package io.trino.execution.scheduler;

import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.testing.TestingTicker;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.client.NodeVersion;
import io.trino.connector.CatalogHandle;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.DynamicFilterConfig;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import io.trino.execution.TestingRemoteTaskFactory;
import io.trino.execution.scheduler.FaultTolerantStageScheduler;
import io.trino.execution.scheduler.NodeAllocator;
import io.trino.execution.scheduler.TestingExchange;
import io.trino.execution.scheduler.TestingNodeSelectorFactory;
import io.trino.failuredetector.NoOpFailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.exchange.Exchange;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingMetadata;
import io.trino.testing.TestingSession;
import io.trino.testing.TestingSplit;
import io.trino.util.FinalizerService;
import java.net.URI;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/execution/scheduler/TestFaultTolerantStageScheduler.class */
public class TestFaultTolerantStageScheduler {
    private static final QueryId QUERY_ID = new QueryId("query");
    private static final Session SESSION = TestingSession.testSessionBuilder().setQueryId(QUERY_ID).build();
    private static final StageId STAGE_ID = new StageId(QUERY_ID, 0);
    private static final PlanFragmentId FRAGMENT_ID = new PlanFragmentId("0");
    private static final PlanFragmentId SOURCE_FRAGMENT_ID_1 = new PlanFragmentId("1");
    private static final PlanFragmentId SOURCE_FRAGMENT_ID_2 = new PlanFragmentId("2");
    private static final PlanNodeId TABLE_SCAN_NODE_ID = new PlanNodeId("table_scan_id");
    private static final InternalNode NODE_1 = new InternalNode("node-1", URI.create("local://127.0.0.1:8080"), NodeVersion.UNKNOWN, false);
    private static final InternalNode NODE_2 = new InternalNode("node-2", URI.create("local://127.0.0.1:8081"), NodeVersion.UNKNOWN, false);
    private static final InternalNode NODE_3 = new InternalNode("node-3", URI.create("local://127.0.0.1:8082"), NodeVersion.UNKNOWN, false);
    private static final PlannerContext PLANNER_CONTEXT = TestingPlannerContext.plannerContextBuilder().build();
    private FinalizerService finalizerService;
    private NodeTaskMap nodeTaskMap;
    private FixedCountNodeAllocatorService nodeAllocatorService;
    private TestingTicker ticker;
    private TestFutureCompletor futureCompletor;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/TestFaultTolerantStageScheduler$TestFutureCompletor.class */
    public static class TestFutureCompletor implements FaultTolerantStageScheduler.DelayedFutureCompletor {
        private final Stopwatch stopwatch;
        private final Set<Entry> entries = Sets.newConcurrentHashSet();

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/execution/scheduler/TestFaultTolerantStageScheduler$TestFutureCompletor$Entry.class */
        public static class Entry {
            private final SettableFuture<Void> future;
            private final Duration completionTime;

            public Entry(SettableFuture<Void> settableFuture, Duration duration) {
                this.future = settableFuture;
                this.completionTime = duration;
            }
        }

        private TestFutureCompletor(Ticker ticker) {
            this.stopwatch = Stopwatch.createStarted(ticker);
        }

        public void completeFuture(SettableFuture<Void> settableFuture, Duration duration) {
            this.entries.add(new Entry(settableFuture, this.stopwatch.elapsed().plus(duration)));
        }

        public void trigger() {
            Duration elapsed = this.stopwatch.elapsed();
            Iterator<Entry> it = this.entries.iterator();
            while (it.hasNext()) {
                Entry next = it.next();
                if (next.completionTime.compareTo(elapsed) <= 0) {
                    next.future.set((Object) null);
                    it.remove();
                }
            }
        }
    }

    @BeforeClass
    public void beforeClass() {
        this.finalizerService = new FinalizerService();
        this.finalizerService.start();
        this.nodeTaskMap = new NodeTaskMap(this.finalizerService);
        this.ticker = new TestingTicker();
        this.futureCompletor = new TestFutureCompletor(this.ticker);
    }

    @AfterClass(alwaysRun = true)
    public void afterClass() {
        this.nodeTaskMap = null;
        if (this.finalizerService != null) {
            this.finalizerService.destroy();
            this.finalizerService = null;
        }
    }

    private void setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier testingNodeSupplier) {
        shutdownNodeAllocatorService();
        this.nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier)));
    }

    @AfterMethod(alwaysRun = true)
    public void shutdownNodeAllocatorService() {
        if (this.nodeAllocatorService != null) {
            this.nodeAllocatorService.stop();
        }
        this.nodeAllocatorService = null;
    }

    @Test
    public void testHappyPath() throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        TaskSourceFactory createTaskSourceFactory = createTaskSourceFactory(5, 2);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_3, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        TestingExchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        TestingExchange testingExchange3 = new TestingExchange();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, createTaskSourceFactory, nodeAllocator, testingExchange, ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange2, SOURCE_FRAGMENT_ID_2, testingExchange3), 2, 1);
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked);
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertBlocked(isBlocked);
            Assert.assertFalse(createFaultTolerantTaskScheduler.isBlocked().isDone());
            testingExchange3.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertUnblocked(isBlocked);
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked2 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked2);
            Assert.assertFalse(testingExchange.isNoMoreSinks());
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks).hasSize(3);
            Assertions.assertThat(tasks).containsKey(getTaskId(0, 0));
            Assertions.assertThat(tasks).containsKey(getTaskId(1, 0));
            Assertions.assertThat(tasks).containsKey(getTaskId(2, 0));
            tasks.get(getTaskId(0, 0)).fail(new RuntimeException("some failure"));
            assertUnblocked(isBlocked2);
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            moveTime(10, TimeUnit.SECONDS);
            createFaultTolerantTaskScheduler.schedule();
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks2 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks2).hasSize(4);
            Assertions.assertThat(tasks2).containsKey(getTaskId(3, 0));
            ListenableFuture isBlocked3 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked3);
            Assertions.assertThat(tasks2).containsKey(getTaskId(1, 0));
            tasks2.get(getTaskId(1, 0)).finish();
            assertUnblocked(isBlocked3);
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            Assertions.assertThat(testingExchange.getFinishedSinkHandles()).contains(new TestingExchange.TestingExchangeSinkHandle[]{new TestingExchange.TestingExchangeSinkHandle(1)});
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked4 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked4);
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks3 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks3).hasSize(5);
            Assertions.assertThat(tasks3).containsKey(getTaskId(0, 1));
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks4 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks4).containsKey(getTaskId(3, 0));
            tasks4.get(getTaskId(3, 0)).finish();
            Assertions.assertThat(testingExchange.getFinishedSinkHandles()).contains(new TestingExchange.TestingExchangeSinkHandle[]{new TestingExchange.TestingExchangeSinkHandle(1), new TestingExchange.TestingExchangeSinkHandle(3)});
            assertUnblocked(isBlocked4);
            createFaultTolerantTaskScheduler.schedule();
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks5 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks5).hasSize(6);
            Assertions.assertThat(tasks5).containsKey(getTaskId(4, 0));
            Assert.assertTrue(testingExchange.isNoMoreSinks());
            Assert.assertFalse(createFaultTolerantTaskScheduler.isFinished());
            ListenableFuture isBlocked5 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked5);
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks6 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks6).containsKey(getTaskId(4, 0));
            tasks6.get(getTaskId(0, 1)).finish();
            tasks6.get(getTaskId(2, 0)).finish();
            tasks6.get(getTaskId(4, 0)).finish();
            assertUnblocked(isBlocked5);
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            Assertions.assertThat(testingExchange.getFinishedSinkHandles()).contains(new TestingExchange.TestingExchangeSinkHandle[]{new TestingExchange.TestingExchangeSinkHandle(0), new TestingExchange.TestingExchangeSinkHandle(1), new TestingExchange.TestingExchangeSinkHandle(2), new TestingExchange.TestingExchangeSinkHandle(3), new TestingExchange.TestingExchangeSinkHandle(4)});
            Assert.assertTrue(testingExchange.isAllRequiredSinksFinished());
            Assert.assertTrue(createFaultTolerantTaskScheduler.isFinished());
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testTasksWaitingForNodes() throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        TaskSourceFactory testingTaskSourceFactory = new TestingTaskSourceFactory((Optional<CatalogHandle>) Optional.of(TestingHandles.TEST_CATALOG_HANDLE), (List<Split>) ImmutableList.of(new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_2.getHostAndPort()))), new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_3.getHostAndPort())))), 2);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_3, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        Exchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        TestingExchange testingExchange3 = new TestingExchange();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, testingTaskSourceFactory, nodeAllocator, testingExchange, ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange2, SOURCE_FRAGMENT_ID_2, testingExchange3), 2, 3);
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            testingExchange3.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            createFaultTolerantTaskScheduler.schedule();
            assertBlocked(createFaultTolerantTaskScheduler.isBlocked());
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks).hasSize(2);
            Assertions.assertThat(tasks).containsKey(getTaskId(0, 0));
            Assertions.assertThat(tasks).containsKey(getTaskId(3, 0));
            tasks.get(getTaskId(3, 0)).finish();
            createFaultTolerantTaskScheduler.schedule();
            assertBlocked(createFaultTolerantTaskScheduler.isBlocked());
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks2 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks2).hasSize(2);
            Assertions.assertThat(tasks2).containsKey(getTaskId(0, 0));
            Assertions.assertThat(tasks2).containsKey(getTaskId(3, 0));
            tasks2.get(getTaskId(0, 0)).finish();
            createFaultTolerantTaskScheduler.schedule();
            assertBlocked(createFaultTolerantTaskScheduler.isBlocked());
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks3 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks3).hasSize(4);
            Assertions.assertThat(tasks3).containsKey(getTaskId(0, 0));
            Assertions.assertThat(tasks3).containsKey(getTaskId(1, 0));
            Assertions.assertThat(tasks3).containsKey(getTaskId(3, 0));
            Assertions.assertThat(tasks3).containsKey(getTaskId(5, 0));
            tasks3.get(getTaskId(1, 0)).finish();
            createFaultTolerantTaskScheduler.schedule();
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks4 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks4).containsKey(getTaskId(2, 0));
            tasks4.get(getTaskId(2, 0)).finish();
            createFaultTolerantTaskScheduler.schedule();
            Map<TaskId, TestingRemoteTaskFactory.TestingRemoteTask> tasks5 = testingRemoteTaskFactory.getTasks();
            Assertions.assertThat(tasks5).containsKey(getTaskId(4, 0));
            tasks5.get(getTaskId(4, 0)).finish();
            tasks5.get(getTaskId(3, 0)).finish();
            tasks5.get(getTaskId(5, 0)).finish();
            createFaultTolerantTaskScheduler.schedule();
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            Assert.assertTrue(createFaultTolerantTaskScheduler.isFinished());
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testTaskFailure() throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        TaskSourceFactory createTaskSourceFactory = createTaskSourceFactory(3, 1);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        TestingExchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, createTaskSourceFactory, nodeAllocator, new TestingExchange(), ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange, SOURCE_FRAGMENT_ID_2, testingExchange2), 0, 1);
            testingExchange.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked);
            NodeAllocator.NodeLease acquire = nodeAllocator.acquire(new NodeRequirements(Optional.of(TestingHandles.TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1L, DataSize.Unit.GIGABYTE));
            NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TestingHandles.TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1L, DataSize.Unit.GIGABYTE));
            testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure"));
            assertUnblocked(isBlocked);
            assertUnblocked(acquire.getNode());
            assertUnblocked(acquire2.getNode());
            Objects.requireNonNull(createFaultTolerantTaskScheduler);
            Assertions.assertThatThrownBy(createFaultTolerantTaskScheduler::schedule).hasMessageContaining("some failure");
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            Assert.assertFalse(createFaultTolerantTaskScheduler.isFinished());
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testRetryDelay() throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        TaskSourceFactory createTaskSourceFactory = createTaskSourceFactory(3, 1);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_3, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        TestingExchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        Session build = TestingSession.testSessionBuilder().setQueryId(QUERY_ID).setSystemProperty("retry_initial_delay", "1s").setSystemProperty("retry_max_delay", "3s").setSystemProperty("retry_delay_scale_factor", "2.0").build();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(build, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(build, testingRemoteTaskFactory, createTaskSourceFactory, nodeAllocator, new TestingExchange(), ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange, SOURCE_FRAGMENT_ID_2, testingExchange2), 6, 1);
            testingExchange.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(3);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure"));
            assertUnblocked(isBlocked);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked2 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked2);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(3);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(900, TimeUnit.MILLISECONDS);
            assertBlocked(isBlocked2);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(3);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(500, TimeUnit.MILLISECONDS);
            assertUnblocked(isBlocked2);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked3 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked3);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(4);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).fail(new RuntimeException("some other failure"));
            assertUnblocked(isBlocked3);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked4 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked4);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(4);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(1900, TimeUnit.MILLISECONDS);
            assertBlocked(isBlocked4);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(4);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(200, TimeUnit.MILLISECONDS);
            assertUnblocked(isBlocked4);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked5 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked5);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(5);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).fail(new RuntimeException("some other failure"));
            assertUnblocked(isBlocked5);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked6 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked6);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(5);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(2900, TimeUnit.MILLISECONDS);
            assertBlocked(isBlocked6);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(5);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(200, TimeUnit.MILLISECONDS);
            assertUnblocked(isBlocked6);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked7 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked7);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(6);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).fail(new RuntimeException("some other failure"));
            assertUnblocked(isBlocked7);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked8 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked8);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(6);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(2400, TimeUnit.MILLISECONDS);
            assertBlocked(isBlocked8);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(6);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).finish();
            assertBlocked(isBlocked8);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(6);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            moveTime(700, TimeUnit.MILLISECONDS);
            assertUnblocked(isBlocked8);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked9 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked9);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(7);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING);
            testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).finish();
            assertUnblocked(isBlocked9);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked10 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked10);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(7);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED);
            testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).fail(new RuntimeException("some other failure"));
            assertUnblocked(isBlocked10);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked11 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked11);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(7);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED);
            moveTime(900, TimeUnit.MILLISECONDS);
            assertBlocked(isBlocked11);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(7);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED);
            moveTime(200, TimeUnit.MILLISECONDS);
            assertUnblocked(isBlocked11);
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked12 = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked12);
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(8);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 2)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED);
            testingRemoteTaskFactory.getTasks().get(getTaskId(1, 2)).fail(new TrinoException(StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "oom"));
            assertUnblocked(isBlocked12);
            createFaultTolerantTaskScheduler.schedule();
            assertBlocked(createFaultTolerantTaskScheduler.isBlocked());
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(9);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 2)).getTaskStatus().getState(), TaskState.FAILED);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(1, 3)).getTaskStatus().getState(), TaskState.RUNNING);
            Assert.assertEquals(testingRemoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED);
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testCancellation() throws Exception {
        testCancellation(true);
        testCancellation(false);
    }

    private void testCancellation(boolean z) throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        TestingTaskSourceFactory createTaskSourceFactory = createTaskSourceFactory(3, 1);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        TestingExchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, createTaskSourceFactory, nodeAllocator, new TestingExchange(), ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange, SOURCE_FRAGMENT_ID_2, testingExchange2), 0, 1);
            testingExchange.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            ListenableFuture isBlocked = createFaultTolerantTaskScheduler.isBlocked();
            assertBlocked(isBlocked);
            NodeAllocator.NodeLease acquire = nodeAllocator.acquire(new NodeRequirements(Optional.of(TestingHandles.TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1L, DataSize.Unit.GIGABYTE));
            NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TestingHandles.TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1L, DataSize.Unit.GIGABYTE));
            if (z) {
                createFaultTolerantTaskScheduler.abort();
            } else {
                createFaultTolerantTaskScheduler.cancel();
            }
            assertUnblocked(isBlocked);
            assertUnblocked(acquire.getNode());
            assertUnblocked(acquire2.getNode());
            createFaultTolerantTaskScheduler.schedule();
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            Assert.assertFalse(createFaultTolerantTaskScheduler.isFinished());
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testAsyncTaskSource() throws Exception {
        TestingRemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        SettableFuture create = SettableFuture.create();
        TaskSourceFactory testingTaskSourceFactory = new TestingTaskSourceFactory((Optional<CatalogHandle>) Optional.of(TestingHandles.TEST_CATALOG_HANDLE), (ListenableFuture<List<Split>>) create, 1);
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE), NODE_2, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        TestingExchange testingExchange = new TestingExchange();
        TestingExchange testingExchange2 = new TestingExchange();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, testingTaskSourceFactory, nodeAllocator, new TestingExchange(), ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange, SOURCE_FRAGMENT_ID_2, testingExchange2), 2, 1);
            testingExchange.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            testingExchange2.setSourceHandles(ImmutableList.of(new TestingExchange.TestingExchangeSourceHandle(0, 1L)));
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            assertBlocked(createFaultTolerantTaskScheduler.isBlocked());
            create.set(createSplits(2));
            assertUnblocked(createFaultTolerantTaskScheduler.isBlocked());
            createFaultTolerantTaskScheduler.schedule();
            Assertions.assertThat(testingRemoteTaskFactory.getTasks()).hasSize(2);
            testingRemoteTaskFactory.getTasks().values().forEach(testingRemoteTask -> {
                Assertions.assertThat(testingRemoteTask.getSplits().values()).hasSize(3);
                testingRemoteTask.finish();
            });
            Assertions.assertThat(createFaultTolerantTaskScheduler.isFinished()).isTrue();
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testIsFinished() throws Exception {
        setupNodeAllocatorService(TestingNodeSelectorFactory.TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TestingHandles.TEST_CATALOG_HANDLE))));
        RemoteTaskFactory testingRemoteTaskFactory = new TestingRemoteTaskFactory();
        final SettableFuture create = SettableFuture.create();
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        final AtomicBoolean atomicBoolean2 = new AtomicBoolean();
        TaskSource taskSource = new TaskSource() { // from class: io.trino.execution.scheduler.TestFaultTolerantStageScheduler.1
            public ListenableFuture<List<TaskDescriptor>> getMoreTasks() {
                return create;
            }

            public boolean isFinished() {
                return atomicBoolean2.get();
            }

            public void close() {
            }
        };
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(SESSION, 1);
        try {
            TestingExchange testingExchange = new TestingExchange();
            testingExchange.setSourceHandles(ImmutableList.of());
            TestingExchange testingExchange2 = new TestingExchange();
            testingExchange2.setSourceHandles(ImmutableList.of());
            TestingExchange testingExchange3 = new TestingExchange();
            FaultTolerantStageScheduler createFaultTolerantTaskScheduler = createFaultTolerantTaskScheduler(testingRemoteTaskFactory, (session, planFragment, multimap, longConsumer, faultTolerantPartitioningScheme) -> {
                atomicBoolean.set(true);
                return taskSource;
            }, nodeAllocator, testingExchange3, ImmutableMap.of(SOURCE_FRAGMENT_ID_1, testingExchange, SOURCE_FRAGMENT_ID_2, testingExchange2), 1, 1);
            Assert.assertFalse(atomicBoolean.get());
            createFaultTolerantTaskScheduler.schedule();
            Assert.assertTrue(atomicBoolean.get());
            createFaultTolerantTaskScheduler.schedule();
            Assert.assertFalse(createFaultTolerantTaskScheduler.isFinished());
            atomicBoolean2.set(true);
            createFaultTolerantTaskScheduler.schedule();
            Assert.assertFalse(createFaultTolerantTaskScheduler.isFinished());
            create.set(ImmutableList.of());
            Assert.assertTrue(createFaultTolerantTaskScheduler.isFinished());
            Assert.assertTrue(testingExchange3.isNoMoreSinks());
            Assert.assertTrue(testingExchange3.isAllRequiredSinksFinished());
            if (nodeAllocator != null) {
                nodeAllocator.close();
            }
        } catch (Throwable th) {
            if (nodeAllocator != null) {
                try {
                    nodeAllocator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private FaultTolerantStageScheduler createFaultTolerantTaskScheduler(RemoteTaskFactory remoteTaskFactory, TaskSourceFactory taskSourceFactory, NodeAllocator nodeAllocator, Exchange exchange, Map<PlanFragmentId, Exchange> map, int i, int i2) {
        return createFaultTolerantTaskScheduler(SESSION, remoteTaskFactory, taskSourceFactory, nodeAllocator, exchange, map, i, i2);
    }

    private FaultTolerantStageScheduler createFaultTolerantTaskScheduler(Session session, RemoteTaskFactory remoteTaskFactory, TaskSourceFactory taskSourceFactory, NodeAllocator nodeAllocator, Exchange exchange, Map<PlanFragmentId, Exchange> map, int i, int i2) {
        TaskDescriptorStorage taskDescriptorStorage = new TaskDescriptorStorage(DataSize.of(10L, DataSize.Unit.MEGABYTE));
        taskDescriptorStorage.initialize(SESSION.getQueryId());
        DynamicFilterService dynamicFilterService = new DynamicFilterService(PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), PLANNER_CONTEXT.getTypeOperators(), new DynamicFilterConfig());
        return createStageScheduler(session, createSqlStage(createIntermediatePlanFragment(), remoteTaskFactory), nodeAllocator, i, i2, taskDescriptorStorage, taskSourceFactory, dynamicFilterService, exchange, (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            FaultTolerantStageScheduler createStageScheduler = createStageScheduler(session, createSqlStage(createLeafPlanFragment((PlanFragmentId) entry.getKey()), remoteTaskFactory), nodeAllocator, i, i2, taskDescriptorStorage, new TestingTaskSourceFactory((Optional<CatalogHandle>) Optional.empty(), (List<Split>) ImmutableList.of(), 1), dynamicFilterService, (Exchange) entry.getValue(), ImmutableMap.of(), ImmutableMap.of());
            while (!createStageScheduler.isFinished()) {
                try {
                    createStageScheduler.schedule();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            return createStageScheduler;
        })), map);
    }

    private FaultTolerantStageScheduler createStageScheduler(Session session, SqlStage sqlStage, NodeAllocator nodeAllocator, int i, int i2, TaskDescriptorStorage taskDescriptorStorage, TaskSourceFactory taskSourceFactory, DynamicFilterService dynamicFilterService, Exchange exchange, Map<PlanFragmentId, FaultTolerantStageScheduler> map, Map<PlanFragmentId, Exchange> map2) {
        FaultTolerantPartitioningScheme faultTolerantPartitioningScheme = new FaultTolerantPartitioningScheme(3, Optional.empty(), Optional.empty(), Optional.empty());
        return new FaultTolerantStageScheduler(session, sqlStage, new NoOpFailureDetector(), taskSourceFactory, nodeAllocator, taskDescriptorStorage, new ConstantPartitionMemoryEstimator(), new TaskExecutionStats(), this.futureCompletor, this.ticker, exchange, faultTolerantPartitioningScheme, map, map2, faultTolerantPartitioningScheme, new AtomicInteger(i), i, i2, dynamicFilterService);
    }

    private SqlStage createSqlStage(PlanFragment planFragment, RemoteTaskFactory remoteTaskFactory) {
        return SqlStage.createSqlStage(STAGE_ID, planFragment, ImmutableMap.of(), remoteTaskFactory, SESSION, false, this.nodeTaskMap, MoreExecutors.directExecutor(), new SplitSchedulerStats());
    }

    private PlanFragment createIntermediatePlanFragment() {
        Symbol symbol = new Symbol("probe_column");
        Symbol symbol2 = new Symbol("build_column");
        TableScanNode tableScanNode = new TableScanNode(TABLE_SCAN_NODE_ID, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of(symbol), ImmutableMap.of(symbol, new TestingMetadata.TestingColumnHandle("column")), TupleDomain.none(), Optional.empty(), false, Optional.empty());
        RemoteSourceNode remoteSourceNode = new RemoteSourceNode(new PlanNodeId("remote_source_id"), ImmutableList.of(SOURCE_FRAGMENT_ID_1, SOURCE_FRAGMENT_ID_2), ImmutableList.of(symbol2), Optional.empty(), ExchangeNode.Type.REPLICATE, RetryPolicy.TASK);
        return new PlanFragment(FRAGMENT_ID, new JoinNode(new PlanNodeId("join_id"), JoinNode.Type.INNER, tableScanNode, remoteSourceNode, ImmutableList.of(), tableScanNode.getOutputSymbols(), remoteSourceNode.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(JoinNode.DistributionType.REPLICATED), Optional.empty(), ImmutableMap.of(), Optional.empty()), ImmutableMap.of(symbol, VarcharType.VARCHAR, symbol2, VarcharType.VARCHAR), SystemPartitioningHandle.SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol, symbol2)), StatsAndCosts.empty(), ImmutableList.of(), Optional.empty());
    }

    private PlanFragment createLeafPlanFragment(PlanFragmentId planFragmentId) {
        Symbol symbol = new Symbol("output_column");
        return new PlanFragment(planFragmentId, new TableScanNode(TABLE_SCAN_NODE_ID, TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of(symbol), ImmutableMap.of(symbol, new TestingMetadata.TestingColumnHandle("column")), TupleDomain.none(), Optional.empty(), false, Optional.empty()), ImmutableMap.of(symbol, VarcharType.VARCHAR), SystemPartitioningHandle.SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), Optional.empty());
    }

    private static TestingTaskSourceFactory createTaskSourceFactory(int i, int i2) {
        return new TestingTaskSourceFactory((Optional<CatalogHandle>) Optional.of(TestingHandles.TEST_CATALOG_HANDLE), createSplits(i), i2);
    }

    private static List<Split> createSplits(int i) {
        return ImmutableList.copyOf(Iterables.limit(Iterables.cycle(new Split[]{new Split(TestingHandles.TEST_CATALOG_HANDLE, TestingSplit.createRemoteSplit())}), i));
    }

    private static TaskId getTaskId(int i, int i2) {
        return new TaskId(STAGE_ID, i, i2);
    }

    private static void assertBlocked(ListenableFuture<?> listenableFuture) {
        Assert.assertFalse(listenableFuture.isDone());
    }

    private static void assertUnblocked(ListenableFuture<?> listenableFuture) {
        Assert.assertTrue(listenableFuture.isDone());
    }

    private void moveTime(int i, TimeUnit timeUnit) {
        this.ticker.increment(i, timeUnit);
        this.futureCompletor.trigger();
    }
}
