package org.apache.flink.runtime.jobgraph.forwardgroup;

import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.VertexGroupComputeUtil;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.class */
public class ForwardGroupComputeUtil {
    public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroupsAndCheckParallelism(Iterable<JobVertex> iterable) {
        Map<JobVertexID, JobVertexForwardGroup> computeForwardGroups = computeForwardGroups(iterable, ForwardGroupComputeUtil::getForwardProducers);
        iterable.forEach(jobVertex -> {
            JobVertexForwardGroup jobVertexForwardGroup = (JobVertexForwardGroup) computeForwardGroups.get(jobVertex.getID());
            if (jobVertexForwardGroup == null || !jobVertexForwardGroup.isParallelismDecided()) {
                return;
            }
            Preconditions.checkState(jobVertex.getParallelism() == jobVertexForwardGroup.getParallelism());
        });
        return computeForwardGroups;
    }

    public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroups(Iterable<JobVertex> iterable, Function<JobVertex, Set<JobVertex>> function) {
        Map computeVertexToGroup = computeVertexToGroup(iterable, function);
        HashMap hashMap = new HashMap();
        for (Set set : VertexGroupComputeUtil.uniqueVertexGroups(computeVertexToGroup)) {
            if (set.size() > 1) {
                JobVertexForwardGroup jobVertexForwardGroup = new JobVertexForwardGroup(set);
                Iterator<JobVertexID> it = jobVertexForwardGroup.getVertexIds().iterator();
                while (it.hasNext()) {
                    hashMap.put(it.next(), jobVertexForwardGroup);
                }
            }
        }
        return hashMap;
    }

    public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroup(Iterable<StreamNode> iterable, Function<StreamNode, Set<StreamNode>> function) {
        Map computeVertexToGroup = computeVertexToGroup(iterable, function);
        HashMap hashMap = new HashMap();
        Iterator it = VertexGroupComputeUtil.uniqueVertexGroups(computeVertexToGroup).iterator();
        while (it.hasNext()) {
            StreamNodeForwardGroup streamNodeForwardGroup = new StreamNodeForwardGroup((Set) it.next());
            Iterator<Integer> it2 = streamNodeForwardGroup.getVertexIds().iterator();
            while (it2.hasNext()) {
                hashMap.put(it2.next(), streamNodeForwardGroup);
            }
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v28, types: [java.util.Set] */
    private static <T> Map<T, Set<T>> computeVertexToGroup(Iterable<T> iterable, Function<T, Set<T>> function) {
        IdentityHashMap identityHashMap = new IdentityHashMap();
        for (T t : iterable) {
            HashSet hashSet = new HashSet();
            hashSet.add(t);
            identityHashMap.put(t, hashSet);
            for (T t2 : function.apply(t)) {
                Set set = (Set) identityHashMap.get(t2);
                if (set == null) {
                    throw new IllegalStateException("Producer task " + t2 + " forward group is null while calculating forward group for the consumer task " + t + ". This should be a forward group building bug.");
                }
                if (hashSet != set) {
                    hashSet = VertexGroupComputeUtil.mergeVertexGroups(hashSet, set, identityHashMap);
                }
            }
        }
        return identityHashMap;
    }

    public static boolean canTargetMergeIntoSourceForwardGroup(ForwardGroup<?> forwardGroup, ForwardGroup<?> forwardGroup2) {
        if (forwardGroup == null || forwardGroup2 == null) {
            return false;
        }
        if (forwardGroup == forwardGroup2) {
            return true;
        }
        if (forwardGroup.isParallelismDecided() && forwardGroup2.isParallelismDecided() && forwardGroup.getParallelism() != forwardGroup2.getParallelism()) {
            return false;
        }
        return (forwardGroup.isParallelismDecided() && forwardGroup2.isMaxParallelismDecided() && forwardGroup.getParallelism() > forwardGroup2.getMaxParallelism()) ? false : true;
    }

    static Set<JobVertex> getForwardProducers(JobVertex jobVertex) {
        return (Set) jobVertex.getInputs().stream().filter((v0) -> {
            return v0.isForward();
        }).map((v0) -> {
            return v0.getSource();
        }).map((v0) -> {
            return v0.getProducer();
        }).collect(Collectors.toSet());
    }
}
