package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.cost.StatsAndCosts;
import io.trino.operator.RetryPolicy;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.IndexJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.StringLiteral;
import java.util.List;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.class */
public class TestTopologicalOrderSubPlanVisitor {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor$JoinFunction.class */
    public interface JoinFunction {
        PlanNode apply(String str, PlanNode planNode, PlanNode planNode2);
    }

    private void test(JoinFunction joinFunction) {
        SubPlan valuesSubPlan = valuesSubPlan("a");
        SubPlan valuesSubPlan2 = valuesSubPlan("b");
        SubPlan valuesSubPlan3 = valuesSubPlan("c");
        SubPlan createSubPlan = createSubPlan("middle", joinFunction.apply("middle_join", remoteSource("b"), remoteSource("c")), ImmutableList.of(valuesSubPlan2, valuesSubPlan3));
        SubPlan createSubPlan2 = createSubPlan("root", joinFunction.apply("root_join", remoteSource("a"), remoteSource("middle")), ImmutableList.of(valuesSubPlan, createSubPlan));
        Assertions.assertThat(TopologicalOrderSubPlanVisitor.sortPlanInTopologicalOrder(createSubPlan2)).isEqualTo(ImmutableList.of(valuesSubPlan3, valuesSubPlan2, createSubPlan, valuesSubPlan, createSubPlan2));
    }

    @Test
    public void testJoinOrder() {
        test(TestTopologicalOrderSubPlanVisitor::join);
    }

    @Test
    public void testSemiJoinOrder() {
        test(TestTopologicalOrderSubPlanVisitor::semiJoin);
    }

    @Test
    public void testIndexJoin() {
        test(TestTopologicalOrderSubPlanVisitor::indexJoin);
    }

    @Test
    public void testSpatialJoin() {
        test(TestTopologicalOrderSubPlanVisitor::spatialJoin);
    }

    private static RemoteSourceNode remoteSource(String str) {
        return remoteSource((List<String>) ImmutableList.of(str));
    }

    private static RemoteSourceNode remoteSource(List<String> list) {
        return new RemoteSourceNode(new PlanNodeId(list.get(0)), (List) list.stream().map(PlanFragmentId::new).collect(ImmutableList.toImmutableList()), ImmutableList.of(new Symbol("blah")), Optional.empty(), ExchangeNode.Type.REPARTITION, RetryPolicy.TASK);
    }

    private static JoinNode join(String str, PlanNode planNode, PlanNode planNode2) {
        return new JoinNode(new PlanNodeId(str), JoinNode.Type.INNER, planNode, planNode2, ImmutableList.of(), planNode.getOutputSymbols(), planNode2.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    }

    private static SemiJoinNode semiJoin(String str, PlanNode planNode, PlanNode planNode2) {
        return new SemiJoinNode(new PlanNodeId(str), planNode, planNode2, (Symbol) planNode.getOutputSymbols().get(0), (Symbol) planNode2.getOutputSymbols().get(0), new Symbol(str), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static IndexJoinNode indexJoin(String str, PlanNode planNode, PlanNode planNode2) {
        return new IndexJoinNode(new PlanNodeId(str), IndexJoinNode.Type.INNER, planNode, planNode2, ImmutableList.of(), Optional.empty(), Optional.empty());
    }

    private static SpatialJoinNode spatialJoin(String str, PlanNode planNode, PlanNode planNode2) {
        return new SpatialJoinNode(new PlanNodeId(str), SpatialJoinNode.Type.INNER, planNode, planNode2, planNode.getOutputSymbols(), BooleanLiteral.TRUE_LITERAL, Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static SubPlan valuesSubPlan(String str) {
        return createSubPlan(str, new ValuesNode(new PlanNodeId(str + "Values"), ImmutableList.of(new Symbol("column")), ImmutableList.of(new Row(ImmutableList.of(new StringLiteral("foo"))))), ImmutableList.of());
    }

    private static SubPlan createSubPlan(String str, PlanNode planNode, List<SubPlan> list) {
        Symbol symbol = (Symbol) planNode.getOutputSymbols().get(0);
        return new SubPlan(new PlanFragment(new PlanFragmentId(str), planNode, ImmutableMap.of(symbol, VarcharType.VARCHAR), SystemPartitioningHandle.SOURCE_DISTRIBUTION, Optional.empty(), ImmutableList.of(new PlanNodeId("plan")), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), ImmutableList.of(), Optional.empty()), list);
    }
}
