package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.class */
public class VertexParallelismAndInputInfosDeciderUtils {
    private static final Logger LOG = LoggerFactory.getLogger(VertexParallelismAndInputInfosDeciderUtils.class);

    public static Optional<List<IndexRange>> adjustToClosestLegalParallelism(long j, int i, int i2, int i3, long j2, long j3, Function<Long, Integer> function, Function<Long, List<IndexRange>> function2) {
        long j4 = j;
        if (i < i2) {
            long findMaxLegalValue = BisectionSearchUtils.findMaxLegalValue(l -> {
                return Boolean.valueOf(((Integer) function.apply(l)).intValue() >= i2);
            }, j2, j);
            long intValue = function.apply(Long.valueOf(findMaxLegalValue)).intValue();
            j4 = BisectionSearchUtils.findMinLegalValue(l2 -> {
                return Boolean.valueOf(((long) ((Integer) function.apply(l2)).intValue()) == intValue);
            }, j2, findMaxLegalValue);
        } else if (i > i3) {
            j4 = BisectionSearchUtils.findMinLegalValue(l3 -> {
                return Boolean.valueOf(((Integer) function.apply(l3)).intValue() <= i3);
            }, j, j3);
        }
        return isLegalParallelism(function.apply(Long.valueOf(j4)).intValue(), i2, i3) ? Optional.of(function2.apply(Long.valueOf(j4))) : Optional.empty();
    }

    public static <T> List<List<T>> cartesianProduct(List<List<T>> list) {
        ArrayList arrayList = new ArrayList();
        if (list.isEmpty()) {
            arrayList.add(new ArrayList());
            return arrayList;
        }
        List<T> list2 = list.get(0);
        List<List> cartesianProduct = cartesianProduct(list.subList(1, list.size()));
        for (T t : list2) {
            for (List list3 : cartesianProduct) {
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(t);
                arrayList2.addAll(list3);
                arrayList.add(arrayList2);
            }
        }
        return arrayList;
    }

    public static long median(long[] jArr) {
        int length = jArr.length;
        long[] array = LongStream.of(jArr).sorted().toArray();
        return length % 2 == 0 ? Math.max((array[length / 2] + array[(length / 2) - 1]) / 2, 1L) : Math.max(array[length / 2], 1L);
    }

    public static long computeSkewThreshold(long j, double d, long j2) {
        return (long) Math.max(j * d, j2);
    }

    public static long computeTargetSize(long[] jArr, long j, long j2) {
        long[] array = LongStream.of(jArr).filter(j3 -> {
            return j3 <= j;
        }).toArray();
        return array.length == 0 ? j2 : Math.max(j2, LongStream.of(array).sum() / array.length);
    }

    public static List<BlockingInputInfo> getNonBroadcastInputInfos(List<BlockingInputInfo> list) {
        return (List) list.stream().filter(blockingInputInfo -> {
            return !blockingInputInfo.isBroadcast();
        }).collect(Collectors.toList());
    }

    public static boolean hasSameNumPartitions(List<BlockingInputInfo> list) {
        return ((Set) list.stream().map((v0) -> {
            return v0.getNumPartitions();
        }).collect(Collectors.toSet())).size() == 1;
    }

    public static int getMaxNumPartitions(List<BlockingInputInfo> list) {
        Preconditions.checkArgument(!list.isEmpty());
        return list.stream().mapToInt((v0) -> {
            return v0.getNumPartitions();
        }).max().getAsInt();
    }

    public static int checkAndGetSubpartitionNum(List<BlockingInputInfo> list) {
        Set set = (Set) list.stream().flatMap(blockingInputInfo -> {
            Stream<Integer> boxed = IntStream.range(0, blockingInputInfo.getNumPartitions()).boxed();
            Objects.requireNonNull(blockingInputInfo);
            return boxed.map((v1) -> {
                return r1.getNumSubpartitions(v1);
            });
        }).collect(Collectors.toSet());
        Preconditions.checkState(set.size() == 1);
        return ((Integer) set.iterator().next()).intValue();
    }

    public static int checkAndGetSubpartitionNumForAggregatedInputs(Collection<AggregatedBlockingInputInfo> collection) {
        Set set = (Set) collection.stream().map((v0) -> {
            return v0.getNumSubpartitions();
        }).collect(Collectors.toSet());
        Preconditions.checkState(set.size() == 1);
        return ((Integer) set.iterator().next()).intValue();
    }

    public static boolean isLegalParallelism(int i, int i2, int i3) {
        return i >= i2 && i <= i3;
    }

    public static boolean checkAndGetIntraCorrelation(List<BlockingInputInfo> list) {
        Set set = (Set) list.stream().map((v0) -> {
            return v0.isIntraInputKeyCorrelated();
        }).collect(Collectors.toSet());
        Preconditions.checkArgument(set.size() == 1);
        return ((Boolean) set.iterator().next()).booleanValue();
    }

    public static int checkAndGetParallelism(Collection<JobVertexInputInfo> collection) {
        Set set = (Set) collection.stream().map(jobVertexInputInfo -> {
            return Integer.valueOf(jobVertexInputInfo.getExecutionVertexInputInfos().size());
        }).collect(Collectors.toSet());
        Preconditions.checkState(set.size() == 1);
        return ((Integer) set.iterator().next()).intValue();
    }

    public static Optional<List<IndexRange>> tryComputeSubpartitionSliceRange(int i, int i2, long j, Map<Integer, List<SubpartitionSlice>> map) {
        Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedData = tryComputeSubpartitionSliceRangeEvenlyDistributedData(i, i2, j, map);
        if (tryComputeSubpartitionSliceRangeEvenlyDistributedData.isEmpty()) {
            LOG.info("Failed to compute a legal subpartition slice range that can evenly distribute data amount, fallback to compute it that can evenly distribute the number of subpartition slices.");
            tryComputeSubpartitionSliceRangeEvenlyDistributedData = tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(i, i2, map);
        }
        return tryComputeSubpartitionSliceRangeEvenlyDistributedData;
    }

    public static JobVertexInputInfo createdJobVertexInputInfoForBroadcast(BlockingInputInfo blockingInputInfo, int i) {
        Preconditions.checkArgument(blockingInputInfo.isBroadcast());
        int numPartitions = blockingInputInfo.getNumPartitions();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(blockingInputInfo.isSingleSubpartitionContainsAllData() ? new ExecutionVertexInputInfo(i2, new IndexRange(0, numPartitions - 1), new IndexRange(0, 0)) : new ExecutionVertexInputInfo(i2, new IndexRange(0, numPartitions - 1), new IndexRange(0, blockingInputInfo.getNumSubpartitions(0) - 1)));
        }
        return new JobVertexInputInfo(arrayList);
    }

    public static JobVertexInputInfo createdJobVertexInputInfoForNonBroadcast(BlockingInputInfo blockingInputInfo, List<IndexRange> list, List<SubpartitionSlice> list2) {
        Preconditions.checkArgument(!blockingInputInfo.isBroadcast());
        int numPartitions = blockingInputInfo.getNumPartitions();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new ExecutionVertexInputInfo(i, computeConsumedSubpartitionGroups(list.get(i), list2, numPartitions, blockingInputInfo.isPointwise())));
        }
        return new JobVertexInputInfo(arrayList);
    }

    private static Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedData(int i, int i2, long j, Map<Integer, List<SubpartitionSlice>> map) {
        int checkAndGetSubpartitionSlicesSize = checkAndGetSubpartitionSlicesSize(map);
        List<IndexRange> computeSubpartitionSliceRanges = computeSubpartitionSliceRanges(j, checkAndGetSubpartitionSlicesSize, map);
        if (isLegalParallelism(computeSubpartitionSliceRanges.size(), i, i2)) {
            return Optional.of(computeSubpartitionSliceRanges);
        }
        long j2 = j;
        long j3 = 0;
        for (int i3 = 0; i3 < checkAndGetSubpartitionSlicesSize; i3++) {
            long j4 = 0;
            Iterator<List<SubpartitionSlice>> it = map.values().iterator();
            while (it.hasNext()) {
                j4 += it.next().get(i3).getDataBytes();
            }
            j2 = Math.min(j2, j4);
            j3 += j4;
        }
        return adjustToClosestLegalParallelism(j, computeSubpartitionSliceRanges.size(), i, i2, j2, j3, l -> {
            return Integer.valueOf(computeParallelism(l.longValue(), checkAndGetSubpartitionSlicesSize, map));
        }, l2 -> {
            return computeSubpartitionSliceRanges(l2.longValue(), checkAndGetSubpartitionSlicesSize, map);
        });
    }

    private static Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(int i, int i2, Map<Integer, List<SubpartitionSlice>> map) {
        int checkAndGetSubpartitionSlicesSize = checkAndGetSubpartitionSlicesSize(map);
        if (checkAndGetSubpartitionSlicesSize < i) {
            return Optional.empty();
        }
        int min = Math.min(checkAndGetSubpartitionSlicesSize, i2);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < min; i3++) {
            arrayList.add(new IndexRange((i3 * checkAndGetSubpartitionSlicesSize) / min, (((i3 + 1) * checkAndGetSubpartitionSlicesSize) / min) - 1));
        }
        Preconditions.checkState(arrayList.size() == min);
        return Optional.of(arrayList);
    }

    private static Map<IndexRange, IndexRange> computeConsumedSubpartitionGroups(IndexRange indexRange, List<SubpartitionSlice> list, int i, boolean z) {
        IndexRange subpartitionRange;
        IndexRange partitionRange;
        TreeMap treeMap = new TreeMap(Comparator.comparingInt((v0) -> {
            return v0.getStartIndex();
        }));
        for (int startIndex = indexRange.getStartIndex(); startIndex <= indexRange.getEndIndex(); startIndex++) {
            SubpartitionSlice subpartitionSlice = list.get(startIndex);
            if (z) {
                subpartitionRange = subpartitionSlice.getPartitionRange(i);
                partitionRange = subpartitionSlice.getSubpartitionRange();
            } else {
                subpartitionRange = subpartitionSlice.getSubpartitionRange();
                partitionRange = subpartitionSlice.getPartitionRange(i);
            }
            ((List) treeMap.computeIfAbsent(subpartitionRange, indexRange2 -> {
                return new ArrayList();
            })).add(partitionRange);
        }
        Map map = (Map) treeMap.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return IndexRangeUtil.mergeIndexRanges((Collection) entry.getValue());
        }));
        HashMap hashMap = new HashMap();
        for (Map.Entry entry2 : map.entrySet()) {
            IndexRange indexRange3 = (IndexRange) entry2.getKey();
            Iterator it = ((List) entry2.getValue()).iterator();
            while (it.hasNext()) {
                ((List) hashMap.computeIfAbsent((IndexRange) it.next(), indexRange4 -> {
                    return new ArrayList();
                })).add(indexRange3);
            }
        }
        Map<IndexRange, IndexRange> map2 = (Map) hashMap.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry3 -> {
            List<IndexRange> mergeIndexRanges = IndexRangeUtil.mergeIndexRanges((Collection) entry3.getValue());
            Preconditions.checkState(mergeIndexRanges.size() == 1);
            return mergeIndexRanges.get(0);
        }));
        return z ? reverseIndexRangeMap(map2) : map2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<IndexRange> computeSubpartitionSliceRanges(long j, int i, Map<Integer, List<SubpartitionSlice>> map) {
        ArrayList arrayList = new ArrayList();
        long j2 = 0;
        int i2 = 0;
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < i; i3++) {
            long j3 = 0;
            long j4 = 0;
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : map.entrySet()) {
                Integer key = entry.getKey();
                SubpartitionSlice subpartitionSlice = entry.getValue().get(i3);
                if (!((Set) hashMap.computeIfAbsent(key, num -> {
                    return new HashSet();
                })).contains(subpartitionSlice)) {
                    j4 += subpartitionSlice.getDataBytes();
                }
                j3 += subpartitionSlice.getDataBytes();
            }
            if (i3 == i2 || j2 + j4 <= j) {
                j2 += j4;
            } else {
                arrayList.add(new IndexRange(i2, i3 - 1));
                i2 = i3;
                j2 = j3;
                hashMap.clear();
            }
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry2 : map.entrySet()) {
                ((Set) hashMap.computeIfAbsent(entry2.getKey(), num2 -> {
                    return new HashSet();
                })).add(entry2.getValue().get(i3));
            }
        }
        arrayList.add(new IndexRange(i2, i - 1));
        return arrayList;
    }

    private static int computeParallelism(long j, int i, Map<Integer, List<SubpartitionSlice>> map) {
        int i2 = 1;
        long j2 = 0;
        int i3 = 0;
        HashMap hashMap = new HashMap();
        for (int i4 = 0; i4 < i; i4++) {
            long j3 = 0;
            long j4 = 0;
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry : map.entrySet()) {
                Integer key = entry.getKey();
                SubpartitionSlice subpartitionSlice = entry.getValue().get(i4);
                if (!((Set) hashMap.computeIfAbsent(key, num -> {
                    return new HashSet();
                })).contains(subpartitionSlice)) {
                    j4 += subpartitionSlice.getDataBytes();
                }
                j3 += subpartitionSlice.getDataBytes();
            }
            if (i4 == i3 || j2 + j4 <= j) {
                j2 += j4;
            } else {
                i2++;
                i3 = i4;
                j2 = j3;
                hashMap.clear();
            }
            for (Map.Entry<Integer, List<SubpartitionSlice>> entry2 : map.entrySet()) {
                ((Set) hashMap.computeIfAbsent(entry2.getKey(), num2 -> {
                    return new HashSet();
                })).add(entry2.getValue().get(i4));
            }
        }
        return i2;
    }

    private static int checkAndGetSubpartitionSlicesSize(Map<Integer, List<SubpartitionSlice>> map) {
        Set set = (Set) map.values().stream().map((v0) -> {
            return v0.size();
        }).collect(Collectors.toSet());
        Preconditions.checkArgument(set.size() == 1);
        return ((Integer) set.iterator().next()).intValue();
    }

    private static Map<IndexRange, IndexRange> reverseIndexRangeMap(Map<IndexRange, IndexRange> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<IndexRange, IndexRange> entry : map.entrySet()) {
            Preconditions.checkState(!hashMap.containsKey(entry.getValue()));
            hashMap.put(entry.getValue(), entry.getKey());
        }
        return hashMap;
    }

    public static long calculateDataVolumePerTaskForInputsGroup(long j, List<BlockingInputInfo> list, List<BlockingInputInfo> list2) {
        return calculateDataVolumePerTaskForInput(j, list.stream().mapToLong((v0) -> {
            return v0.getNumBytesProduced();
        }).sum(), list2.stream().mapToLong((v0) -> {
            return v0.getNumBytesProduced();
        }).sum());
    }

    public static long calculateDataVolumePerTaskForInput(long j, long j2, long j3) {
        return (long) ((j2 / j3) * j);
    }
}
