package org.apache.flink.runtime.scheduler.adaptivebatch.util;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfo.class */
public class AggregatedBlockingInputInfo {
    private static final Logger LOG = LoggerFactory.getLogger(AggregatedBlockingInputInfo.class);
    private final int maxPartitionNum;
    private final long skewedThreshold;
    private final long targetSize;
    private final boolean intraInputKeyCorrelated;
    private final Map<Integer, long[]> subpartitionBytesByPartition;
    private final long[] aggregatedSubpartitionBytes;

    private AggregatedBlockingInputInfo(long j, long j2, int i, boolean z, Map<Integer, long[]> map, long[] jArr) {
        this.maxPartitionNum = i;
        this.skewedThreshold = j2;
        this.targetSize = j;
        this.intraInputKeyCorrelated = z;
        this.subpartitionBytesByPartition = (Map) Preconditions.checkNotNull(map);
        this.aggregatedSubpartitionBytes = (long[]) Preconditions.checkNotNull(jArr);
    }

    public int getMaxPartitionNum() {
        return this.maxPartitionNum;
    }

    public long getTargetSize() {
        return this.targetSize;
    }

    public Map<Integer, long[]> getSubpartitionBytesByPartition() {
        return Collections.unmodifiableMap(this.subpartitionBytesByPartition);
    }

    public long getAggregatedSubpartitionBytes(int i) {
        return this.aggregatedSubpartitionBytes[i];
    }

    public boolean isSplittable() {
        return (this.intraInputKeyCorrelated || this.subpartitionBytesByPartition.isEmpty()) ? false : true;
    }

    public boolean isSkewedSubpartition(int i) {
        return this.aggregatedSubpartitionBytes[i] > this.skewedThreshold;
    }

    public int getNumSubpartitions() {
        return this.aggregatedSubpartitionBytes.length;
    }

    private static long[] computeAggregatedSubpartitionBytes(List<BlockingInputInfo> list, int i) {
        long[] jArr = new long[i];
        Iterator<BlockingInputInfo> it = list.iterator();
        while (it.hasNext()) {
            List<Long> aggregatedSubpartitionBytes = it.next().getAggregatedSubpartitionBytes();
            for (int i2 = 0; i2 < aggregatedSubpartitionBytes.size(); i2++) {
                int i3 = i2;
                jArr[i3] = jArr[i3] + aggregatedSubpartitionBytes.get(i2).longValue();
            }
        }
        return jArr;
    }

    private static Map<Integer, long[]> computeSubpartitionBytesByPartitionIndex(List<BlockingInputInfo> list, int i) {
        if (!VertexParallelismAndInputInfosDeciderUtils.hasSameNumPartitions(list)) {
            LOG.warn("Input infos have different num partitions, skip calculate SubpartitionBytesByPartitionIndex");
            return Collections.emptyMap();
        }
        HashMap hashMap = new HashMap();
        Iterator<BlockingInputInfo> it = list.iterator();
        while (it.hasNext()) {
            it.next().getSubpartitionBytesByPartitionIndex().forEach((num, jArr) -> {
                long[] jArr = (long[]) hashMap.computeIfAbsent(num, num -> {
                    return new long[i];
                });
                for (int i2 = 0; i2 < i; i2++) {
                    int i3 = i2;
                    jArr[i3] = jArr[i3] + jArr[i2];
                }
            });
        }
        return hashMap;
    }

    public static AggregatedBlockingInputInfo createAggregatedBlockingInputInfo(long j, double d, long j2, List<BlockingInputInfo> list) {
        int checkAndGetSubpartitionNum = VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum(list);
        long[] computeAggregatedSubpartitionBytes = computeAggregatedSubpartitionBytes(list, checkAndGetSubpartitionNum);
        long computeSkewThreshold = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold(VertexParallelismAndInputInfosDeciderUtils.median(computeAggregatedSubpartitionBytes), d, j);
        long computeTargetSize = VertexParallelismAndInputInfosDeciderUtils.computeTargetSize(computeAggregatedSubpartitionBytes, computeSkewThreshold, j2);
        boolean checkAndGetIntraCorrelation = VertexParallelismAndInputInfosDeciderUtils.checkAndGetIntraCorrelation(list);
        return new AggregatedBlockingInputInfo(computeTargetSize, computeSkewThreshold, VertexParallelismAndInputInfosDeciderUtils.getMaxNumPartitions(list), checkAndGetIntraCorrelation, checkAndGetIntraCorrelation ? new HashMap() : computeSubpartitionBytesByPartitionIndex(list, checkAndGetSubpartitionNum), computeAggregatedSubpartitionBytes);
    }
}
