package org.apache.flink.runtime.jobgraph.forwardgroup;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.class */
class ForwardGroupComputeUtilTest {
    ForwardGroupComputeUtilTest() {
    }

    @Test
    void testIsolatedVertices() throws Exception {
        checkGroupSize(computeForwardGroups(new JobVertex("v1"), new JobVertex("v2"), new JobVertex("v3")), 0, new Integer[0]);
    }

    @Test
    void testIsolatedChainedStreamNodeGroups() throws Exception {
        checkGroupSize(computeForwardGroups(createStreamNodes(3), Collections.emptyMap()), 3, 1, 1, 1);
    }

    @Test
    void testVariousResultPartitionTypesBetweenVertices() throws Exception {
        testThreeVerticesConnectSequentially(false, true, 1, 2);
        testThreeVerticesConnectSequentially(false, false, 0, new Integer[0]);
        testThreeVerticesConnectSequentially(true, true, 1, 3);
    }

    private void testThreeVerticesConnectSequentially(boolean z, boolean z2, int i, Integer... numArr) throws Exception {
        JobVertex jobVertex = new JobVertex("v1");
        JobVertex jobVertex2 = new JobVertex("v2");
        JobVertex jobVertex3 = new JobVertex("v3");
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex2, jobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, false, z);
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex3, jobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, z2);
        checkGroupSize(computeForwardGroups(jobVertex, jobVertex2, jobVertex3), i, numArr);
    }

    @Test
    void testVariousConnectTypesBetweenChainedStreamNodeGroup() throws Exception {
        testThreeChainedStreamNodeGroupsConnectSequentially(false, true, 2, 1, 2);
        testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 1, 1, 1);
        testThreeChainedStreamNodeGroupsConnectSequentially(true, true, 1, 3);
    }

    private void testThreeChainedStreamNodeGroupsConnectSequentially(boolean z, boolean z2, int i, Integer... numArr) throws Exception {
        List<StreamNode> createStreamNodes = createStreamNodes(3);
        HashMap hashMap = new HashMap();
        if (z) {
            ((Set) hashMap.computeIfAbsent(createStreamNodes.get(1), streamNode -> {
                return new HashSet();
            })).add(createStreamNodes.get(0));
        }
        if (z2) {
            ((Set) hashMap.computeIfAbsent(createStreamNodes.get(2), streamNode2 -> {
                return new HashSet();
            })).add(createStreamNodes.get(1));
        }
        checkGroupSize(computeForwardGroups(createStreamNodes, hashMap), i, numArr);
    }

    @Test
    void testTwoInputsMergesIntoOne() throws Exception {
        JobVertex jobVertex = new JobVertex("v1");
        JobVertex jobVertex2 = new JobVertex("v2");
        JobVertex jobVertex3 = new JobVertex("v3");
        JobVertex jobVertex4 = new JobVertex("v4");
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex3, jobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex3, jobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex4, jobVertex3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
        checkGroupSize(computeForwardGroups(jobVertex, jobVertex2, jobVertex3, jobVertex4), 1, 3);
    }

    @Test
    void testTwoInputsMergesIntoOneForStreamNodeForwardGroup() throws Exception {
        List<StreamNode> createStreamNodes = createStreamNodes(4);
        HashMap hashMap = new HashMap();
        ((Set) hashMap.computeIfAbsent(createStreamNodes.get(2), streamNode -> {
            return new HashSet();
        })).add(createStreamNodes.get(0));
        ((Set) hashMap.computeIfAbsent(createStreamNodes.get(2), streamNode2 -> {
            return new HashSet();
        })).add(createStreamNodes.get(1));
        checkGroupSize(computeForwardGroups(createStreamNodes, hashMap), 2, 3, 1);
    }

    @Test
    void testOneInputSplitsIntoTwo() throws Exception {
        JobVertex jobVertex = new JobVertex("v1");
        JobVertex jobVertex2 = new JobVertex("v2");
        JobVertex jobVertex3 = new JobVertex("v3");
        JobVertex jobVertex4 = new JobVertex("v4");
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex2, jobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex3, jobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertex4, jobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING, false, true);
        checkGroupSize(computeForwardGroups(jobVertex, jobVertex2, jobVertex3, jobVertex4), 1, 3);
    }

    @Test
    void testOneInputSplitsIntoTwoForStreamNodeForwardGroup() throws Exception {
        List<StreamNode> createStreamNodes = createStreamNodes(4);
        HashMap hashMap = new HashMap();
        ((Set) hashMap.computeIfAbsent(createStreamNodes.get(3), streamNode -> {
            return new HashSet();
        })).add(createStreamNodes.get(1));
        ((Set) hashMap.computeIfAbsent(createStreamNodes.get(2), streamNode2 -> {
            return new HashSet();
        })).add(createStreamNodes.get(1));
        checkGroupSize(computeForwardGroups(createStreamNodes, hashMap), 2, 3, 1);
    }

    private static Set<ForwardGroup<?>> computeForwardGroups(JobVertex... jobVertexArr) {
        Arrays.asList(jobVertexArr).forEach(jobVertex -> {
            jobVertex.setInvokableClass(NoOpInvokable.class);
        });
        return new HashSet(ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(Arrays.asList(jobVertexArr)).values());
    }

    private static void checkGroupSize(Set<ForwardGroup<?>> set, int i, Integer... numArr) {
        Assertions.assertThat(set.size()).isEqualTo(i);
        Assertions.assertThat((List) set.stream().map(forwardGroup -> {
            return forwardGroup instanceof JobVertexForwardGroup ? Integer.valueOf(((JobVertexForwardGroup) forwardGroup).size()) : Integer.valueOf(((StreamNodeForwardGroup) forwardGroup).size());
        }).collect(Collectors.toList())).contains(numArr);
    }

    private static StreamNode createStreamNode(int i) {
        return new StreamNode(Integer.valueOf(i), (String) null, (String) null, (StreamOperator) null, (String) null, (Class) null);
    }

    private static List<StreamNode> createStreamNodes(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 1; i2 <= i; i2++) {
            arrayList.add(new StreamNode(Integer.valueOf(i2), (String) null, (String) null, (StreamOperator) null, (String) null, (Class) null));
        }
        return arrayList;
    }

    private static Set<ForwardGroup<?>> computeForwardGroups(List<StreamNode> list, Map<StreamNode, Set<StreamNode>> map) {
        return new HashSet(computeStreamNodeForwardGroupAndCheckParallelism(list, streamNode -> {
            return (Set) map.getOrDefault(streamNode, Collections.emptySet());
        }).values());
    }

    public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroupAndCheckParallelism(Iterable<StreamNode> iterable, Function<StreamNode, Set<StreamNode>> function) {
        Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroup = ForwardGroupComputeUtil.computeStreamNodeForwardGroup(iterable, function);
        iterable.forEach(streamNode -> {
            StreamNodeForwardGroup streamNodeForwardGroup = (StreamNodeForwardGroup) computeStreamNodeForwardGroup.get(Integer.valueOf(streamNode.getId()));
            if (streamNodeForwardGroup == null || !streamNodeForwardGroup.isParallelismDecided()) {
                return;
            }
            Preconditions.checkState(streamNode.getParallelism() == streamNodeForwardGroup.getParallelism());
        });
        return computeStreamNodeForwardGroup;
    }
}
