package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.graph.Traverser;
import io.trino.Session;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.warnings.WarningCollector;
import io.trino.plugin.tpch.TpchConnectorFactory;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionBuilder;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.CONCURRENT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/sql/planner/TestPlanFragmentPartitionCount.class */
public class TestPlanFragmentPartitionCount {
    private PlanFragmenter planFragmenter;
    private Session session;
    private LocalQueryRunner localQueryRunner;

    @BeforeAll
    public void setUp() {
        this.session = TestingSession.testSessionBuilder().setCatalog("test_catalog").build();
        this.localQueryRunner = LocalQueryRunner.create(this.session);
        this.localQueryRunner.createCatalog("test_catalog", new TpchConnectorFactory(), ImmutableMap.of());
        this.planFragmenter = new PlanFragmenter(this.localQueryRunner.getMetadata(), this.localQueryRunner.getFunctionManager(), this.localQueryRunner.getTransactionManager(), this.localQueryRunner.getCatalogManager(), this.localQueryRunner.getLanguageFunctionManager(), new QueryManagerConfig());
    }

    @AfterAll
    public void tearDown() {
        this.planFragmenter = null;
        this.session = null;
        this.localQueryRunner.close();
        this.localQueryRunner = null;
    }

    @Test
    public void testPartitionCountInPlanFragment() {
        PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), this.localQueryRunner.getPlannerContext(), this.session);
        Symbol symbol = planBuilder.symbol("a", VarcharType.VARCHAR);
        Symbol symbol2 = planBuilder.symbol("b", VarcharType.VARCHAR);
        Symbol symbol3 = planBuilder.symbol("c", VarcharType.VARCHAR);
        Symbol symbol4 = planBuilder.symbol("d", VarcharType.VARCHAR);
        Symbol symbol5 = planBuilder.symbol("f", VarcharType.VARCHAR);
        Symbol symbol6 = planBuilder.symbol("g", VarcharType.VARCHAR);
        Symbol symbol7 = planBuilder.symbol("h", VarcharType.VARCHAR);
        Symbol symbol8 = planBuilder.symbol("i", VarcharType.VARCHAR);
        SubPlan fragment = fragment(new Plan(planBuilder.output(outputBuilder -> {
            outputBuilder.source(planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.type(ExchangeNode.Type.REPARTITION).addSource(planBuilder.exchange(exchangeBuilder -> {
                    exchangeBuilder.type(ExchangeNode.Type.REPARTITION).addSource(planBuilder.join(JoinNode.Type.INNER, planBuilder.exchange(exchangeBuilder -> {
                        exchangeBuilder.type(ExchangeNode.Type.REPARTITION).addSource(planBuilder.values(symbol, symbol2)).addInputsSet(symbol, symbol2).fixedHashDistributionPartitioningScheme((List<Symbol>) ImmutableList.of(symbol, symbol2), (List<Symbol>) ImmutableList.of(symbol2), 5);
                    }), planBuilder.exchange(exchangeBuilder2 -> {
                        exchangeBuilder2.type(ExchangeNode.Type.REPARTITION).addSource(planBuilder.values(symbol3, symbol4)).addInputsSet(symbol3, symbol4).fixedHashDistributionPartitioningScheme((List<Symbol>) ImmutableList.of(symbol3, symbol4), (List<Symbol>) ImmutableList.of(symbol4), 5);
                    }), new JoinNode.EquiJoinClause(symbol2, symbol4))).addInputsSet(symbol, symbol2, symbol3, symbol4).fixedArbitraryDistributionPartitioningScheme(ImmutableList.of(symbol, symbol2, symbol3, symbol4), 2);
                })).addSource(planBuilder.values(symbol5, symbol6, symbol7, symbol8)).addInputsSet(symbol, symbol2, symbol3, symbol4).addInputsSet(symbol5, symbol6, symbol7, symbol8).fixedHashDistributionPartitioningScheme((List<Symbol>) ImmutableList.of(symbol, symbol2, symbol3, symbol4), (List<Symbol>) ImmutableList.of(symbol2), 3);
            }));
        }), planBuilder.getTypes(), StatsAndCosts.empty()));
        ImmutableMap.Builder builder = ImmutableMap.builder();
        Traverser.forTree((v0) -> {
            return v0.getChildren();
        }).depthFirstPreOrder(fragment).forEach(subPlan -> {
            builder.put(subPlan.getFragment().getId(), subPlan.getFragment().getPartitionCount());
        });
        Assertions.assertThat(builder.buildOrThrow()).isEqualTo(ImmutableMap.of(new PlanFragmentId("0"), Optional.of(3), new PlanFragmentId("1"), Optional.of(2), new PlanFragmentId("2"), Optional.of(5), new PlanFragmentId("3"), Optional.empty(), new PlanFragmentId("4"), Optional.empty(), new PlanFragmentId("5"), Optional.empty()));
    }

    private SubPlan fragment(Plan plan) {
        this.localQueryRunner.getLanguageFunctionManager().registerQuery(this.session);
        return (SubPlan) inTransaction(session -> {
            return this.planFragmenter.createSubPlans(session, plan, false, WarningCollector.NOOP);
        });
    }

    private <T> T inTransaction(Function<Session, T> function) {
        return (T) TransactionBuilder.transaction(this.localQueryRunner.getTransactionManager(), this.localQueryRunner.getMetadata(), new AllowAllAccessControl()).singleStatement().execute(this.session, session -> {
            session.getCatalog().ifPresent(str -> {
                this.localQueryRunner.getMetadata().getCatalogHandle(session, str);
            });
            return function.apply(session);
        });
    }
}
