package org.apache.flink.runtime.executiongraph;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/* loaded from: input_file:org/apache/flink/runtime/executiongraph/EdgeManagerTest.class */
class EdgeManagerTest {

    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    EdgeManagerTest() {
    }

    @Test
    void testGetConsumedPartitionGroup() throws Exception {
        JobVertex jobVertex = new JobVertex("source");
        JobVertex jobVertex2 = new JobVertex("sink");
        ExecutionGraph buildExecutionGraph = buildExecutionGraph(jobVertex, jobVertex2, 2, 2, DistributionPattern.ALL_TO_ALL);
        ConsumedPartitionGroup consumedPartitionGroup = (ConsumedPartitionGroup) ((ExecutionJobVertex) Objects.requireNonNull(buildExecutionGraph.getJobVertex(jobVertex2.getID()))).getTaskVertices()[0].getAllConsumedPartitionGroups().get(0);
        IntermediateResultPartition intermediateResultPartition = ((ExecutionJobVertex) Objects.requireNonNull(buildExecutionGraph.getJobVertex(jobVertex.getID()))).getProducedDataSets()[0].getPartitions()[0];
        Assertions.assertThat((ConsumedPartitionGroup) intermediateResultPartition.getConsumedPartitionGroups().get(0)).isEqualTo(consumedPartitionGroup);
        Assertions.assertThat((ConsumedPartitionGroup) buildExecutionGraph.getSchedulingTopology().getResultPartition(intermediateResultPartition.getPartitionId()).getConsumedPartitionGroups().get(0)).isEqualTo(consumedPartitionGroup);
    }

    @Test
    void testCalculateNumberOfConsumers() throws Exception {
        testCalculateNumberOfConsumers(5, 2, DistributionPattern.ALL_TO_ALL, new int[]{2, 2});
        testCalculateNumberOfConsumers(5, 2, DistributionPattern.POINTWISE, new int[]{1, 1});
        testCalculateNumberOfConsumers(2, 5, DistributionPattern.ALL_TO_ALL, new int[]{5, 5, 5, 5, 5});
        testCalculateNumberOfConsumers(2, 5, DistributionPattern.POINTWISE, new int[]{3, 3, 3, 2, 2});
        testCalculateNumberOfConsumers(5, 5, DistributionPattern.ALL_TO_ALL, new int[]{5, 5, 5, 5, 5});
        testCalculateNumberOfConsumers(5, 5, DistributionPattern.POINTWISE, new int[]{1, 1, 1, 1, 1});
    }

    private void testCalculateNumberOfConsumers(int i, int i2, DistributionPattern distributionPattern, int[] iArr) throws Exception {
        JobVertex jobVertex = new JobVertex("producer");
        JobVertex jobVertex2 = new JobVertex("consumer");
        int i3 = 0;
        Iterator it = ((List) Arrays.stream(((ExecutionJobVertex) Preconditions.checkNotNull(buildExecutionGraph(jobVertex, jobVertex2, i, i2, distributionPattern).getJobVertex(jobVertex2.getID()))).getTaskVertices()).flatMap(executionVertex -> {
            return executionVertex.getAllConsumedPartitionGroups().stream();
        }).collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            int i4 = i3;
            i3++;
            Assertions.assertThat(((ConsumedPartitionGroup) it.next()).getNumConsumers()).isEqualTo(iArr[i4]);
        }
    }

    private ExecutionGraph buildExecutionGraph(JobVertex jobVertex, JobVertex jobVertex2, int i, int i2, DistributionPattern distributionPattern) throws Exception {
        jobVertex.setParallelism(i);
        jobVertex2.setParallelism(i2);
        jobVertex.setInvokableClass(NoOpInvokable.class);
        jobVertex2.setInvokableClass(NoOpInvokable.class);
        jobVertex2.connectNewDataSetAsInput(jobVertex, distributionPattern, ResultPartitionType.BLOCKING);
        return SchedulerTestingUtils.createScheduler(JobGraphTestUtils.batchJobGraph(jobVertex, jobVertex2), ComponentMainThreadExecutorServiceAdapter.forMainThread(), (ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor()).getExecutionGraph();
    }
}
