package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import javax.annotation.Nullable;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.executiongraph.failover.flip1.FixedDelayRestartBackoffTimeStrategy;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.TestingJobMasterPartitionTracker;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.class */
class AdaptiveBatchSchedulerTest {
    private static final int SOURCE_PARALLELISM_1 = 6;
    private static final int SOURCE_PARALLELISM_2 = 4;
    private static final long SUBPARTITION_BYTES = 100;

    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();
    private ComponentMainThreadExecutor mainThreadExecutor;
    private ManuallyTriggeredScheduledExecutor taskRestartExecutor;

    AdaptiveBatchSchedulerTest() {
    }

    @BeforeEach
    void setUp() {
        this.mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forMainThread();
        this.taskRestartExecutor = new ManuallyTriggeredScheduledExecutor();
    }

    @Test
    void testAdaptiveBatchScheduler() throws Exception {
        JobGraph createJobGraph = createJobGraph();
        Iterator it = createJobGraph.getVertices().iterator();
        JobVertex jobVertex = (JobVertex) it.next();
        JobVertex jobVertex2 = (JobVertex) it.next();
        JobVertex jobVertex3 = (JobVertex) it.next();
        SchedulerBase createScheduler = createScheduler(createJobGraph);
        ExecutionJobVertex jobVertex4 = createScheduler.getExecutionGraph().getJobVertex(jobVertex3.getID());
        createScheduler.startScheduling();
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex2);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(10);
        Assertions.assertThat(jobVertex3.getParallelism()).isEqualTo(10);
        checkAggregatedInputDataBytesIsCalculated(jobVertex4, 26000L);
    }

    @Test
    void testDecideParallelismForForwardTarget() throws Exception {
        JobVertex createJobVertex = createJobVertex("source", SOURCE_PARALLELISM_1);
        JobVertex createJobVertex2 = createJobVertex("map", -1);
        JobVertex createJobVertex3 = createJobVertex("sink", -1);
        createJobVertex2.connectNewDataSetAsInput(createJobVertex, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        createJobVertex3.connectNewDataSetAsInput(createJobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        ((JobEdge) ((IntermediateDataSet) createJobVertex2.getProducedDataSets().get(0)).getConsumers().get(0)).setForward(true);
        SchedulerBase createScheduler = createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2, createJobVertex3}), DefaultSchedulerBuilder.createCustomParallelismDecider((Function<JobVertexID, Integer>) jobVertexID -> {
            return jobVertexID.equals(createJobVertex2.getID()) ? 5 : 10;
        }), 128);
        DefaultExecutionGraph executionGraph = createScheduler.getExecutionGraph();
        ExecutionJobVertex jobVertex = executionGraph.getJobVertex(createJobVertex2.getID());
        ExecutionJobVertex jobVertex2 = executionGraph.getJobVertex(createJobVertex3.getID());
        createScheduler.startScheduling();
        Assertions.assertThat(jobVertex.getParallelism()).isEqualTo(-1);
        Assertions.assertThat(jobVertex2.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, createJobVertex);
        Assertions.assertThat(jobVertex.getParallelism()).isEqualTo(5);
        Assertions.assertThat(jobVertex2.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, createJobVertex2);
        Assertions.assertThat(jobVertex.getParallelism()).isEqualTo(5);
        Assertions.assertThat(jobVertex2.getParallelism()).isEqualTo(5);
        Assertions.assertThat(createJobVertex3.getParallelism()).isEqualTo(5);
        checkAggregatedInputDataBytesIsCalculated(jobVertex2, 13000L);
    }

    @Test
    void testUpdateBlockingResultInfoWhileScheduling() throws Exception {
        JobGraph createJobGraph = createJobGraph();
        Iterator it = createJobGraph.getVertices().iterator();
        JobVertex jobVertex = (JobVertex) it.next();
        JobVertex jobVertex2 = (JobVertex) it.next();
        JobVertex jobVertex3 = (JobVertex) it.next();
        TestingJobMasterPartitionTracker testingJobMasterPartitionTracker = new TestingJobMasterPartitionTracker();
        testingJobMasterPartitionTracker.setIsPartitionTrackedFunction(resultPartitionID -> {
            return true;
        });
        AdaptiveBatchScheduler buildAdaptiveBatchJobScheduler = new DefaultSchedulerBuilder(createJobGraph, this.mainThreadExecutor, (ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor()).setDelayExecutor(this.taskRestartExecutor).setPartitionTracker(testingJobMasterPartitionTracker).setRestartBackoffTimeStrategy(new FixedDelayRestartBackoffTimeStrategy.FixedDelayRestartBackoffTimeStrategyFactory(10, 0L).create()).setVertexParallelismAndInputInfosDecider(DefaultSchedulerBuilder.createCustomParallelismDecider(SOURCE_PARALLELISM_1)).setDefaultMaxParallelism(SOURCE_PARALLELISM_1).buildAdaptiveBatchJobScheduler();
        DefaultExecutionGraph executionGraph = buildAdaptiveBatchJobScheduler.getExecutionGraph();
        ExecutionJobVertex jobVertex4 = executionGraph.getJobVertex(jobVertex.getID());
        ExecutionJobVertex jobVertex5 = executionGraph.getJobVertex(jobVertex3.getID());
        buildAdaptiveBatchJobScheduler.startScheduling();
        transitionExecutionsState(buildAdaptiveBatchJobScheduler, ExecutionState.FINISHED, jobVertex);
        Assertions.assertThat(getBlockingResultInfo(buildAdaptiveBatchJobScheduler, jobVertex).getNumOfRecordedPartitions()).isEqualTo(SOURCE_PARALLELISM_1);
        transitionExecutionsState(buildAdaptiveBatchJobScheduler, ExecutionState.FINISHED, jobVertex2);
        Assertions.assertThat(getBlockingResultInfo(buildAdaptiveBatchJobScheduler, jobVertex2).getNumOfRecordedPartitions()).isEqualTo(4);
        triggerFailedByPartitionNotFound(buildAdaptiveBatchJobScheduler, jobVertex4.getTaskVertices()[0], jobVertex5.getTaskVertices()[0]);
        this.taskRestartExecutor.triggerScheduledTasks();
        Assertions.assertThat(getBlockingResultInfo(buildAdaptiveBatchJobScheduler, jobVertex).getNumOfRecordedPartitions()).isEqualTo(5);
    }

    @Test
    void testConsumeOneResultTwice() throws Exception {
        JobVertex createJobVertex = createJobVertex("source1", 1);
        JobVertex createJobVertex2 = createJobVertex("sink", -1);
        IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
        createJobVertex2.connectNewDataSetAsInput(createJobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, intermediateDataSetID, false);
        createJobVertex2.connectNewDataSetAsInput(createJobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, intermediateDataSetID, false);
        SchedulerBase createScheduler = createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2}), DefaultVertexParallelismAndInputInfosDeciderTest.createDecider(1, 16, 400L), 16);
        ExecutionJobVertex jobVertex = createScheduler.getExecutionGraph().getJobVertex(createJobVertex2.getID());
        createScheduler.startScheduling();
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, createJobVertex);
        Assertions.assertThat(jobVertex.getParallelism()).isEqualTo(8);
        Assertions.assertThat(createJobVertex2.getParallelism()).isEqualTo(8);
    }

    @Test
    void testParallelismDecidedVerticesCanBeInitializedEarlier() throws Exception {
        JobVertex createJobVertex = createJobVertex("source", 8);
        JobVertex createJobVertex2 = createJobVertex("sink", 8);
        createJobVertex2.connectNewDataSetAsInput(createJobVertex, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
        SchedulerBase createScheduler = createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2}));
        ExecutionJobVertex jobVertex = createScheduler.getExecutionGraph().getJobVertex(createJobVertex2.getID());
        createScheduler.startScheduling();
        Assertions.assertThat(jobVertex.isInitialized()).isTrue();
    }

    @Test
    void testUserConfiguredMaxParallelismIsLargerThanGlobalMaxParallelism() throws Exception {
        testUserConfiguredMaxParallelism(1, 32, 128, 1L, 32);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalMaxParallelism() throws Exception {
        testUserConfiguredMaxParallelism(1, 128, 32, 1L, 32);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalMinParallelism() throws Exception {
        testUserConfiguredMaxParallelism(16, 128, 8, 400L, 8);
    }

    @Test
    void testUserConfiguredMaxParallelismIsSmallerThanGlobalDefaultSourceParallelism() throws Exception {
        JobVertex createJobVertex = createJobVertex("source", -1);
        createJobVertex.setMaxParallelism(8);
        createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex}), DefaultVertexParallelismAndInputInfosDeciderTest.createDecider(1, 128, 1L, 32), 128).startScheduling();
        Assertions.assertThat(createJobVertex.getParallelism()).isEqualTo(8);
    }

    void testUserConfiguredMaxParallelism(int i, int i2, int i3, long j, int i4) throws Exception {
        JobVertex createJobVertex = createJobVertex("source", 8);
        JobVertex createJobVertex2 = createJobVertex("sink", -1);
        createJobVertex2.setMaxParallelism(i3);
        createJobVertex2.connectNewDataSetAsInput(createJobVertex, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        SchedulerBase createScheduler = createScheduler(new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2}), DefaultVertexParallelismAndInputInfosDeciderTest.createDecider(i, i2, j), i2);
        createScheduler.startScheduling();
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, createJobVertex);
        Assertions.assertThat(createJobVertex2.getParallelism()).isEqualTo(i4);
    }

    private BlockingResultInfo getBlockingResultInfo(AdaptiveBatchScheduler adaptiveBatchScheduler, JobVertex jobVertex) {
        return adaptiveBatchScheduler.getBlockingResultInfo(((IntermediateDataSet) Iterables.getOnlyElement(jobVertex.getProducedDataSets())).getId());
    }

    private void checkAggregatedInputDataBytesIsCalculated(ExecutionJobVertex executionJobVertex, long j) {
        long j2 = 0;
        for (ExecutionVertex executionVertex : executionJobVertex.getTaskVertices()) {
            long inputBytes = executionVertex.getInputBytes();
            Assertions.assertThat(inputBytes).isNotEqualTo(-1L);
            j2 += inputBytes;
        }
        Assertions.assertThat(j2).isEqualTo(j);
    }

    private void triggerFailedByPartitionNotFound(SchedulerBase schedulerBase, ExecutionVertex executionVertex, ExecutionVertex executionVertex2) {
        transitionExecutionsState(schedulerBase, ExecutionState.FAILED, Collections.singletonList(executionVertex2.getCurrentExecutionAttempt()), new PartitionNotFoundException(new ResultPartitionID(((IntermediateResultPartition) Iterables.getOnlyElement(executionVertex.getProducedPartitions().values())).getPartitionId(), executionVertex.getCurrentExecutionAttempt().getAttemptId())));
    }

    public static void transitionExecutionsState(SchedulerBase schedulerBase, ExecutionState executionState, List<Execution> list, @Nullable Throwable th) {
        TaskExecutionState createFailedTaskExecutionState;
        for (Execution execution : list) {
            if (executionState == ExecutionState.FINISHED) {
                createFailedTaskExecutionState = SchedulerTestingUtils.createFinishedTaskExecutionState(execution.getAttemptId(), createResultPartitionBytesForExecution(execution));
            } else {
                if (executionState != ExecutionState.FAILED) {
                    throw new UnsupportedOperationException("Unsupported state " + executionState);
                }
                createFailedTaskExecutionState = SchedulerTestingUtils.createFailedTaskExecutionState(execution.getAttemptId(), th);
            }
            schedulerBase.updateTaskExecutionState(createFailedTaskExecutionState);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<IntermediateResultPartitionID, ResultPartitionBytes> createResultPartitionBytesForExecution(Execution execution) {
        HashMap hashMap = new HashMap();
        execution.getVertex().getProducedPartitions().forEach((intermediateResultPartitionID, intermediateResultPartition) -> {
            hashMap.put(intermediateResultPartitionID, new ResultPartitionBytes(LongStream.range(0L, intermediateResultPartition.getNumberOfSubpartitions()).boxed().mapToLong(l -> {
                return SUBPARTITION_BYTES;
            }).toArray()));
        });
        return hashMap;
    }

    public static void transitionExecutionsState(SchedulerBase schedulerBase, ExecutionState executionState, JobVertex jobVertex) {
        transitionExecutionsState(schedulerBase, executionState, (List) Arrays.asList(schedulerBase.getExecutionGraph().getJobVertex(jobVertex.getID()).getTaskVertices()).stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList()), null);
    }

    public JobVertex createJobVertex(String str, int i) {
        JobVertex jobVertex = new JobVertex(str);
        jobVertex.setInvokableClass(NoOpInvokable.class);
        if (i > 0) {
            jobVertex.setParallelism(i);
        }
        return jobVertex;
    }

    public JobGraph createJobGraph() {
        JobVertex createJobVertex = createJobVertex("source1", SOURCE_PARALLELISM_1);
        JobVertex createJobVertex2 = createJobVertex("source2", 4);
        JobVertex createJobVertex3 = createJobVertex("sink", -1);
        createJobVertex3.connectNewDataSetAsInput(createJobVertex, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        createJobVertex3.connectNewDataSetAsInput(createJobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        return new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2, createJobVertex3});
    }

    private SchedulerBase createScheduler(JobGraph jobGraph) throws Exception {
        return createScheduler(jobGraph, DefaultSchedulerBuilder.createCustomParallelismDecider(10), ((Integer) BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM.defaultValue()).intValue());
    }

    private SchedulerBase createScheduler(JobGraph jobGraph, VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider, int i) throws Exception {
        return new DefaultSchedulerBuilder(jobGraph, this.mainThreadExecutor, (ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor()).setDelayExecutor(this.taskRestartExecutor).setVertexParallelismAndInputInfosDecider(vertexParallelismAndInputInfosDecider).setDefaultMaxParallelism(i).buildAdaptiveBatchJobScheduler();
    }
}
