package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.runtime.accumulators.AccumulatorSnapshot;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IOMetrics;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.class */
class AdaptiveBatchSchedulerTest {
    private static final int SOURCE_PARALLELISM_1 = 6;
    private static final int SOURCE_PARALLELISM_2 = 4;

    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();
    private static final ComponentMainThreadExecutor mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forMainThread();

    AdaptiveBatchSchedulerTest() {
    }

    @Test
    void testAdaptiveBatchScheduler() throws Exception {
        JobGraph createJobGraph = createJobGraph(false);
        Iterator it = createJobGraph.getVertices().iterator();
        JobVertex jobVertex = (JobVertex) it.next();
        JobVertex jobVertex2 = (JobVertex) it.next();
        JobVertex jobVertex3 = (JobVertex) it.next();
        SchedulerBase createScheduler = createScheduler(createJobGraph);
        ExecutionJobVertex jobVertex4 = createScheduler.getExecutionGraph().getJobVertex(jobVertex3.getID());
        createScheduler.startScheduling();
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex2);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(10);
        Assertions.assertThat(jobVertex3.getParallelism()).isEqualTo(10);
    }

    @Test
    void testDecideParallelismForForwardTarget() throws Exception {
        JobGraph createJobGraph = createJobGraph(true);
        Iterator it = createJobGraph.getVertices().iterator();
        JobVertex jobVertex = (JobVertex) it.next();
        JobVertex jobVertex2 = (JobVertex) it.next();
        JobVertex jobVertex3 = (JobVertex) it.next();
        SchedulerBase createScheduler = createScheduler(createJobGraph);
        ExecutionJobVertex jobVertex4 = createScheduler.getExecutionGraph().getJobVertex(jobVertex3.getID());
        createScheduler.startScheduling();
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(-1);
        transitionExecutionsState(createScheduler, ExecutionState.FINISHED, jobVertex2);
        Assertions.assertThat(jobVertex4.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
        Assertions.assertThat(jobVertex3.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
    }

    public static void transitionExecutionsState(SchedulerBase schedulerBase, ExecutionState executionState, List<Execution> list) {
        Iterator<Execution> it = list.iterator();
        while (it.hasNext()) {
            schedulerBase.updateTaskExecutionState(new TaskExecutionState(it.next().getAttemptId(), executionState, (Throwable) null, (AccumulatorSnapshot) null, new IOMetrics(0L, 0L, 0L, 0L, 0L, 0L, 0L)));
        }
    }

    public static void transitionExecutionsState(SchedulerBase schedulerBase, ExecutionState executionState, JobVertex jobVertex) {
        transitionExecutionsState(schedulerBase, executionState, (List<Execution>) Arrays.asList(schedulerBase.getExecutionGraph().getJobVertex(jobVertex.getID()).getTaskVertices()).stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList()));
    }

    public JobVertex createJobVertex(String str, int i) {
        JobVertex jobVertex = new JobVertex(str);
        jobVertex.setInvokableClass(NoOpInvokable.class);
        if (i > 0) {
            jobVertex.setParallelism(i);
        }
        return jobVertex;
    }

    public JobGraph createJobGraph(boolean z) {
        JobVertex createJobVertex = createJobVertex("source1", SOURCE_PARALLELISM_1);
        JobVertex createJobVertex2 = createJobVertex("source2", 4);
        JobVertex createJobVertex3 = createJobVertex("sink", -1);
        createJobVertex3.connectNewDataSetAsInput(createJobVertex, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        createJobVertex3.connectNewDataSetAsInput(createJobVertex2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
        if (z) {
            ((JobEdge) ((IntermediateDataSet) createJobVertex.getProducedDataSets().get(0)).getConsumers().get(0)).setForward(true);
        }
        return new JobGraph(new JobID(), "test job", new JobVertex[]{createJobVertex, createJobVertex2, createJobVertex3});
    }

    public SchedulerBase createScheduler(JobGraph jobGraph) throws Exception {
        Configuration configuration = new Configuration();
        configuration.set(JobManagerOptions.SCHEDULER, JobManagerOptions.SchedulerType.AdaptiveBatch);
        return new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, (ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor()).setJobMasterConfiguration(configuration).setVertexParallelismDecider(list -> {
            return 10;
        }).buildAdaptiveBatchJobScheduler();
    }
}
