package io.trino.execution;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.concurrent.Threads;
import io.airlift.tracing.Tracing;
import io.opentelemetry.api.trace.Span;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.StringLiteral;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/execution/TestStageStateMachine.class */
public class TestStageStateMachine {
    private static final StageId STAGE_ID = new StageId("query", 0);
    private static final PlanFragment PLAN_FRAGMENT = createValuesPlan();
    private static final SQLException FAILED_CAUSE = new SQLException("FAILED");
    private ExecutorService executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%s"));

    @AfterClass(alwaysRun = true)
    public void tearDown() {
        this.executor.shutdownNow();
        this.executor = null;
    }

    @Test
    public void testBasicStateChanges() {
        StageStateMachine createStageStateMachine = createStageStateMachine();
        assertState(createStageStateMachine, StageState.PLANNED);
        Assert.assertTrue(createStageStateMachine.transitionToScheduling());
        assertState(createStageStateMachine, StageState.SCHEDULING);
        Assert.assertTrue(createStageStateMachine.transitionToRunning());
        assertState(createStageStateMachine, StageState.RUNNING);
        Assert.assertTrue(createStageStateMachine.transitionToPending());
        assertState(createStageStateMachine, StageState.PENDING);
        Assert.assertTrue(createStageStateMachine.transitionToRunning());
        assertState(createStageStateMachine, StageState.RUNNING);
        Assert.assertTrue(createStageStateMachine.transitionToFinished());
        assertState(createStageStateMachine, StageState.FINISHED);
    }

    @Test
    public void testPlanned() {
        assertState(createStageStateMachine(), StageState.PLANNED);
        StageStateMachine createStageStateMachine = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine.transitionToScheduling());
        assertState(createStageStateMachine, StageState.SCHEDULING);
        StageStateMachine createStageStateMachine2 = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine2.transitionToRunning());
        assertState(createStageStateMachine2, StageState.RUNNING);
        StageStateMachine createStageStateMachine3 = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine3.transitionToFinished());
        assertState(createStageStateMachine3, StageState.FINISHED);
        StageStateMachine createStageStateMachine4 = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine4.transitionToFailed(FAILED_CAUSE));
        assertState(createStageStateMachine4, StageState.FAILED);
    }

    @Test
    public void testScheduling() {
        StageStateMachine createStageStateMachine = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine.transitionToScheduling());
        assertState(createStageStateMachine, StageState.SCHEDULING);
        Assert.assertFalse(createStageStateMachine.transitionToScheduling());
        assertState(createStageStateMachine, StageState.SCHEDULING);
        StageStateMachine createStageStateMachine2 = createStageStateMachine();
        createStageStateMachine2.transitionToScheduling();
        Assert.assertTrue(createStageStateMachine2.transitionToRunning());
        assertState(createStageStateMachine2, StageState.RUNNING);
        StageStateMachine createStageStateMachine3 = createStageStateMachine();
        createStageStateMachine3.transitionToScheduling();
        Assert.assertTrue(createStageStateMachine3.transitionToFinished());
        assertState(createStageStateMachine3, StageState.FINISHED);
        StageStateMachine createStageStateMachine4 = createStageStateMachine();
        createStageStateMachine4.transitionToScheduling();
        Assert.assertTrue(createStageStateMachine4.transitionToFailed(FAILED_CAUSE));
        assertState(createStageStateMachine4, StageState.FAILED);
    }

    @Test
    public void testRunning() {
        StageStateMachine createStageStateMachine = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine.transitionToRunning());
        assertState(createStageStateMachine, StageState.RUNNING);
        Assert.assertFalse(createStageStateMachine.transitionToScheduling());
        assertState(createStageStateMachine, StageState.RUNNING);
        Assert.assertFalse(createStageStateMachine.transitionToRunning());
        assertState(createStageStateMachine, StageState.RUNNING);
        Assert.assertTrue(createStageStateMachine.transitionToPending());
        assertState(createStageStateMachine, StageState.PENDING);
        Assert.assertTrue(createStageStateMachine.transitionToRunning());
        assertState(createStageStateMachine, StageState.RUNNING);
        StageStateMachine createStageStateMachine2 = createStageStateMachine();
        createStageStateMachine2.transitionToRunning();
        Assert.assertTrue(createStageStateMachine2.transitionToFinished());
        assertState(createStageStateMachine2, StageState.FINISHED);
        StageStateMachine createStageStateMachine3 = createStageStateMachine();
        createStageStateMachine3.transitionToRunning();
        Assert.assertTrue(createStageStateMachine3.transitionToFailed(FAILED_CAUSE));
        assertState(createStageStateMachine3, StageState.FAILED);
    }

    @Test
    public void testFinished() {
        StageStateMachine createStageStateMachine = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine.transitionToFinished());
        assertFinalState(createStageStateMachine, StageState.FINISHED);
    }

    @Test
    public void testFailed() {
        StageStateMachine createStageStateMachine = createStageStateMachine();
        Assert.assertTrue(createStageStateMachine.transitionToFailed(FAILED_CAUSE));
        assertFinalState(createStageStateMachine, StageState.FAILED);
    }

    private static void assertFinalState(StageStateMachine stageStateMachine, StageState stageState) {
        Assert.assertTrue(stageState.isDone());
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToScheduling());
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToPending());
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToRunning());
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToFinished());
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToFailed(FAILED_CAUSE));
        assertState(stageStateMachine, stageState);
        Assert.assertFalse(stageStateMachine.transitionToFailed(new IOException("failure after finish")));
        assertState(stageStateMachine, stageState);
    }

    private static void assertState(StageStateMachine stageStateMachine, StageState stageState) {
        Assert.assertEquals(stageStateMachine.getStageId(), STAGE_ID);
        StageInfo stageInfo = stageStateMachine.getStageInfo(ImmutableList::of);
        Assert.assertEquals(stageInfo.getStageId(), STAGE_ID);
        Assert.assertEquals(stageInfo.getSubStages(), ImmutableList.of());
        Assert.assertEquals(stageInfo.getTasks(), ImmutableList.of());
        Assert.assertEquals(stageInfo.getTypes(), ImmutableList.of(VarcharType.VARCHAR));
        Assert.assertSame(stageInfo.getPlan(), PLAN_FRAGMENT);
        Assert.assertEquals(stageStateMachine.getState(), stageState);
        Assert.assertEquals(stageInfo.getState(), stageState);
        if (stageState != StageState.FAILED) {
            Assert.assertNull(stageInfo.getFailureCause());
            return;
        }
        ExecutionFailureInfo failureCause = stageInfo.getFailureCause();
        Assert.assertEquals(failureCause.getMessage(), FAILED_CAUSE.getMessage());
        Assert.assertEquals(failureCause.getType(), FAILED_CAUSE.getClass().getName());
    }

    private StageStateMachine createStageStateMachine() {
        return new StageStateMachine(STAGE_ID, PLAN_FRAGMENT, ImmutableMap.of(), this.executor, Tracing.noopTracer(), Span.getInvalid(), new SplitSchedulerStats());
    }

    private static PlanFragment createValuesPlan() {
        Symbol symbol = new Symbol("column");
        PlanNodeId planNodeId = new PlanNodeId("plan");
        return new PlanFragment(new PlanFragmentId("plan"), new ValuesNode(planNodeId, ImmutableList.of(symbol), ImmutableList.of(new Row(ImmutableList.of(new StringLiteral("foo"))))), ImmutableMap.of(symbol, VarcharType.VARCHAR), SystemPartitioningHandle.SOURCE_DISTRIBUTION, Optional.empty(), ImmutableList.of(planNodeId), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), Optional.empty());
    }

    static {
        FAILED_CAUSE.setStackTrace(new StackTraceElement[0]);
    }
}
