package org.apache.flink.table.runtime.strategy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.streaming.api.graph.StreamGraphContext;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/table/runtime/strategy/AdaptiveSkewedJoinOptimizationStrategy.class */
public class AdaptiveSkewedJoinOptimizationStrategy extends BaseAdaptiveJoinOperatorOptimizationStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(AdaptiveSkewedJoinOptimizationStrategy.class);
    private static final int LEFT_INPUT_TYPE_NUMBER = 1;
    private static final int RIGHT_INPUT_TYPE_NUMBER = 2;
    private Map<Integer, Map<Integer, long[]>> aggregatedProducedBytesByTypeNumberAndNodeId;
    private OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy adaptiveSkewedJoinOptimizationStrategy;
    private long skewedThresholdInBytes;
    private double skewedFactor;

    public void initialize(StreamGraphContext streamGraphContext) {
        ReadableConfig configuration = streamGraphContext.getStreamGraph().getConfiguration();
        this.aggregatedProducedBytesByTypeNumberAndNodeId = new HashMap();
        this.adaptiveSkewedJoinOptimizationStrategy = (OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy) configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY);
        this.skewedFactor = ((Double) configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_FACTOR)).doubleValue();
        this.skewedThresholdInBytes = ((MemorySize) configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_THRESHOLD)).getBytes();
    }

    public boolean onOperatorsFinished(OperatorsFinished operatorsFinished, StreamGraphContext streamGraphContext) throws Exception {
        visitDownstreamAdaptiveJoinNode(operatorsFinished, streamGraphContext);
        return true;
    }

    @Override // org.apache.flink.table.runtime.strategy.BaseAdaptiveJoinOperatorOptimizationStrategy
    void tryOptimizeAdaptiveJoin(OperatorsFinished operatorsFinished, StreamGraphContext streamGraphContext, ImmutableStreamNode immutableStreamNode, List<ImmutableStreamEdge> list, AdaptiveJoin adaptiveJoin) {
        if (canPerformOptimization(immutableStreamNode)) {
            for (ImmutableStreamEdge immutableStreamEdge : list) {
                AllToAllBlockingResultInfo blockingResultInfo = getBlockingResultInfo(operatorsFinished, streamGraphContext, immutableStreamEdge);
                Preconditions.checkState(blockingResultInfo instanceof AllToAllBlockingResultInfo);
                aggregatedProducedBytesByTypeNumber(immutableStreamNode, immutableStreamEdge.getTypeNumber(), blockingResultInfo.getAggregatedSubpartitionBytes());
            }
            if (streamGraphContext.areAllUpstreamNodesFinished(immutableStreamNode)) {
                applyAdaptiveSkewedJoinOptimization(streamGraphContext, immutableStreamNode, adaptiveJoin.getJoinType());
                freeNodeStatistic(Integer.valueOf(immutableStreamNode.getId()));
            }
        }
    }

    private boolean canPerformOptimization(ImmutableStreamNode immutableStreamNode) {
        if (AdaptiveJoinOptimizationUtils.isBroadcastJoin(immutableStreamNode)) {
            return false;
        }
        if (this.adaptiveSkewedJoinOptimizationStrategy == OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.AUTO) {
            return canPerformOptimizationAutomatic(immutableStreamNode);
        }
        if (this.adaptiveSkewedJoinOptimizationStrategy == OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.FORCED) {
            return canPerformOptimizationForced(immutableStreamNode);
        }
        return false;
    }

    private static BlockingResultInfo getBlockingResultInfo(OperatorsFinished operatorsFinished, StreamGraphContext streamGraphContext, ImmutableStreamEdge immutableStreamEdge) {
        List<BlockingResultInfo> list = (List) operatorsFinished.getResultInfoMap().get(Integer.valueOf(immutableStreamEdge.getSourceId()));
        IntermediateDataSetID consumedIntermediateDataSetId = streamGraphContext.getConsumedIntermediateDataSetId(immutableStreamEdge.getEdgeId());
        for (BlockingResultInfo blockingResultInfo : list) {
            if (blockingResultInfo.getResultId().equals(consumedIntermediateDataSetId)) {
                return blockingResultInfo;
            }
        }
        throw new IllegalStateException("No matching BlockingResultInfo found for edge ID: " + immutableStreamEdge.getEdgeId());
    }

    private void aggregatedProducedBytesByTypeNumber(ImmutableStreamNode immutableStreamNode, int i, List<Long> list) {
        long[] computeIfAbsent = this.aggregatedProducedBytesByTypeNumberAndNodeId.computeIfAbsent(Integer.valueOf(immutableStreamNode.getId()), num -> {
            return new HashMap();
        }).computeIfAbsent(Integer.valueOf(i), num2 -> {
            return new long[list.size()];
        });
        Preconditions.checkState(list.size() == computeIfAbsent.length);
        for (int i2 = 0; i2 < list.size(); i2++) {
            int i3 = i2;
            computeIfAbsent[i3] = computeIfAbsent[i3] + list.get(i2).longValue();
        }
    }

    private void applyAdaptiveSkewedJoinOptimization(StreamGraphContext streamGraphContext, ImmutableStreamNode immutableStreamNode, FlinkJoinType flinkJoinType) {
        long[] jArr = this.aggregatedProducedBytesByTypeNumberAndNodeId.get(Integer.valueOf(immutableStreamNode.getId())).get(1);
        Preconditions.checkState(jArr != null, "Left input bytes of adaptive join [%s] is unknown, which is unexpected.", new Object[]{Integer.valueOf(immutableStreamNode.getId())});
        long[] jArr2 = this.aggregatedProducedBytesByTypeNumberAndNodeId.get(Integer.valueOf(immutableStreamNode.getId())).get(2);
        Preconditions.checkState(jArr2 != null, "Right input bytes of adaptive join [%s] is unknown, which is unexpected.", new Object[]{Integer.valueOf(immutableStreamNode.getId())});
        long computeSkewThreshold = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold(VertexParallelismAndInputInfosDeciderUtils.median(jArr), this.skewedFactor, this.skewedThresholdInBytes);
        long computeSkewThreshold2 = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold(VertexParallelismAndInputInfosDeciderUtils.median(jArr2), this.skewedFactor, this.skewedThresholdInBytes);
        boolean z = false;
        boolean z2 = false;
        switch (flinkJoinType) {
            case RIGHT:
                z2 = true;
                break;
            case INNER:
                z = true;
                z2 = true;
                break;
            case LEFT:
            case SEMI:
            case ANTI:
                z = true;
                break;
            case FULL:
            default:
                throw new IllegalStateException(String.format("Unexpected join type %s.", flinkJoinType));
        }
        boolean existBytesLargerThanThreshold = z & existBytesLargerThanThreshold(jArr, computeSkewThreshold);
        boolean existBytesLargerThanThreshold2 = z2 & existBytesLargerThanThreshold(jArr2, computeSkewThreshold2);
        if (existBytesLargerThanThreshold) {
            LOG.info("Apply skewed join optimization {} for left input of node {}.", tryModifyInputAndOutputEdges(streamGraphContext, immutableStreamNode, 1) ? "succeeded" : "failed", Integer.valueOf(immutableStreamNode.getId()));
        }
        if (existBytesLargerThanThreshold2) {
            LOG.info("Apply skewed join optimization {} for right input of node {}.", tryModifyInputAndOutputEdges(streamGraphContext, immutableStreamNode, 2) ? "succeeded" : "failed", Integer.valueOf(immutableStreamNode.getId()));
        }
    }

    private static boolean tryModifyInputAndOutputEdges(StreamGraphContext streamGraphContext, ImmutableStreamNode immutableStreamNode, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(generateCorrelationModificationRequestInfos(AdaptiveJoinOptimizationUtils.filterEdges(immutableStreamNode.getInEdges(), i)));
        arrayList.addAll(generateForwardPartitionerModificationRequestInfos(immutableStreamNode.getOutEdges(), streamGraphContext));
        return streamGraphContext.modifyStreamEdge(arrayList);
    }

    private static List<StreamEdgeUpdateRequestInfo> generateCorrelationModificationRequestInfos(List<ImmutableStreamEdge> list) {
        ArrayList arrayList = new ArrayList();
        for (ImmutableStreamEdge immutableStreamEdge : list) {
            arrayList.add(new StreamEdgeUpdateRequestInfo(immutableStreamEdge.getEdgeId(), Integer.valueOf(immutableStreamEdge.getSourceId()), Integer.valueOf(immutableStreamEdge.getTargetId())).withIntraInputKeyCorrelated(false));
        }
        return arrayList;
    }

    private static List<StreamEdgeUpdateRequestInfo> generateForwardPartitionerModificationRequestInfos(List<ImmutableStreamEdge> list, StreamGraphContext streamGraphContext) {
        ArrayList arrayList = new ArrayList();
        for (ImmutableStreamEdge immutableStreamEdge : list) {
            if (immutableStreamEdge.isForwardForConsecutiveHashEdge()) {
                arrayList.add(new StreamEdgeUpdateRequestInfo(immutableStreamEdge.getEdgeId(), Integer.valueOf(immutableStreamEdge.getSourceId()), Integer.valueOf(immutableStreamEdge.getTargetId())).withOutputPartitioner(((StreamPartitioner) Preconditions.checkNotNull(streamGraphContext.getOutputPartitioner(immutableStreamEdge.getEdgeId(), Integer.valueOf(immutableStreamEdge.getSourceId()), Integer.valueOf(immutableStreamEdge.getTargetId())))).getHashPartitioner()));
            }
        }
        return arrayList;
    }

    private void freeNodeStatistic(Integer num) {
        this.aggregatedProducedBytesByTypeNumberAndNodeId.remove(num);
    }

    private static boolean existBytesLargerThanThreshold(long[] jArr, long j) {
        for (long j2 : jArr) {
            if (j2 > j) {
                return true;
            }
        }
        return false;
    }

    private static boolean canPerformOptimizationAutomatic(ImmutableStreamNode immutableStreamNode) {
        return immutableStreamNode.getOutEdges().stream().noneMatch((v0) -> {
            return v0.isIntraInputKeyCorrelated();
        });
    }

    private static boolean canPerformOptimizationForced(ImmutableStreamNode immutableStreamNode) {
        return immutableStreamNode.getOutEdges().stream().noneMatch(immutableStreamEdge -> {
            return immutableStreamEdge.isIntraInputKeyCorrelated() && !immutableStreamEdge.isForwardForConsecutiveHashEdge();
        });
    }
}
