/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.PointwiseBlockingResultInfo;
import org.assertj.core.api.Assertions;

public class VertexInputInfoComputerTestUtil {
    public static List<BlockingInputInfo> createBlockingInputInfos(int typeNumber, int numInputInfos, int numPartitions, int numSubpartitions, boolean existIntraInputKeyCorrelation, boolean existInterInputsKeyCorrelation, int defaultSize, double skewedFactor, List<Integer> skewedPartitionIndex, List<Integer> skewedSubpartitionIndex, boolean isPointwise) {
        ArrayList<BlockingInputInfo> blockingInputInfos = new ArrayList<BlockingInputInfo>();
        for (int i = 0; i < numInputInfos; ++i) {
            HashMap<Integer, long[]> subpartitionBytesByPartitionIndex = new HashMap<Integer, long[]>();
            for (int j = 0; j < numPartitions; ++j) {
                long[] subpartitionBytes = new long[numSubpartitions];
                for (int k = 0; k < numSubpartitions; ++k) {
                    subpartitionBytes[k] = skewedSubpartitionIndex.contains(k) || skewedPartitionIndex.contains(j) ? (long)((double)defaultSize * skewedFactor) : (long)defaultSize;
                }
                subpartitionBytesByPartitionIndex.put(j, subpartitionBytes);
            }
            Object resultInfo = isPointwise ? new PointwiseBlockingResultInfo(new IntermediateDataSetID(), numPartitions, numSubpartitions, subpartitionBytesByPartitionIndex) : new AllToAllBlockingResultInfo(new IntermediateDataSetID(), numPartitions, numSubpartitions, false, subpartitionBytesByPartitionIndex);
            blockingInputInfos.add(new BlockingInputInfo((BlockingResultInfo)resultInfo, typeNumber, existInterInputsKeyCorrelation, existIntraInputKeyCorrelation));
        }
        return blockingInputInfos;
    }

    private static void checkParallelism(int targetParallelism, Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfoMap) {
        vertexInputInfoMap.values().forEach(info -> Assertions.assertThat((int)info.getExecutionVertexInputInfos().size()).isEqualTo(targetParallelism));
    }

    public static void checkConsumedSubpartitionGroups(List<Map<IndexRange, IndexRange>> targetConsumedSubpartitionGroups, List<BlockingInputInfo> inputInfos, Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfoMap) {
        JobVertexInputInfo vertexInputInfo = VertexInputInfoComputerTestUtil.checkAndGetJobVertexInputInfo(inputInfos, vertexInputInfoMap);
        List executionVertexInputInfos = vertexInputInfo.getExecutionVertexInputInfos();
        for (int i = 0; i < executionVertexInputInfos.size(); ++i) {
            Assertions.assertThat((Map)((ExecutionVertexInputInfo)executionVertexInputInfos.get(i)).getConsumedSubpartitionGroups()).isEqualTo(targetConsumedSubpartitionGroups.get(i));
        }
    }

    public static void checkConsumedDataVolumePerSubtask(long[] targetConsumedDataVolume, List<BlockingInputInfo> inputInfos, Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputs) {
        long[] consumedDataVolume = new long[targetConsumedDataVolume.length];
        for (BlockingInputInfo inputInfo : inputInfos) {
            JobVertexInputInfo vertexInputInfo = vertexInputs.get(inputInfo.getResultId());
            List executionVertexInputInfos = vertexInputInfo.getExecutionVertexInputInfos();
            int i = 0;
            while (i < executionVertexInputInfos.size()) {
                ExecutionVertexInputInfo executionVertexInputInfo = (ExecutionVertexInputInfo)executionVertexInputInfos.get(i);
                int n = i++;
                consumedDataVolume[n] = consumedDataVolume[n] + executionVertexInputInfo.getConsumedSubpartitionGroups().entrySet().stream().mapToLong(entry -> inputInfo.getNumBytesProduced((IndexRange)entry.getKey(), (IndexRange)entry.getValue())).sum();
            }
        }
        Assertions.assertThat((long[])consumedDataVolume).isEqualTo((Object)targetConsumedDataVolume);
    }

    private static JobVertexInputInfo checkAndGetJobVertexInputInfo(List<BlockingInputInfo> inputInfos, Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfoMap) {
        List vertexInputInfos = inputInfos.stream().map(inputInfo -> (JobVertexInputInfo)vertexInputInfoMap.get(inputInfo.getResultId())).collect(Collectors.toList());
        Assertions.assertThat((int)vertexInputInfos.size()).isEqualTo(inputInfos.size());
        JobVertexInputInfo baseVertexInputInfo = (JobVertexInputInfo)vertexInputInfos.get(0);
        for (int i = 1; i < vertexInputInfos.size(); ++i) {
            Assertions.assertThat((Object)((JobVertexInputInfo)vertexInputInfos.get(i))).isEqualTo((Object)baseVertexInputInfo);
        }
        return baseVertexInputInfo;
    }

    public static void checkCorrectnessForNonCorrelatedInput(Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfoMap, BlockingInputInfo inputInfo, int targetParallelism) {
        VertexInputInfoComputerTestUtil.checkParallelism(targetParallelism, vertexInputInfoMap);
        HashMap consumedPartitionToSubpartitionRanges = new HashMap();
        vertexInputInfoMap.get(inputInfo.getResultId()).getExecutionVertexInputInfos().forEach(info -> info.getConsumedSubpartitionGroups().forEach((partitionRange, subpartitionRange) -> {
            for (int i = partitionRange.getStartIndex(); i <= partitionRange.getEndIndex(); ++i) {
                consumedPartitionToSubpartitionRanges.computeIfAbsent(i, k -> new ArrayList()).add(subpartitionRange);
            }
        }));
        Set partitionIndex = IntStream.rangeClosed(0, inputInfo.getNumPartitions() - 1).boxed().collect(Collectors.toSet());
        IndexRange subpartitionRange = new IndexRange(0, inputInfo.getNumSubpartitions(0) - 1);
        Assertions.assertThat(consumedPartitionToSubpartitionRanges.keySet()).isEqualTo(partitionIndex);
        consumedPartitionToSubpartitionRanges.values().forEach(subpartitionRanges -> {
            List mergedRange = IndexRangeUtil.mergeIndexRanges((Collection)subpartitionRanges);
            Assertions.assertThat((int)mergedRange.size()).isEqualTo(1);
            Assertions.assertThat((Object)((IndexRange)mergedRange.get(0))).isEqualTo((Object)subpartitionRange);
        });
    }

    public static void checkCorrectnessForCorrelatedInputs(Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfoMap, List<BlockingInputInfo> inputInfos, int targetParallelism, int numSubpartitions) {
        VertexInputInfoComputerTestUtil.checkParallelism(targetParallelism, vertexInputInfoMap);
        Map<Integer, List<BlockingInputInfo>> inputInfosGroupByTypeNumber = inputInfos.stream().collect(Collectors.groupingBy(BlockingInputInfo::getInputTypeNumber));
        Map<Integer, List> vertexInputInfosGroupByTypeNumber = inputInfosGroupByTypeNumber.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((List)e.getValue()).stream().map(v -> (JobVertexInputInfo)vertexInputInfoMap.get(v.getResultId())).collect(Collectors.toList())));
        Map<JobVertexInputInfo, Integer> vertexInputInfoToNumPartitionsMap = inputInfosGroupByTypeNumber.values().stream().flatMap(Collection::stream).collect(Collectors.toMap(v -> (JobVertexInputInfo)vertexInputInfoMap.get(v.getResultId()), BlockingInputInfo::getNumPartitions));
        Assertions.assertThat((int)vertexInputInfosGroupByTypeNumber.size()).isEqualTo(2);
        VertexInputInfoComputerTestUtil.checkCorrectnessForCorrelatedInputs(vertexInputInfosGroupByTypeNumber.get(1), vertexInputInfosGroupByTypeNumber.get(2), vertexInputInfoToNumPartitionsMap, numSubpartitions);
    }

    private static void checkCorrectnessForCorrelatedInputs(List<JobVertexInputInfo> infosWithTypeNumber1, List<JobVertexInputInfo> infosWithTypeNumber2, Map<JobVertexInputInfo, Integer> vertexInputInfoToNumPartitionsMap, int numSubpartitions) {
        for (JobVertexInputInfo vertexInputInfo : infosWithTypeNumber1) {
            for (JobVertexInputInfo jobVertexInputInfo : infosWithTypeNumber2) {
                VertexInputInfoComputerTestUtil.checkCorrectnessForConsumedSubpartitionRanges(vertexInputInfo, jobVertexInputInfo, vertexInputInfoToNumPartitionsMap.get(vertexInputInfo), vertexInputInfoToNumPartitionsMap.get(jobVertexInputInfo), numSubpartitions);
            }
        }
    }

    private static void checkCorrectnessForConsumedSubpartitionRanges(JobVertexInputInfo inputInfo1, JobVertexInputInfo inputInfo2, int numPartitions1, int numPartitions2, int numSubpartitions) {
        Assertions.assertThat((int)inputInfo1.getExecutionVertexInputInfos().size()).isEqualTo(inputInfo2.getExecutionVertexInputInfos().size());
        HashMap<Integer, Map> input1ToInput2 = new HashMap<Integer, Map>();
        for (int i = 0; i < inputInfo1.getExecutionVertexInputInfos().size(); ++i) {
            Map<Integer, Set<IndexRange>> subpartitionIndexToPartition1 = VertexInputInfoComputerTestUtil.getConsumedSubpartitionIndexToPartitionRanges(((ExecutionVertexInputInfo)inputInfo1.getExecutionVertexInputInfos().get(i)).getConsumedSubpartitionGroups());
            Map<Integer, Set<IndexRange>> subpartitionIndexToPartition2 = VertexInputInfoComputerTestUtil.getConsumedSubpartitionIndexToPartitionRanges(((ExecutionVertexInputInfo)inputInfo2.getExecutionVertexInputInfos().get(i)).getConsumedSubpartitionGroups());
            subpartitionIndexToPartition1.forEach((subpartitionIndex, partitionRanges) -> {
                Assertions.assertThat((boolean)subpartitionIndexToPartition2.containsKey(subpartitionIndex)).isTrue();
                partitionRanges.forEach(partitionRange -> {
                    for (int j = partitionRange.getStartIndex(); j <= partitionRange.getEndIndex(); ++j) {
                        input1ToInput2.computeIfAbsent((Integer)subpartitionIndex, k -> new HashMap()).computeIfAbsent(j, k -> new HashSet()).addAll((Collection)subpartitionIndexToPartition2.get(subpartitionIndex));
                    }
                });
            });
        }
        Set partitionIndex = IntStream.rangeClosed(0, numPartitions1 - 1).boxed().collect(Collectors.toSet());
        Set subpartitionIndexSet = IntStream.rangeClosed(0, numSubpartitions - 1).boxed().collect(Collectors.toSet());
        IndexRange partitionRange2 = new IndexRange(0, numPartitions2 - 1);
        Assertions.assertThat(input1ToInput2.keySet()).isEqualTo(subpartitionIndexSet);
        input1ToInput2.forEach((subpartitionIndex, input1ToInput2PartitionRanges) -> {
            Assertions.assertThat(input1ToInput2PartitionRanges.keySet()).isEqualTo((Object)partitionIndex);
            input1ToInput2PartitionRanges.values().forEach(partitionRanges -> {
                List mergedRange = IndexRangeUtil.mergeIndexRanges((Collection)partitionRanges);
                Assertions.assertThat((int)mergedRange.size()).isEqualTo(1);
                Assertions.assertThat((Object)((IndexRange)mergedRange.get(0))).isEqualTo((Object)partitionRange2);
            });
        });
    }

    private static Map<Integer, Set<IndexRange>> getConsumedSubpartitionIndexToPartitionRanges(Map<IndexRange, IndexRange> consumedSubpartitionGroups) {
        HashMap<Integer, Set<IndexRange>> subpartitionIndexToPartition = new HashMap<Integer, Set<IndexRange>>();
        consumedSubpartitionGroups.forEach((partitionRange, subpartitionRange) -> {
            for (int j = subpartitionRange.getStartIndex(); j <= subpartitionRange.getEndIndex(); ++j) {
                subpartitionIndexToPartition.computeIfAbsent(j, key -> new HashSet()).add(partitionRange);
            }
        });
        return subpartitionIndexToPartition;
    }
}

