package io.trino.execution;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.node.NodeInfo;
import io.airlift.slice.Slice;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.tracing.Tracing;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.connector.CatalogProperties;
import io.trino.connector.ConnectorServices;
import io.trino.connector.ConnectorServicesProvider;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PagesSerdeUtil;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.LocalMemoryManager;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.metadata.InternalNode;
import io.trino.metadata.WorkerLanguageFunctionProvider;
import io.trino.operator.DirectExchangeClient;
import io.trino.operator.DirectExchangeClientSupplier;
import io.trino.operator.RetryPolicy;
import io.trino.spi.QueryId;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.exchange.ExchangeId;
import io.trino.spiller.LocalSpillManager;
import io.trino.spiller.NodeSpillConfig;
import io.trino.testing.TestingSession;
import io.trino.version.EmbedVersion;
import java.net.URI;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.CONCURRENT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/execution/BaseTestSqlTaskManager.class */
public abstract class BaseTestSqlTaskManager {
    public static final PipelinedOutputBuffers.OutputBufferId OUT = new PipelinedOutputBuffers.OutputBufferId(0);
    private final AtomicInteger sequence = new AtomicInteger();
    private TaskExecutor taskExecutor;
    private TaskManagementExecutor taskManagementExecutor;

    /* loaded from: input_file:io/trino/execution/BaseTestSqlTaskManager$MockDirectExchangeClientSupplier.class */
    public static class MockDirectExchangeClientSupplier implements DirectExchangeClientSupplier {
        public DirectExchangeClient get(QueryId queryId, ExchangeId exchangeId, LocalMemoryContext localMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:io/trino/execution/BaseTestSqlTaskManager$MockLocationFactory.class */
    public static class MockLocationFactory implements LocationFactory {
        public URI createQueryLocation(QueryId queryId) {
            return URI.create("http://fake.invalid/query/" + queryId);
        }

        public URI createLocalTaskLocation(TaskId taskId) {
            return URI.create("http://fake.invalid/task/" + taskId);
        }

        public URI createTaskLocation(InternalNode internalNode, TaskId taskId) {
            return URI.create("http://fake.invalid/task/" + internalNode.getNodeIdentifier() + "/" + taskId);
        }

        public URI createMemoryInfoLocation(InternalNode internalNode) {
            return URI.create("http://fake.invalid/" + internalNode.getNodeIdentifier() + "/memory");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/BaseTestSqlTaskManager$NoConnectorServicesProvider.class */
    public static class NoConnectorServicesProvider implements ConnectorServicesProvider {
        private NoConnectorServicesProvider() {
        }

        public void loadInitialCatalogs() {
        }

        public void ensureCatalogsLoaded(Session session, List<CatalogProperties> list) {
        }

        public void pruneCatalogs(Set<CatalogHandle> set) {
            throw new UnsupportedOperationException();
        }

        public ConnectorServices getConnectorServices(CatalogHandle catalogHandle) {
            throw new UnsupportedOperationException();
        }
    }

    protected abstract TaskExecutor createTaskExecutor();

    @BeforeAll
    public void setUp() {
        this.taskExecutor = createTaskExecutor();
        this.taskExecutor.start();
        this.taskManagementExecutor = new TaskManagementExecutor();
    }

    @AfterAll
    public void tearDown() {
        this.taskExecutor.stop();
        this.taskExecutor = null;
        this.taskManagementExecutor.close();
        this.taskManagementExecutor = null;
    }

    @Test
    public void testEmptyQuery() {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig());
        try {
            TaskId newTaskId = newTaskId();
            Assertions.assertThat(createTask(createSqlTaskManager, newTaskId, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds()).getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(createSqlTaskManager.getTaskInfo(newTaskId).getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(createTask(createSqlTaskManager, newTaskId, ImmutableSet.of(), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds()).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            Assertions.assertThat(createSqlTaskManager.getTaskInfo(newTaskId).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Timeout(30)
    @Test
    public void testSimpleQuery() throws Exception {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig());
        try {
            TaskId newTaskId = newTaskId();
            createTask(createSqlTaskManager, newTaskId, ImmutableSet.of(TaskTestUtils.SPLIT), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            TaskInfo taskInfo = (TaskInfo) createSqlTaskManager.getTaskInfo(newTaskId, 0L).get();
            Assertions.assertThat(taskInfo.getTaskStatus().getState()).isEqualTo(TaskState.FLUSHING);
            BufferResult bufferResult = (BufferResult) createSqlTaskManager.getTaskResults(newTaskId, OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE)).getResultsFuture().get();
            Assertions.assertThat(bufferResult.isBufferComplete()).isFalse();
            Assertions.assertThat(bufferResult.getSerializedPages().size()).isEqualTo(1);
            Assertions.assertThat(PagesSerdeUtil.getSerializedPagePositionCount((Slice) bufferResult.getSerializedPages().get(0))).isEqualTo(1);
            boolean z = true;
            while (z) {
                bufferResult = (BufferResult) createSqlTaskManager.getTaskResults(newTaskId, OUT, bufferResult.getToken() + bufferResult.getSerializedPages().size(), DataSize.of(1L, DataSize.Unit.MEGABYTE)).getResultsFuture().get();
                z = !bufferResult.isBufferComplete();
            }
            Assertions.assertThat(bufferResult.isBufferComplete()).isTrue();
            Assertions.assertThat(bufferResult.getSerializedPages().size()).isEqualTo(0);
            Assertions.assertThat(createSqlTaskManager.destroyTaskResults(newTaskId, OUT).getOutputBuffers().getState()).isEqualTo(BufferState.FINISHED);
            Assertions.assertThat(((TaskInfo) createSqlTaskManager.getTaskInfo(newTaskId, taskInfo.getTaskStatus().getVersion()).get()).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            Assertions.assertThat(createSqlTaskManager.getTaskInfo(newTaskId).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testCancel() throws InterruptedException, ExecutionException, TimeoutException {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig());
        try {
            TaskId newTaskId = newTaskId();
            TaskInfo createTask = createTask(createSqlTaskManager, newTaskId, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            Assertions.assertThat(createTask.getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(createTask.getStats().getEndTime()).isNull();
            TaskInfo taskInfo = createSqlTaskManager.getTaskInfo(newTaskId);
            Assertions.assertThat(taskInfo.getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(taskInfo.getStats().getEndTime()).isNull();
            TaskInfo pollTerminatingTaskInfoUntilDone = pollTerminatingTaskInfoUntilDone(createSqlTaskManager, createSqlTaskManager.cancelTask(newTaskId));
            Assertions.assertThat(pollTerminatingTaskInfoUntilDone.getTaskStatus().getState()).isEqualTo(TaskState.CANCELED);
            Assertions.assertThat(pollTerminatingTaskInfoUntilDone.getStats().getEndTime()).isNotNull();
            TaskInfo taskInfo2 = createSqlTaskManager.getTaskInfo(newTaskId);
            Assertions.assertThat(taskInfo2.getTaskStatus().getState()).isEqualTo(TaskState.CANCELED);
            Assertions.assertThat(taskInfo2.getStats().getEndTime()).isNotNull();
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testAbort() throws InterruptedException, ExecutionException, TimeoutException {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig());
        try {
            TaskId newTaskId = newTaskId();
            TaskInfo createTask = createTask(createSqlTaskManager, newTaskId, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            Assertions.assertThat(createTask.getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(createTask.getStats().getEndTime()).isNull();
            TaskInfo taskInfo = createSqlTaskManager.getTaskInfo(newTaskId);
            Assertions.assertThat(taskInfo.getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(taskInfo.getStats().getEndTime()).isNull();
            TaskInfo pollTerminatingTaskInfoUntilDone = pollTerminatingTaskInfoUntilDone(createSqlTaskManager, createSqlTaskManager.abortTask(newTaskId));
            Assertions.assertThat(pollTerminatingTaskInfoUntilDone.getTaskStatus().getState()).isEqualTo(TaskState.ABORTED);
            Assertions.assertThat(pollTerminatingTaskInfoUntilDone.getStats().getEndTime()).isNotNull();
            TaskInfo taskInfo2 = createSqlTaskManager.getTaskInfo(newTaskId);
            Assertions.assertThat(taskInfo2.getTaskStatus().getState()).isEqualTo(TaskState.ABORTED);
            Assertions.assertThat(taskInfo2.getStats().getEndTime()).isNotNull();
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Timeout(30)
    @Test
    public void testAbortResults() throws Exception {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig());
        try {
            TaskId newTaskId = newTaskId();
            createTask(createSqlTaskManager, newTaskId, ImmutableSet.of(TaskTestUtils.SPLIT), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            TaskInfo taskInfo = (TaskInfo) createSqlTaskManager.getTaskInfo(newTaskId, 0L).get();
            Assertions.assertThat(taskInfo.getTaskStatus().getState()).isEqualTo(TaskState.FLUSHING);
            createSqlTaskManager.destroyTaskResults(newTaskId, OUT);
            Assertions.assertThat(((TaskInfo) createSqlTaskManager.getTaskInfo(newTaskId, taskInfo.getTaskStatus().getVersion()).get()).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            Assertions.assertThat(createSqlTaskManager.getTaskInfo(newTaskId).getTaskStatus().getState()).isEqualTo(TaskState.FINISHED);
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testRemoveOldTasks() throws InterruptedException, ExecutionException, TimeoutException {
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig().setInfoMaxAge(new Duration(5.0d, TimeUnit.MILLISECONDS)));
        try {
            TaskId newTaskId = newTaskId();
            Assertions.assertThat(createTask(createSqlTaskManager, newTaskId, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()).getTaskStatus().getState()).isEqualTo(TaskState.RUNNING);
            Assertions.assertThat(pollTerminatingTaskInfoUntilDone(createSqlTaskManager, createSqlTaskManager.cancelTask(newTaskId)).getTaskStatus().getState()).isEqualTo(TaskState.CANCELED);
            Assertions.assertThat(createSqlTaskManager.getTaskInfo(newTaskId).getTaskStatus().getState()).isEqualTo(TaskState.CANCELED);
            Thread.sleep(100L);
            createSqlTaskManager.removeOldTasks();
            Iterator it = createSqlTaskManager.getAllTaskInfo().iterator();
            while (it.hasNext()) {
                Assertions.assertThat(((TaskInfo) it.next()).getTaskStatus().getTaskId()).isNotEqualTo(newTaskId);
            }
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testSessionPropertyMemoryLimitOverride() {
        NodeMemoryConfig maxQueryMemoryPerNode = new NodeMemoryConfig().setMaxQueryMemoryPerNode(DataSize.ofBytes(3L));
        SqlTaskManager createSqlTaskManager = createSqlTaskManager(new TaskManagerConfig(), maxQueryMemoryPerNode);
        try {
            TaskId taskId = new TaskId(new StageId("q1", 0), 1, 0);
            TaskId taskId2 = new TaskId(new StageId("q2", 0), 1, 0);
            QueryContext queryContext = createSqlTaskManager.getQueryContext(taskId.getQueryId());
            QueryContext queryContext2 = createSqlTaskManager.getQueryContext(taskId2.getQueryId());
            Assertions.assertThat(queryContext.isMemoryLimitsInitialized()).isFalse();
            Assertions.assertThat(queryContext.getMaxUserMemory()).isEqualTo(maxQueryMemoryPerNode.getMaxQueryMemoryPerNode().toBytes());
            Assertions.assertThat(queryContext2.isMemoryLimitsInitialized()).isFalse();
            Assertions.assertThat(queryContext2.getMaxUserMemory()).isEqualTo(maxQueryMemoryPerNode.getMaxQueryMemoryPerNode().toBytes());
            createSqlTaskManager.updateTask(TestingSession.testSessionBuilder().setSystemProperty("query_max_memory_per_node", "1B").build(), taskId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(TaskTestUtils.SPLIT), true)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of(), false);
            Assertions.assertThat(queryContext.isMemoryLimitsInitialized()).isTrue();
            Assertions.assertThat(queryContext.getMaxUserMemory()).isEqualTo(1L);
            createSqlTaskManager.updateTask(TestingSession.testSessionBuilder().setSystemProperty("query_max_memory_per_node", "10B").build(), taskId2, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(TaskTestUtils.SPLIT), true)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of(), false);
            Assertions.assertThat(queryContext2.isMemoryLimitsInitialized()).isTrue();
            Assertions.assertThat(queryContext2.getMaxUserMemory()).isEqualTo(maxQueryMemoryPerNode.getMaxQueryMemoryPerNode().toBytes());
            if (createSqlTaskManager != null) {
                createSqlTaskManager.close();
            }
        } catch (Throwable th) {
            if (createSqlTaskManager != null) {
                try {
                    createSqlTaskManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private SqlTaskManager createSqlTaskManager(TaskManagerConfig taskManagerConfig) {
        return createSqlTaskManager(taskManagerConfig, new NodeMemoryConfig());
    }

    private SqlTaskManager createSqlTaskManager(TaskManagerConfig taskManagerConfig, NodeMemoryConfig nodeMemoryConfig) {
        return new SqlTaskManager(new EmbedVersion("testversion"), new NoConnectorServicesProvider(), TaskTestUtils.createTestingPlanner(), new WorkerLanguageFunctionProvider(), new MockLocationFactory(), this.taskExecutor, TaskTestUtils.createTestSplitMonitor(), new NodeInfo("test"), new LocalMemoryManager(nodeMemoryConfig), this.taskManagementExecutor, taskManagerConfig, nodeMemoryConfig, new LocalSpillManager(new NodeSpillConfig()), new NodeSpillConfig(), new TestingGcMonitor(), Tracing.noopTracer(), new ExchangeManagerRegistry());
    }

    private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, ImmutableSet<ScheduledSplit> immutableSet, OutputBuffers outputBuffers) {
        return sqlTaskManager.updateTask(SessionTestUtils.TEST_SESSION, taskId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, immutableSet, true)), outputBuffers, ImmutableMap.of(), false);
    }

    private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, OutputBuffers outputBuffers) {
        sqlTaskManager.getQueryContext(taskId.getQueryId()).addTaskContext(new TaskStateMachine(taskId, MoreExecutors.directExecutor()), TestingSession.testSessionBuilder().build(), () -> {
        }, false, false);
        return sqlTaskManager.updateTask(SessionTestUtils.TEST_SESSION, taskId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(), outputBuffers, ImmutableMap.of(), false);
    }

    private static TaskInfo pollTerminatingTaskInfoUntilDone(SqlTaskManager sqlTaskManager, TaskInfo taskInfo) throws InterruptedException, ExecutionException, TimeoutException {
        Assertions.assertThat(taskInfo.getTaskStatus().getState().isTerminatingOrDone()).isTrue();
        for (int i = 3; i > 0 && taskInfo.getTaskStatus().getState().isTerminating(); i--) {
            taskInfo = (TaskInfo) sqlTaskManager.getTaskInfo(taskInfo.getTaskStatus().getTaskId(), taskInfo.getTaskStatus().getVersion()).get(5L, TimeUnit.SECONDS);
        }
        return taskInfo;
    }

    private TaskId newTaskId() {
        return new TaskId(new StageId("query" + this.sequence.incrementAndGet(), 0), 1, 0);
    }
}
