package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputer.class */
public class AllToAllVertexInputInfoComputer {
    private static final Logger LOG = LoggerFactory.getLogger(AllToAllVertexInputInfoComputer.class);
    private final double skewedFactor;
    private final long defaultSkewedThreshold;

    public AllToAllVertexInputInfoComputer(double d, long j) {
        this.skewedFactor = d;
        this.defaultSkewedThreshold = j;
    }

    public Map<IntermediateDataSetID, JobVertexInputInfo> compute(JobVertexID jobVertexID, List<BlockingInputInfo> list, int i, int i2, int i3, long j) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (BlockingInputInfo blockingInputInfo : list) {
            if (blockingInputInfo.areInterInputsKeysCorrelated()) {
                arrayList2.add(blockingInputInfo);
            } else {
                arrayList.add(blockingInputInfo);
            }
        }
        HashMap hashMap = new HashMap();
        if (!arrayList2.isEmpty()) {
            hashMap.putAll(computeJobVertexInputInfosForInputsWithInterKeysCorrelation(jobVertexID, arrayList2, i, i2, i3, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(j, arrayList2, list)));
            i = VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism(hashMap.values());
        }
        if (!arrayList.isEmpty()) {
            hashMap.putAll(computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation(arrayList, i, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup(j, arrayList, list)));
        }
        return hashMap;
    }

    private Map<IntermediateDataSetID, JobVertexInputInfo> computeJobVertexInputInfosForInputsWithInterKeysCorrelation(JobVertexID jobVertexID, List<BlockingInputInfo> list, int i, int i2, int i3, long j) {
        List<BlockingInputInfo> nonBroadcastInputInfos = VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos(list);
        if (nonBroadcastInputInfos.isEmpty()) {
            LOG.info("All inputs are broadcast for vertex {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", jobVertexID);
            return VertexInputInfoComputationUtils.computeVertexInputInfos(i, list, true);
        }
        Map<Integer, List<SubpartitionSlice>> createSubpartitionSlicesForInputsWithInterKeysCorrelation = createSubpartitionSlicesForInputsWithInterKeysCorrelation(nonBroadcastInputInfos, j);
        Optional<List<IndexRange>> tryComputeSubpartitionSliceRange = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange(i2, i3, j, createSubpartitionSlicesForInputsWithInterKeysCorrelation);
        if (tryComputeSubpartitionSliceRange.isEmpty()) {
            LOG.info("Cannot find a legal parallelism to evenly distribute data amount for job vertex {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", jobVertexID);
            return VertexInputInfoComputationUtils.computeVertexInputInfos(i, list, true);
        }
        List<IndexRange> list2 = tryComputeSubpartitionSliceRange.get();
        Preconditions.checkState(VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(list2.size(), i2, i3));
        return createJobVertexInputInfos(list, createSubpartitionSlicesForInputsWithInterKeysCorrelation, list2);
    }

    private Map<Integer, List<SubpartitionSlice>> createSubpartitionSlicesForInputsWithInterKeysCorrelation(List<BlockingInputInfo> list, long j) {
        Map<Integer, AggregatedBlockingInputInfo> createAggregatedBlockingInputInfos = createAggregatedBlockingInputInfos(list, j);
        int checkAndGetSubpartitionNumForAggregatedInputs = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNumForAggregatedInputs(createAggregatedBlockingInputInfos.values());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < checkAndGetSubpartitionNumForAggregatedInputs; i++) {
            Map<Integer, List<SubpartitionSlice>> createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation = createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation(i, createAggregatedBlockingInputInfos);
            ArrayList arrayList = new ArrayList(createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation.keySet());
            for (List list2 : VertexParallelismAndInputInfosDeciderUtils.cartesianProduct(new ArrayList(createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation.values()))) {
                for (int i2 = 0; i2 < list2.size(); i2++) {
                    ((List) hashMap.computeIfAbsent(Integer.valueOf(((Integer) arrayList.get(i2)).intValue()), num -> {
                        return new ArrayList();
                    })).add((SubpartitionSlice) list2.get(i2));
                }
            }
        }
        return hashMap;
    }

    private Map<Integer, AggregatedBlockingInputInfo> createAggregatedBlockingInputInfos(List<BlockingInputInfo> list, long j) {
        Map map = (Map) list.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getInputTypeNumber();
        }));
        Preconditions.checkState(hasSameIntraInputKeyCorrelation(map));
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : map.entrySet()) {
            hashMap.put((Integer) entry.getKey(), AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo(this.defaultSkewedThreshold, this.skewedFactor, j, (List) entry.getValue()));
        }
        return hashMap;
    }

    private static Map<Integer, List<SubpartitionSlice>> createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation(int i, Map<Integer, AggregatedBlockingInputInfo> map) {
        HashMap hashMap = new HashMap();
        IndexRange indexRange = new IndexRange(i, i);
        for (Map.Entry<Integer, AggregatedBlockingInputInfo> entry : map.entrySet()) {
            Integer key = entry.getKey();
            AggregatedBlockingInputInfo value = entry.getValue();
            if (value.isSplittable() && value.isSkewedSubpartition(i)) {
                hashMap.put(key, SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges(computePartitionRangesEvenlyData(i, value.getTargetSize(), value.getSubpartitionBytesByPartition()), indexRange, value.getSubpartitionBytesByPartition()));
            } else {
                hashMap.put(key, Collections.singletonList(SubpartitionSlice.createSubpartitionSlice(new IndexRange(0, value.getMaxPartitionNum() - 1), indexRange, value.getAggregatedSubpartitionBytes(i))));
            }
        }
        return hashMap;
    }

    private static List<IndexRange> computePartitionRangesEvenlyData(int i, long j, Map<Integer, long[]> map) {
        long j2;
        ArrayList arrayList = new ArrayList();
        int size = map.size();
        long j3 = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            long j4 = map.get(Integer.valueOf(i3))[i];
            if (i3 == i2 || j3 + j4 < j) {
                j2 = j3 + j4;
            } else {
                arrayList.add(new IndexRange(i2, i3 - 1));
                i2 = i3;
                j2 = j4;
            }
            j3 = j2;
        }
        arrayList.add(new IndexRange(i2, size - 1));
        return arrayList;
    }

    private static Map<IntermediateDataSetID, JobVertexInputInfo> createJobVertexInputInfos(List<BlockingInputInfo> list, Map<Integer, List<SubpartitionSlice>> map, List<IndexRange> list2) {
        HashMap hashMap = new HashMap();
        for (BlockingInputInfo blockingInputInfo : list) {
            if (blockingInputInfo.isBroadcast()) {
                hashMap.put(blockingInputInfo.getResultId(), VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForBroadcast(blockingInputInfo, list2.size()));
            } else {
                hashMap.put(blockingInputInfo.getResultId(), VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast(blockingInputInfo, list2, map.get(Integer.valueOf(blockingInputInfo.getInputTypeNumber()))));
            }
        }
        return hashMap;
    }

    private Map<IntermediateDataSetID, JobVertexInputInfo> computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation(List<BlockingInputInfo> list, int i, long j) {
        long sum = list.stream().mapToLong((v0) -> {
            return v0.getNumBytesProduced();
        }).sum();
        HashMap hashMap = new HashMap();
        for (BlockingInputInfo blockingInputInfo : list) {
            hashMap.put(blockingInputInfo.getResultId(), computeVertexInputInfoForInputWithoutInterKeysCorrelation(blockingInputInfo, i, VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInput(j, blockingInputInfo.getNumBytesProduced(), sum)));
        }
        return hashMap;
    }

    private JobVertexInputInfo computeVertexInputInfoForInputWithoutInterKeysCorrelation(BlockingInputInfo blockingInputInfo, int i, long j) {
        if (blockingInputInfo.isBroadcast()) {
            return VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForBroadcast(blockingInputInfo, i);
        }
        List<SubpartitionSlice> createSubpartitionSlicesForInputWithoutInterKeysCorrelation = createSubpartitionSlicesForInputWithoutInterKeysCorrelation(blockingInputInfo);
        Optional<List<IndexRange>> tryComputeSubpartitionSliceRange = VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange(i, i, j, Map.of(Integer.valueOf(blockingInputInfo.getInputTypeNumber()), createSubpartitionSlicesForInputWithoutInterKeysCorrelation));
        if (!tryComputeSubpartitionSliceRange.isEmpty()) {
            List<IndexRange> list = tryComputeSubpartitionSliceRange.get();
            Preconditions.checkState(VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism(list.size(), i, i));
            return VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast(blockingInputInfo, list, createSubpartitionSlicesForInputWithoutInterKeysCorrelation);
        }
        LOG.info("Cannot find a legal parallelism to evenly distribute data amount for input {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", blockingInputInfo.getResultId());
        int numPartitions = blockingInputInfo.getNumPartitions();
        Objects.requireNonNull(blockingInputInfo);
        return VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise(numPartitions, i, (v1) -> {
            return r2.getNumSubpartitions(v1);
        }, true);
    }

    private List<SubpartitionSlice> createSubpartitionSlicesForInputWithoutInterKeysCorrelation(BlockingInputInfo blockingInputInfo) {
        ArrayList arrayList = new ArrayList();
        if (blockingInputInfo.isIntraInputKeyCorrelated()) {
            int checkAndGetSubpartitionNum = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(List.of(blockingInputInfo));
            IndexRange indexRange = new IndexRange(0, blockingInputInfo.getNumPartitions() - 1);
            for (int i = 0; i < checkAndGetSubpartitionNum; i++) {
                IndexRange indexRange2 = new IndexRange(i, i);
                arrayList.add(SubpartitionSlice.createSubpartitionSlice(indexRange, indexRange2, blockingInputInfo.getNumBytesProduced(indexRange, indexRange2)));
            }
        } else {
            for (int i2 = 0; i2 < blockingInputInfo.getNumPartitions(); i2++) {
                IndexRange indexRange3 = new IndexRange(i2, i2);
                for (int i3 = 0; i3 < blockingInputInfo.getNumSubpartitions(i2); i3++) {
                    IndexRange indexRange4 = new IndexRange(i3, i3);
                    arrayList.add(SubpartitionSlice.createSubpartitionSlice(indexRange3, indexRange4, blockingInputInfo.getNumBytesProduced(indexRange3, indexRange4)));
                }
            }
        }
        return arrayList;
    }

    private static boolean hasSameIntraInputKeyCorrelation(Map<Integer, List<BlockingInputInfo>> map) {
        return map.values().stream().allMatch(list -> {
            return list.stream().map((v0) -> {
                return v0.isIntraInputKeyCorrelated();
            }).distinct().count() == 1;
        });
    }
}
