package org.apache.flink.streaming.api.graph;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
import org.apache.flink.streaming.api.graph.StreamGraphContext;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
/* loaded from: input_file:org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.class */
public class DefaultStreamGraphContext implements StreamGraphContext {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class);
    private final StreamGraph streamGraph;
    private final ImmutableStreamGraph immutableStreamGraph;
    private final Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap;
    private final Map<Integer, Integer> frozenNodeToStartNodeMap;
    private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches;
    private final Map<String, IntermediateDataSet> consumerEdgeIdToIntermediateDataSetMap;
    private final Set<Integer> finishedStreamNodeIds;

    @Nullable
    private final StreamGraphContext.StreamGraphUpdateListener streamGraphUpdateListener;

    @VisibleForTesting
    public DefaultStreamGraphContext(StreamGraph streamGraph, Map<Integer, StreamNodeForwardGroup> map, Map<Integer, Integer> map2, Map<Integer, Map<StreamEdge, NonChainedOutput>> map3, Map<String, IntermediateDataSet> map4, Set<Integer> set, ClassLoader classLoader) {
        this(streamGraph, map, map2, map3, map4, set, classLoader, null);
    }

    public DefaultStreamGraphContext(StreamGraph streamGraph, Map<Integer, StreamNodeForwardGroup> map, Map<Integer, Integer> map2, Map<Integer, Map<StreamEdge, NonChainedOutput>> map3, Map<String, IntermediateDataSet> map4, Set<Integer> set, ClassLoader classLoader, @Nullable StreamGraphContext.StreamGraphUpdateListener streamGraphUpdateListener) {
        this.streamGraph = (StreamGraph) Preconditions.checkNotNull(streamGraph);
        this.steamNodeIdToForwardGroupMap = (Map) Preconditions.checkNotNull(map);
        this.frozenNodeToStartNodeMap = (Map) Preconditions.checkNotNull(map2);
        this.opIntermediateOutputsCaches = (Map) Preconditions.checkNotNull(map3);
        this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph, classLoader);
        this.consumerEdgeIdToIntermediateDataSetMap = (Map) Preconditions.checkNotNull(map4);
        this.finishedStreamNodeIds = set;
        this.streamGraphUpdateListener = streamGraphUpdateListener;
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public ImmutableStreamGraph getStreamGraph() {
        return this.immutableStreamGraph;
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    @Nullable
    public StreamOperatorFactory<?> getOperatorFactory(Integer num) {
        return this.streamGraph.getStreamNode(num).getOperatorFactory();
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> list) {
        Iterator<StreamEdgeUpdateRequestInfo> it = list.iterator();
        while (it.hasNext()) {
            if (!validateStreamEdgeUpdateRequest(it.next())) {
                return false;
            }
        }
        for (StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo : list) {
            StreamEdge streamEdge = getStreamEdge(streamEdgeUpdateRequestInfo.getSourceId(), streamEdgeUpdateRequestInfo.getTargetId(), streamEdgeUpdateRequestInfo.getEdgeId());
            StreamPartitioner<?> outputPartitioner = streamEdgeUpdateRequestInfo.getOutputPartitioner();
            if (outputPartitioner != null) {
                modifyOutputPartitioner(streamEdge, outputPartitioner);
            }
            if (streamEdgeUpdateRequestInfo.getTypeNumber() != 0) {
                streamEdge.setTypeNumber(streamEdgeUpdateRequestInfo.getTypeNumber());
            }
            if (streamEdgeUpdateRequestInfo.getIntraInputKeyCorrelated() != null) {
                modifyIntraInputKeyCorrelation(streamEdge, streamEdgeUpdateRequestInfo.getIntraInputKeyCorrelated().booleanValue());
            }
        }
        if (this.streamGraphUpdateListener == null) {
            return true;
        }
        this.streamGraphUpdateListener.onStreamGraphUpdated();
        return true;
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public boolean modifyStreamNode(List<StreamNodeUpdateRequestInfo> list) {
        for (StreamNodeUpdateRequestInfo streamNodeUpdateRequestInfo : list) {
            StreamNode streamNode = this.streamGraph.getStreamNode(streamNodeUpdateRequestInfo.getNodeId());
            if (streamNodeUpdateRequestInfo.getTypeSerializersIn() != null) {
                if (streamNodeUpdateRequestInfo.getTypeSerializersIn().length != streamNode.getTypeSerializersIn().length) {
                    LOG.info("Modification for node {} is not allowed as the array size of typeSerializersIn is not matched.", streamNodeUpdateRequestInfo.getNodeId());
                    return false;
                }
                streamNode.setSerializersIn(streamNodeUpdateRequestInfo.getTypeSerializersIn());
            }
        }
        if (this.streamGraphUpdateListener == null) {
            return true;
        }
        this.streamGraphUpdateListener.onStreamGraphUpdated();
        return true;
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public boolean areAllUpstreamNodesFinished(ImmutableStreamNode immutableStreamNode) {
        Iterator<ImmutableStreamEdge> it = immutableStreamNode.getInEdges().iterator();
        while (it.hasNext()) {
            if (!this.finishedStreamNodeIds.contains(Integer.valueOf(it.next().getSourceId()))) {
                return false;
            }
        }
        return true;
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public IntermediateDataSetID getConsumedIntermediateDataSetId(String str) {
        return this.consumerEdgeIdToIntermediateDataSetMap.get(str).getId();
    }

    @Override // org.apache.flink.streaming.api.graph.StreamGraphContext
    public StreamPartitioner<?> getOutputPartitioner(String str, Integer num, Integer num2) {
        return ((StreamEdge) Preconditions.checkNotNull(getStreamEdge(num, num2, str))).getPartitioner();
    }

    private boolean validateStreamEdgeUpdateRequest(StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo) {
        Integer sourceId = streamEdgeUpdateRequestInfo.getSourceId();
        Integer targetId = streamEdgeUpdateRequestInfo.getTargetId();
        StreamEdge streamEdge = getStreamEdge(sourceId, targetId, streamEdgeUpdateRequestInfo.getEdgeId());
        if (streamEdgeUpdateRequestInfo.getOutputPartitioner() != null) {
            Map<StreamEdge, NonChainedOutput> map = this.opIntermediateOutputsCaches.get(sourceId);
            NonChainedOutput nonChainedOutput = map != null ? map.get(streamEdge) : null;
            if (nonChainedOutput != null && ((Set) map.entrySet().stream().filter(entry -> {
                return ((NonChainedOutput) entry.getValue()).equals(nonChainedOutput);
            }).map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toSet())).size() != 1) {
                LOG.info("Skip modifying edge {} because the subscribing output is reused.", streamEdge);
                return false;
            }
        }
        if (this.frozenNodeToStartNodeMap.containsKey(targetId)) {
            LOG.info("Skip modifying edge {} because the target node with id {} is in frozen list.", streamEdge, targetId);
            return false;
        }
        StreamPartitioner<?> outputPartitioner = streamEdgeUpdateRequestInfo.getOutputPartitioner();
        if (outputPartitioner == null) {
            return true;
        }
        if (streamEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) {
            LOG.info("Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", streamEdge);
            return false;
        }
        if (!outputPartitioner.getClass().equals(ForwardPartitioner.class) || ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup(this.steamNodeIdToForwardGroupMap.get(Integer.valueOf(streamEdge.getSourceId())), this.steamNodeIdToForwardGroupMap.get(Integer.valueOf(streamEdge.getTargetId())))) {
            return true;
        }
        LOG.info("Skip modifying edge {} because forward groups can not be merged.", streamEdge);
        return false;
    }

    private void modifyOutputPartitioner(StreamEdge streamEdge, StreamPartitioner<?> streamPartitioner) {
        if (streamPartitioner == null) {
            return;
        }
        StreamPartitioner<?> partitioner = streamEdge.getPartitioner();
        streamEdge.setPartitioner(streamPartitioner);
        if (streamEdge.getPartitioner() instanceof ForwardPartitioner) {
            tryConvertForwardPartitionerAndMergeForwardGroup(streamEdge);
        }
        Map<StreamEdge, NonChainedOutput> map = this.opIntermediateOutputsCaches.get(Integer.valueOf(streamEdge.getSourceId()));
        NonChainedOutput nonChainedOutput = map != null ? map.get(streamEdge) : null;
        if (nonChainedOutput != null) {
            nonChainedOutput.setPartitioner(streamEdge.getPartitioner());
        }
        Optional.ofNullable(this.consumerEdgeIdToIntermediateDataSetMap.get(streamEdge.getEdgeId())).ifPresent(intermediateDataSet -> {
            intermediateDataSet.updateOutputPattern(streamEdge.getPartitioner().isPointwise() ? DistributionPattern.POINTWISE : DistributionPattern.ALL_TO_ALL, streamEdge.getPartitioner().isBroadcast(), streamEdge.getPartitioner().getClass().equals(ForwardPartitioner.class));
        });
        LOG.info("The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", new Object[]{streamEdge, partitioner, streamPartitioner, streamEdge.getPartitioner()});
    }

    private void modifyIntraInputKeyCorrelation(StreamEdge streamEdge, boolean z) {
        if (streamEdge.isIntraInputKeyCorrelated() == z) {
            return;
        }
        streamEdge.setIntraInputKeyCorrelated(z);
    }

    private void tryConvertForwardPartitionerAndMergeForwardGroup(StreamEdge streamEdge) {
        Preconditions.checkState(streamEdge.getPartitioner() instanceof ForwardPartitioner);
        Integer valueOf = Integer.valueOf(streamEdge.getSourceId());
        Integer valueOf2 = Integer.valueOf(streamEdge.getTargetId());
        if (canConvertToForwardPartitioner(streamEdge)) {
            streamEdge.setPartitioner(new ForwardPartitioner());
            Preconditions.checkState(mergeForwardGroups(valueOf, valueOf2));
        } else if (streamEdge.getPartitioner() instanceof ForwardForUnspecifiedPartitioner) {
            streamEdge.setPartitioner(new RescalePartitioner());
        } else if (streamEdge.getPartitioner() instanceof ForwardForConsecutiveHashPartitioner) {
            streamEdge.setPartitioner(((ForwardForConsecutiveHashPartitioner) streamEdge.getPartitioner()).getHashPartitioner());
        } else {
            Preconditions.checkState(mergeForwardGroups(valueOf, valueOf2));
        }
    }

    private boolean canConvertToForwardPartitioner(StreamEdge streamEdge) {
        Integer valueOf = Integer.valueOf(streamEdge.getSourceId());
        Integer valueOf2 = Integer.valueOf(streamEdge.getTargetId());
        if (streamEdge.getPartitioner() instanceof ForwardForUnspecifiedPartitioner) {
            return !this.frozenNodeToStartNodeMap.containsKey(valueOf) && StreamingJobGraphGenerator.isChainable(streamEdge, this.streamGraph, true) && ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup(this.steamNodeIdToForwardGroupMap.get(valueOf), this.steamNodeIdToForwardGroupMap.get(valueOf2));
        }
        if (streamEdge.getPartitioner() instanceof ForwardForConsecutiveHashPartitioner) {
            return ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup(this.steamNodeIdToForwardGroupMap.get(valueOf), this.steamNodeIdToForwardGroupMap.get(valueOf2));
        }
        return false;
    }

    private boolean mergeForwardGroups(Integer num, Integer num2) {
        StreamNodeForwardGroup streamNodeForwardGroup = this.steamNodeIdToForwardGroupMap.get(num);
        StreamNodeForwardGroup streamNodeForwardGroup2 = this.steamNodeIdToForwardGroupMap.get(num2);
        if (streamNodeForwardGroup == null || streamNodeForwardGroup2 == null || !streamNodeForwardGroup.mergeForwardGroup(streamNodeForwardGroup2)) {
            return false;
        }
        streamNodeForwardGroup2.getVertexIds().forEach(num3 -> {
            this.steamNodeIdToForwardGroupMap.put(num3, streamNodeForwardGroup);
        });
        return true;
    }

    private StreamEdge getStreamEdge(Integer num, Integer num2, String str) {
        for (StreamEdge streamEdge : this.streamGraph.getStreamEdges(num.intValue(), num2.intValue())) {
            if (streamEdge.getEdgeId().equals(str)) {
                return streamEdge;
            }
        }
        throw new RuntimeException(String.format("Stream edge with id '%s' is not found whose source id is %d, target id is %d.", str, num, num2));
    }
}
