package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.assertj.core.api.Assertions;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/VertexInputInfoComputerTestUtil.class */
public class VertexInputInfoComputerTestUtil {
    public static List<BlockingInputInfo> createBlockingInputInfos(int i, int i2, int i3, int i4, boolean z, boolean z2, int i5, double d, List<Integer> list, List<Integer> list2, boolean z3) {
        ArrayList arrayList = new ArrayList();
        for (int i6 = 0; i6 < i2; i6++) {
            HashMap hashMap = new HashMap();
            for (int i7 = 0; i7 < i3; i7++) {
                long[] jArr = new long[i4];
                for (int i8 = 0; i8 < i4; i8++) {
                    if (list2.contains(Integer.valueOf(i8)) || list.contains(Integer.valueOf(i7))) {
                        jArr[i8] = (long) (i5 * d);
                    } else {
                        jArr[i8] = i5;
                    }
                }
                hashMap.put(Integer.valueOf(i7), jArr);
            }
            arrayList.add(new BlockingInputInfo(z3 ? new PointwiseBlockingResultInfo(new IntermediateDataSetID(), i3, i4, hashMap) : new AllToAllBlockingResultInfo(new IntermediateDataSetID(), i3, i4, false, hashMap), i, z2, z));
        }
        return arrayList;
    }

    private static void checkParallelism(int i, Map<IntermediateDataSetID, JobVertexInputInfo> map) {
        map.values().forEach(jobVertexInputInfo -> {
            Assertions.assertThat(jobVertexInputInfo.getExecutionVertexInputInfos().size()).isEqualTo(i);
        });
    }

    public static void checkConsumedSubpartitionGroups(List<Map<IndexRange, IndexRange>> list, List<BlockingInputInfo> list2, Map<IntermediateDataSetID, JobVertexInputInfo> map) {
        List executionVertexInputInfos = checkAndGetJobVertexInputInfo(list2, map).getExecutionVertexInputInfos();
        for (int i = 0; i < executionVertexInputInfos.size(); i++) {
            Assertions.assertThat(((ExecutionVertexInputInfo) executionVertexInputInfos.get(i)).getConsumedSubpartitionGroups()).isEqualTo(list.get(i));
        }
    }

    public static void checkConsumedDataVolumePerSubtask(long[] jArr, List<BlockingInputInfo> list, Map<IntermediateDataSetID, JobVertexInputInfo> map) {
        long[] jArr2 = new long[jArr.length];
        for (BlockingInputInfo blockingInputInfo : list) {
            List executionVertexInputInfos = map.get(blockingInputInfo.getResultId()).getExecutionVertexInputInfos();
            for (int i = 0; i < executionVertexInputInfos.size(); i++) {
                int i2 = i;
                jArr2[i2] = jArr2[i2] + ((ExecutionVertexInputInfo) executionVertexInputInfos.get(i)).getConsumedSubpartitionGroups().entrySet().stream().mapToLong(entry -> {
                    return blockingInputInfo.getNumBytesProduced((IndexRange) entry.getKey(), (IndexRange) entry.getValue());
                }).sum();
            }
        }
        Assertions.assertThat(jArr2).isEqualTo(jArr);
    }

    private static JobVertexInputInfo checkAndGetJobVertexInputInfo(List<BlockingInputInfo> list, Map<IntermediateDataSetID, JobVertexInputInfo> map) {
        List list2 = (List) list.stream().map(blockingInputInfo -> {
            return (JobVertexInputInfo) map.get(blockingInputInfo.getResultId());
        }).collect(Collectors.toList());
        Assertions.assertThat(list2.size()).isEqualTo(list.size());
        JobVertexInputInfo jobVertexInputInfo = (JobVertexInputInfo) list2.get(0);
        for (int i = 1; i < list2.size(); i++) {
            Assertions.assertThat((JobVertexInputInfo) list2.get(i)).isEqualTo(jobVertexInputInfo);
        }
        return jobVertexInputInfo;
    }

    public static void checkCorrectnessForNonCorrelatedInput(Map<IntermediateDataSetID, JobVertexInputInfo> map, BlockingInputInfo blockingInputInfo, int i) {
        checkParallelism(i, map);
        HashMap hashMap = new HashMap();
        map.get(blockingInputInfo.getResultId()).getExecutionVertexInputInfos().forEach(executionVertexInputInfo -> {
            executionVertexInputInfo.getConsumedSubpartitionGroups().forEach((indexRange, indexRange2) -> {
                for (int startIndex = indexRange.getStartIndex(); startIndex <= indexRange.getEndIndex(); startIndex++) {
                    ((List) hashMap.computeIfAbsent(Integer.valueOf(startIndex), num -> {
                        return new ArrayList();
                    })).add(indexRange2);
                }
            });
        });
        Set set = (Set) IntStream.rangeClosed(0, blockingInputInfo.getNumPartitions() - 1).boxed().collect(Collectors.toSet());
        IndexRange indexRange = new IndexRange(0, blockingInputInfo.getNumSubpartitions(0) - 1);
        Assertions.assertThat(hashMap.keySet()).isEqualTo(set);
        hashMap.values().forEach(list -> {
            List mergeIndexRanges = IndexRangeUtil.mergeIndexRanges(list);
            Assertions.assertThat(mergeIndexRanges.size()).isEqualTo(1);
            Assertions.assertThat((IndexRange) mergeIndexRanges.get(0)).isEqualTo(indexRange);
        });
    }

    public static void checkCorrectnessForCorrelatedInputs(Map<IntermediateDataSetID, JobVertexInputInfo> map, List<BlockingInputInfo> list, int i, int i2) {
        checkParallelism(i, map);
        Map map2 = (Map) list.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getInputTypeNumber();
        }));
        Map map3 = (Map) map2.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return (List) ((List) entry.getValue()).stream().map(blockingInputInfo -> {
                return (JobVertexInputInfo) map.get(blockingInputInfo.getResultId());
            }).collect(Collectors.toList());
        }));
        Map map4 = (Map) map2.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toMap(blockingInputInfo -> {
            return (JobVertexInputInfo) map.get(blockingInputInfo.getResultId());
        }, (v0) -> {
            return v0.getNumPartitions();
        }));
        Assertions.assertThat(map3.size()).isEqualTo(2);
        checkCorrectnessForCorrelatedInputs((List<JobVertexInputInfo>) map3.get(1), (List<JobVertexInputInfo>) map3.get(2), (Map<JobVertexInputInfo, Integer>) map4, i2);
    }

    private static void checkCorrectnessForCorrelatedInputs(List<JobVertexInputInfo> list, List<JobVertexInputInfo> list2, Map<JobVertexInputInfo, Integer> map, int i) {
        for (JobVertexInputInfo jobVertexInputInfo : list) {
            for (JobVertexInputInfo jobVertexInputInfo2 : list2) {
                checkCorrectnessForConsumedSubpartitionRanges(jobVertexInputInfo, jobVertexInputInfo2, map.get(jobVertexInputInfo).intValue(), map.get(jobVertexInputInfo2).intValue(), i);
            }
        }
    }

    private static void checkCorrectnessForConsumedSubpartitionRanges(JobVertexInputInfo jobVertexInputInfo, JobVertexInputInfo jobVertexInputInfo2, int i, int i2, int i3) {
        Assertions.assertThat(jobVertexInputInfo.getExecutionVertexInputInfos().size()).isEqualTo(jobVertexInputInfo2.getExecutionVertexInputInfos().size());
        HashMap hashMap = new HashMap();
        for (int i4 = 0; i4 < jobVertexInputInfo.getExecutionVertexInputInfos().size(); i4++) {
            Map<Integer, Set<IndexRange>> consumedSubpartitionIndexToPartitionRanges = getConsumedSubpartitionIndexToPartitionRanges(((ExecutionVertexInputInfo) jobVertexInputInfo.getExecutionVertexInputInfos().get(i4)).getConsumedSubpartitionGroups());
            Map<Integer, Set<IndexRange>> consumedSubpartitionIndexToPartitionRanges2 = getConsumedSubpartitionIndexToPartitionRanges(((ExecutionVertexInputInfo) jobVertexInputInfo2.getExecutionVertexInputInfos().get(i4)).getConsumedSubpartitionGroups());
            consumedSubpartitionIndexToPartitionRanges.forEach((num, set) -> {
                Assertions.assertThat(consumedSubpartitionIndexToPartitionRanges2.containsKey(num)).isTrue();
                set.forEach(indexRange -> {
                    for (int startIndex = indexRange.getStartIndex(); startIndex <= indexRange.getEndIndex(); startIndex++) {
                        ((Set) ((Map) hashMap.computeIfAbsent(num, num -> {
                            return new HashMap();
                        })).computeIfAbsent(Integer.valueOf(startIndex), num2 -> {
                            return new HashSet();
                        })).addAll((Collection) consumedSubpartitionIndexToPartitionRanges2.get(num));
                    }
                });
            });
        }
        Set set2 = (Set) IntStream.rangeClosed(0, i - 1).boxed().collect(Collectors.toSet());
        Set set3 = (Set) IntStream.rangeClosed(0, i3 - 1).boxed().collect(Collectors.toSet());
        IndexRange indexRange = new IndexRange(0, i2 - 1);
        Assertions.assertThat(hashMap.keySet()).isEqualTo(set3);
        hashMap.forEach((num2, map) -> {
            Assertions.assertThat(map.keySet()).isEqualTo(set2);
            map.values().forEach(set4 -> {
                List mergeIndexRanges = IndexRangeUtil.mergeIndexRanges(set4);
                Assertions.assertThat(mergeIndexRanges.size()).isEqualTo(1);
                Assertions.assertThat((IndexRange) mergeIndexRanges.get(0)).isEqualTo(indexRange);
            });
        });
    }

    private static Map<Integer, Set<IndexRange>> getConsumedSubpartitionIndexToPartitionRanges(Map<IndexRange, IndexRange> map) {
        HashMap hashMap = new HashMap();
        map.forEach((indexRange, indexRange2) -> {
            for (int startIndex = indexRange2.getStartIndex(); startIndex <= indexRange2.getEndIndex(); startIndex++) {
                ((Set) hashMap.computeIfAbsent(Integer.valueOf(startIndex), num -> {
                    return new HashSet();
                })).add(indexRange);
            }
        });
        return hashMap;
    }
}
