package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.Uninterruptibles;
import io.airlift.concurrent.Threads;
import io.airlift.units.Duration;
import io.trino.RowPagesBuilder;
import io.trino.SessionTestUtils;
import io.trino.execution.ScheduledSplit;
import io.trino.execution.SplitAssignment;
import io.trino.metadata.Split;
import io.trino.metadata.TableHandle;
import io.trino.spi.HostAddress;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spi.connector.DynamicFilter;
import io.trino.spi.connector.FixedPageSource;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.split.PageSourceProvider;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
import io.trino.testing.PageConsumerOperator;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingTaskContext;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/operator/TestDriver.class */
public class TestDriver {
    private ExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;
    private DriverContext driverContext;

    /* loaded from: input_file:io/trino/operator/TestDriver$AlwaysBlockedMemoryRevokingTableScanOperator.class */
    private static class AlwaysBlockedMemoryRevokingTableScanOperator extends TableScanOperator {
        public AlwaysBlockedMemoryRevokingTableScanOperator(OperatorContext operatorContext, PlanNodeId planNodeId, PageSourceProvider pageSourceProvider, TableHandle tableHandle, Iterable<ColumnHandle> iterable) {
            super(operatorContext, planNodeId, pageSourceProvider, tableHandle, iterable, DynamicFilter.EMPTY);
        }

        public ListenableFuture<Void> isBlocked() {
            getOperatorContext().localRevocableMemoryContext().setBytes(100L);
            getOperatorContext().requestMemoryRevoking();
            return SettableFuture.create();
        }
    }

    /* loaded from: input_file:io/trino/operator/TestDriver$AlwaysBlockedTableScanOperator.class */
    private static class AlwaysBlockedTableScanOperator extends TableScanOperator {
        public AlwaysBlockedTableScanOperator(OperatorContext operatorContext, PlanNodeId planNodeId, PageSourceProvider pageSourceProvider, TableHandle tableHandle, Iterable<ColumnHandle> iterable) {
            super(operatorContext, planNodeId, pageSourceProvider, tableHandle, iterable, DynamicFilter.EMPTY);
        }

        public ListenableFuture<Void> isBlocked() {
            return SettableFuture.create();
        }
    }

    /* loaded from: input_file:io/trino/operator/TestDriver$BlockedSinkOperator.class */
    private static class BlockedSinkOperator extends PageConsumerOperator {
        private final SettableFuture<Void> finished;

        public BlockedSinkOperator(OperatorContext operatorContext, Consumer<Page> consumer, Function<Page, Page> function) {
            super(operatorContext, consumer, function);
            this.finished = SettableFuture.create();
            operatorContext.setFinishedFuture(this.finished);
        }

        public boolean isFinished() {
            return this.finished.isDone();
        }

        void setFinished() {
            this.finished.set((Object) null);
        }
    }

    /* loaded from: input_file:io/trino/operator/TestDriver$BrokenOperator.class */
    private static class BrokenOperator implements Operator {
        private final OperatorContext operatorContext;
        private final ReentrantLock lock;
        private final CountDownLatch lockedLatch;
        private final CountDownLatch unlockLatch;
        private final boolean lockForClose;

        private BrokenOperator(OperatorContext operatorContext) {
            this(operatorContext, false);
        }

        private BrokenOperator(OperatorContext operatorContext, boolean z) {
            this.lock = new ReentrantLock();
            this.lockedLatch = new CountDownLatch(1);
            this.unlockLatch = new CountDownLatch(1);
            this.operatorContext = operatorContext;
            this.lockForClose = z;
        }

        public OperatorContext getOperatorContext() {
            return this.operatorContext;
        }

        public void unlock() {
            this.unlockLatch.countDown();
        }

        private void waitForLocked() {
            try {
                Assert.assertTrue(this.lockedLatch.await(10L, TimeUnit.SECONDS));
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException("Interrupted", e);
            }
        }

        private void waitForUnlock() {
            try {
                Assert.assertTrue(this.lock.tryLock(1L, TimeUnit.SECONDS));
                try {
                    this.lockedLatch.countDown();
                    Assert.assertTrue(this.unlockLatch.await(5L, TimeUnit.SECONDS));
                    this.lock.unlock();
                } catch (Throwable th) {
                    this.lock.unlock();
                    throw th;
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException("Interrupted", e);
            }
        }

        public void finish() {
            waitForUnlock();
        }

        public boolean isFinished() {
            waitForUnlock();
            return true;
        }

        public ListenableFuture<Void> isBlocked() {
            waitForUnlock();
            return NOT_BLOCKED;
        }

        public boolean needsInput() {
            waitForUnlock();
            return false;
        }

        public void addInput(Page page) {
            waitForUnlock();
        }

        public Page getOutput() {
            waitForUnlock();
            return null;
        }

        public void close() {
            if (this.lockForClose) {
                waitForUnlock();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/TestDriver$MockSplit.class */
    public static class MockSplit implements ConnectorSplit {
        private MockSplit() {
        }

        public boolean isRemotelyAccessible() {
            return false;
        }

        public List<HostAddress> getAddresses() {
            return ImmutableList.of();
        }

        public Object getInfo() {
            return null;
        }

        public long getRetainedSizeInBytes() {
            return 0L;
        }
    }

    /* loaded from: input_file:io/trino/operator/TestDriver$NotBlockedTableScanOperator.class */
    private static class NotBlockedTableScanOperator extends TableScanOperator {
        public NotBlockedTableScanOperator(OperatorContext operatorContext, PlanNodeId planNodeId, PageSourceProvider pageSourceProvider, TableHandle tableHandle, Iterable<ColumnHandle> iterable) {
            super(operatorContext, planNodeId, pageSourceProvider, tableHandle, iterable, DynamicFilter.EMPTY);
        }

        public ListenableFuture<Void> isBlocked() {
            return NOT_BLOCKED;
        }
    }

    @BeforeMethod
    public void setUp() {
        this.executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%s"));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s"));
        this.driverContext = TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
    }

    @AfterMethod(alwaysRun = true)
    public void tearDown() {
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @Test
    public void testNormalFinish() {
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        ValuesOperator valuesOperator = new ValuesOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "values"), RowPagesBuilder.rowPagesBuilder((Iterable<Type>) of).addSequencePage(10, 20, 30, 40).build());
        Operator createSinkOperator = createSinkOperator(of);
        Driver createDriver = Driver.createDriver(this.driverContext, valuesOperator, new Operator[]{createSinkOperator});
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Assert.assertFalse(createDriver.isFinished());
        Assert.assertTrue(createDriver.processForDuration(new Duration(1.0d, TimeUnit.SECONDS)).isDone());
        Assert.assertTrue(createDriver.isFinished());
        Assert.assertTrue(createSinkOperator.isFinished());
        Assert.assertTrue(valuesOperator.isFinished());
    }

    @Test(invocationCount = 1000, timeOut = 10000)
    public void testConcurrentClose() {
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        Driver createDriver = Driver.createDriver(this.driverContext, new ValuesOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "values"), RowPagesBuilder.rowPagesBuilder((Iterable<Type>) of).addSequencePage(10, 20, 30, 40).build()), new Operator[]{createSinkOperator(of)});
        this.scheduledExecutor.submit(() -> {
            return createDriver.processForDuration(new Duration(1.0d, TimeUnit.NANOSECONDS));
        });
        ScheduledExecutorService scheduledExecutorService = this.scheduledExecutor;
        Objects.requireNonNull(createDriver);
        scheduledExecutorService.submit(createDriver::close);
        while (!this.driverContext.isTerminatingOrDone()) {
            Uninterruptibles.sleepUninterruptibly(1L, TimeUnit.MILLISECONDS);
        }
    }

    @Test
    public void testAbruptFinish() {
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        ValuesOperator valuesOperator = new ValuesOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "values"), RowPagesBuilder.rowPagesBuilder((Iterable<Type>) of).addSequencePage(10, 20, 30, 40).build());
        Operator createSinkOperator = createSinkOperator(of);
        Driver createDriver = Driver.createDriver(this.driverContext, valuesOperator, new Operator[]{createSinkOperator});
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Assert.assertFalse(createDriver.isFinished());
        createDriver.close();
        Assert.assertTrue(createDriver.isFinished());
        Assert.assertFalse(valuesOperator.isFinished());
        Assert.assertFalse(createSinkOperator.isFinished());
        Assert.assertTrue(createSinkOperator.isClosed());
    }

    @Test
    public void testAddSourceFinish() {
        PlanNodeId planNodeId = new PlanNodeId("source");
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        TableScanOperator tableScanOperator = new TableScanOperator(this.driverContext.addOperatorContext(99, new PlanNodeId("test"), "values"), planNodeId, (session, split, tableHandle, list, dynamicFilter) -> {
            return new FixedPageSource(RowPagesBuilder.rowPagesBuilder(of).addSequencePage(10, 20, 30, 40).build());
        }, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of(), DynamicFilter.EMPTY);
        Operator createSinkOperator = createSinkOperator(of);
        Driver createDriver = Driver.createDriver(this.driverContext, tableScanOperator, new Operator[]{createSinkOperator});
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Assert.assertFalse(createDriver.isFinished());
        Assert.assertFalse(createDriver.processForDuration(new Duration(1.0d, TimeUnit.MILLISECONDS)).isDone());
        Assert.assertFalse(createDriver.isFinished());
        createDriver.updateSplitAssignment(new SplitAssignment(planNodeId, ImmutableSet.of(new ScheduledSplit(0L, planNodeId, newMockSplit())), true));
        Assert.assertFalse(createDriver.isFinished());
        Assert.assertTrue(createDriver.processForDuration(new Duration(1.0d, TimeUnit.SECONDS)).isDone());
        Assert.assertTrue(createDriver.isFinished());
        Assert.assertTrue(createSinkOperator.isFinished());
        Assert.assertTrue(tableScanOperator.isFinished());
    }

    @Test
    public void testBrokenOperatorCloseWhileProcessing() {
        BrokenOperator brokenOperator = new BrokenOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "source"), false);
        Driver createDriver = Driver.createDriver(this.driverContext, brokenOperator, new Operator[]{createSinkOperator(ImmutableList.of())});
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Future submit = this.executor.submit(() -> {
            return Boolean.valueOf(createDriver.processForDuration(new Duration(1.0d, TimeUnit.MILLISECONDS)).isDone());
        });
        brokenOperator.waitForLocked();
        createDriver.close();
        Assert.assertTrue(createDriver.isFinished());
        Assertions.assertThatThrownBy(() -> {
            submit.get(1L, TimeUnit.SECONDS);
        }).isInstanceOf(ExecutionException.class).hasCause(new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Driver was interrupted"));
        Assert.assertTrue(createDriver.getDestroyedFuture().isDone());
    }

    @Test
    public void testBrokenOperatorProcessWhileClosing() throws Exception {
        BrokenOperator brokenOperator = new BrokenOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "source"), true);
        Driver createDriver = Driver.createDriver(this.driverContext, brokenOperator, new Operator[]{createSinkOperator(ImmutableList.of())});
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Future submit = this.executor.submit(() -> {
            createDriver.close();
            return true;
        });
        brokenOperator.waitForLocked();
        Assert.assertTrue(createDriver.processForDuration(new Duration(1.0d, TimeUnit.MILLISECONDS)).isDone());
        Assert.assertTrue(createDriver.isFinished());
        Assert.assertFalse(createDriver.getDestroyedFuture().isDone());
        brokenOperator.unlock();
        Assert.assertTrue(((Boolean) submit.get()).booleanValue());
        Assert.assertTrue(createDriver.getDestroyedFuture().isDone());
    }

    @Test
    public void testMemoryRevocationRace() {
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        Assert.assertTrue(Driver.createDriver(this.driverContext, new AlwaysBlockedMemoryRevokingTableScanOperator(this.driverContext.addOperatorContext(99, new PlanNodeId("test"), "scan"), new PlanNodeId("source"), (session, split, tableHandle, list, dynamicFilter) -> {
            return new FixedPageSource(RowPagesBuilder.rowPagesBuilder(of).addSequencePage(10, 20, 30, 40).build());
        }, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of()), new Operator[]{createSinkOperator(of)}).processForDuration(new Duration(100.0d, TimeUnit.MILLISECONDS)).isDone());
    }

    @Test
    public void testUnblocksOnFinish() {
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        AlwaysBlockedTableScanOperator alwaysBlockedTableScanOperator = new AlwaysBlockedTableScanOperator(this.driverContext.addOperatorContext(99, new PlanNodeId("test"), "scan"), new PlanNodeId("source"), (session, split, tableHandle, list, dynamicFilter) -> {
            return new FixedPageSource(RowPagesBuilder.rowPagesBuilder(of).addSequencePage(10, 20, 30, 40).build());
        }, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of());
        MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(this.driverContext.getSession(), of);
        OperatorContext addOperatorContext = this.driverContext.addOperatorContext(1, new PlanNodeId("test"), "sink");
        Objects.requireNonNull(resultBuilder);
        Operator blockedSinkOperator = new BlockedSinkOperator(addOperatorContext, resultBuilder::page, Function.identity());
        ListenableFuture processForDuration = Driver.createDriver(this.driverContext, alwaysBlockedTableScanOperator, new Operator[]{blockedSinkOperator}).processForDuration(new Duration(100.0d, TimeUnit.MILLISECONDS));
        Assert.assertFalse(processForDuration.isDone());
        blockedSinkOperator.setFinished();
        Assert.assertTrue(processForDuration.isDone());
    }

    @Test
    public void testBrokenOperatorAddSource() {
        PlanNodeId planNodeId = new PlanNodeId("source");
        ImmutableList of = ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT);
        NotBlockedTableScanOperator notBlockedTableScanOperator = new NotBlockedTableScanOperator(this.driverContext.addOperatorContext(99, new PlanNodeId("test"), "values"), planNodeId, (session, split, tableHandle, list, dynamicFilter) -> {
            return new FixedPageSource(RowPagesBuilder.rowPagesBuilder(of).addSequencePage(10, 20, 30, 40).build());
        }, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of());
        BrokenOperator brokenOperator = new BrokenOperator(this.driverContext.addOperatorContext(0, new PlanNodeId("test"), "source"));
        Driver createDriver = Driver.createDriver(this.driverContext, notBlockedTableScanOperator, new Operator[]{brokenOperator});
        Future submit = this.executor.submit(() -> {
            return Boolean.valueOf(createDriver.processForDuration(new Duration(1.0d, TimeUnit.MILLISECONDS)).isDone());
        });
        brokenOperator.waitForLocked();
        Assert.assertSame(createDriver.getDriverContext(), this.driverContext);
        Assert.assertFalse(createDriver.isFinished());
        Assert.assertTrue(createDriver.processForDuration(new Duration(1.0d, TimeUnit.MILLISECONDS)).isDone());
        Assert.assertFalse(createDriver.isFinished());
        createDriver.updateSplitAssignment(new SplitAssignment(planNodeId, ImmutableSet.of(new ScheduledSplit(0L, planNodeId, newMockSplit())), true));
        Assert.assertFalse(createDriver.getDestroyedFuture().isDone());
        Assert.assertTrue(createDriver.processForDuration(new Duration(1.0d, TimeUnit.SECONDS)).isDone());
        Assert.assertFalse(createDriver.isFinished());
        createDriver.close();
        Assert.assertTrue(createDriver.isFinished());
        Assertions.assertThatThrownBy(() -> {
            submit.get(1L, TimeUnit.SECONDS);
        }).isInstanceOf(ExecutionException.class).hasCause(new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Driver was interrupted"));
        Assert.assertTrue(createDriver.getDestroyedFuture().isDone());
    }

    private static Split newMockSplit() {
        return new Split(TestingHandles.TEST_CATALOG_HANDLE, new MockSplit());
    }

    private PageConsumerOperator createSinkOperator(List<Type> list) {
        MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(this.driverContext.getSession(), list);
        OperatorContext addOperatorContext = this.driverContext.addOperatorContext(1, new PlanNodeId("test"), "sink");
        Objects.requireNonNull(resultBuilder);
        return new PageConsumerOperator(addOperatorContext, resultBuilder::page, Function.identity());
    }
}
