package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.Collections;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.class */
class AdaptiveExecutionPlanSchedulingContextTest {

    @RegisterExtension
    private static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    AdaptiveExecutionPlanSchedulingContextTest() {
    }

    @Test
    void testGetParallelismAndMaxParallelism() throws DynamicCodeLoadingException {
        DefaultAdaptiveExecutionHandler defaultAdaptiveExecutionHandler = getDefaultAdaptiveExecutionHandler(4, 5);
        ExecutionPlanSchedulingContext createExecutionPlanSchedulingContext = defaultAdaptiveExecutionHandler.createExecutionPlanSchedulingContext(100);
        JobVertex jobVertex = (JobVertex) defaultAdaptiveExecutionHandler.getJobGraph().getVerticesSortedTopologicallyFromSources().get(0);
        IntermediateDataSet intermediateDataSet = (IntermediateDataSet) jobVertex.getProducedDataSets().get(0);
        Assertions.assertThat(createExecutionPlanSchedulingContext.getConsumersParallelism(jobVertexID -> {
            return 123;
        }, intermediateDataSet)).isEqualTo(4);
        Assertions.assertThat(createExecutionPlanSchedulingContext.getConsumersMaxParallelism(jobVertexID2 -> {
            return 456;
        }, intermediateDataSet)).isEqualTo(5);
        defaultAdaptiveExecutionHandler.handleJobEvent(new ExecutionJobVertexFinishedEvent(jobVertex.getID(), Collections.emptyMap()));
        Assertions.assertThat(createExecutionPlanSchedulingContext.getConsumersParallelism(jobVertexID3 -> {
            return 123;
        }, intermediateDataSet)).isEqualTo(123);
        Assertions.assertThat(createExecutionPlanSchedulingContext.getConsumersMaxParallelism(jobVertexID4 -> {
            return 456;
        }, intermediateDataSet)).isEqualTo(456);
    }

    @Test
    void testGetDefaultMaxParallelismWhenParallelismGreaterThanZero() throws DynamicCodeLoadingException {
        DefaultAdaptiveExecutionHandler defaultAdaptiveExecutionHandler = getDefaultAdaptiveExecutionHandler(4, -1);
        Assertions.assertThat(defaultAdaptiveExecutionHandler.createExecutionPlanSchedulingContext(100).getConsumersMaxParallelism(jobVertexID -> {
            return 123;
        }, (IntermediateDataSet) ((JobVertex) defaultAdaptiveExecutionHandler.getJobGraph().getVerticesSortedTopologicallyFromSources().get(0)).getProducedDataSets().get(0))).isEqualTo(SchedulerBase.getDefaultMaxParallelism(4));
    }

    @Test
    void testGetDefaultMaxParallelismWhenParallelismLessThanZero() throws DynamicCodeLoadingException {
        DefaultAdaptiveExecutionHandler defaultAdaptiveExecutionHandler = getDefaultAdaptiveExecutionHandler(-1, -1);
        Assertions.assertThat(defaultAdaptiveExecutionHandler.createExecutionPlanSchedulingContext(100).getConsumersMaxParallelism(jobVertexID -> {
            return 123;
        }, (IntermediateDataSet) ((JobVertex) defaultAdaptiveExecutionHandler.getJobGraph().getVerticesSortedTopologicallyFromSources().get(0)).getProducedDataSets().get(0))).isEqualTo(100);
    }

    @Test
    public void testGetPendingOperatorCount() throws DynamicCodeLoadingException {
        DefaultAdaptiveExecutionHandler defaultAdaptiveExecutionHandler = getDefaultAdaptiveExecutionHandler();
        ExecutionPlanSchedulingContext createExecutionPlanSchedulingContext = defaultAdaptiveExecutionHandler.createExecutionPlanSchedulingContext(1);
        Assertions.assertThat(createExecutionPlanSchedulingContext.getPendingOperatorCount()).isEqualTo(1);
        defaultAdaptiveExecutionHandler.handleJobEvent(new ExecutionJobVertexFinishedEvent(((JobVertex) defaultAdaptiveExecutionHandler.getJobGraph().getVertices().iterator().next()).getID(), Collections.emptyMap()));
        Assertions.assertThat(createExecutionPlanSchedulingContext.getPendingOperatorCount()).isEqualTo(0);
    }

    private static DefaultAdaptiveExecutionHandler getDefaultAdaptiveExecutionHandler() throws DynamicCodeLoadingException {
        return getDefaultAdaptiveExecutionHandler(2, 2);
    }

    private static DefaultAdaptiveExecutionHandler getDefaultAdaptiveExecutionHandler(int i, int i2) throws DynamicCodeLoadingException {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.fromSequence(0L, 1L).disableChaining().print();
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        for (StreamNode streamNode : streamGraph.getStreamNodes()) {
            if (streamNode.getOperatorName().contains("Sink")) {
                streamNode.setParallelism(Integer.valueOf(i));
                if (i2 > 0) {
                    streamNode.setMaxParallelism(i2);
                }
            }
        }
        return new DefaultAdaptiveExecutionHandler(Thread.currentThread().getContextClassLoader(), streamGraph, EXECUTOR_RESOURCE.getExecutor());
    }
}
