package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.testing.Closeables;
import io.trino.Session;
import io.trino.cost.CachingCostProvider;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.CostComparator;
import io.trino.cost.CostProvider;
import io.trino.cost.PlanCostEstimate;
import io.trino.cost.StatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.ReorderJoins;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.io.Closeable;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.CONCURRENT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestJoinEnumerator.class */
public class TestJoinEnumerator {
    private LocalQueryRunner queryRunner;

    @BeforeAll
    public void setUp() {
        this.queryRunner = LocalQueryRunner.create(TestingSession.testSessionBuilder().build());
    }

    @AfterAll
    public void tearDown() {
        Closeables.closeAllRuntimeException(new Closeable[]{this.queryRunner});
        this.queryRunner = null;
    }

    @Test
    public void testGeneratePartitions() {
        Assertions.assertThat(ReorderJoins.JoinEnumerator.generatePartitions(4)).isEqualTo(ImmutableSet.of(ImmutableSet.of(0), ImmutableSet.of(0, 1), ImmutableSet.of(0, 2), ImmutableSet.of(0, 3), ImmutableSet.of(0, 1, 2), ImmutableSet.of(0, 1, 3), new ImmutableSet[]{ImmutableSet.of(0, 2, 3)}));
        Assertions.assertThat(ReorderJoins.JoinEnumerator.generatePartitions(3)).isEqualTo(ImmutableSet.of(ImmutableSet.of(0), ImmutableSet.of(0, 1), ImmutableSet.of(0, 2)));
    }

    @Test
    public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() {
        PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), this.queryRunner.getPlannerContext(), this.queryRunner.getDefaultSession());
        Symbol symbol = planBuilder.symbol("A1");
        Symbol symbol2 = planBuilder.symbol("B1");
        ReorderJoins.MultiJoinNode multiJoinNode = new ReorderJoins.MultiJoinNode(new LinkedHashSet((Collection) ImmutableList.of(planBuilder.values(symbol), planBuilder.values(symbol2))), BooleanLiteral.TRUE_LITERAL, ImmutableList.of(symbol, symbol2), false);
        ReorderJoins.JoinEnumerationResult createJoinAccordingToPartitioning = new ReorderJoins.JoinEnumerator(this.queryRunner.getMetadata(), new CostComparator(1.0d, 1.0d, 1.0d), multiJoinNode.getFilter(), createContext()).createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), ImmutableSet.of(0));
        Assertions.assertThat(createJoinAccordingToPartitioning.getPlanNode().isPresent()).isFalse();
        Assertions.assertThat(createJoinAccordingToPartitioning.getCost()).isEqualTo(PlanCostEstimate.infinite());
    }

    private Rule.Context createContext() {
        final PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
        final SymbolAllocator symbolAllocator = new SymbolAllocator();
        final CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(this.queryRunner.getStatsCalculator(), Optional.empty(), Lookup.noLookup(), this.queryRunner.getDefaultSession(), symbolAllocator.getTypes(), new CachingTableStatsProvider(this.queryRunner.getMetadata(), this.queryRunner.getDefaultSession()));
        final CachingCostProvider cachingCostProvider = new CachingCostProvider(this.queryRunner.getCostCalculator(), cachingStatsProvider, Optional.empty(), this.queryRunner.getDefaultSession(), symbolAllocator.getTypes());
        return new Rule.Context() { // from class: io.trino.sql.planner.iterative.rule.TestJoinEnumerator.1
            public Lookup getLookup() {
                return Lookup.noLookup();
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return planNodeIdAllocator;
            }

            public SymbolAllocator getSymbolAllocator() {
                return symbolAllocator;
            }

            public Session getSession() {
                return TestJoinEnumerator.this.queryRunner.getDefaultSession();
            }

            public StatsProvider getStatsProvider() {
                return cachingStatsProvider;
            }

            public CostProvider getCostProvider() {
                return cachingCostProvider;
            }

            public void checkTimeoutNotExhausted() {
            }

            public WarningCollector getWarningCollector() {
                return WarningCollector.NOOP;
            }
        };
    }
}
