package io.trino.sql.planner.optimizations;

import com.google.common.collect.Iterables;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TopNNode;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/optimizations/TestUnion.class */
public class TestUnion extends BasePlanTest {
    public TestUnion() {
    }

    public TestUnion(Map<String, String> map) {
        super(map);
    }

    @Test
    public void testSimpleUnion() {
        Plan plan = plan("SELECT suppkey FROM supplier UNION ALL SELECT nationkey FROM nation", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        List findAll = PlanNodeSearcher.searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals(findAll.size(), 1, "There should be exactly one RemoteExchange");
        Assert.assertEquals(((ExchangeNode) Iterables.getOnlyElement(findAll)).getType(), ExchangeNode.Type.GATHER);
        assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionUnderTopN() {
        Plan plan = plan("SELECT * FROM (   SELECT regionkey FROM nation    UNION ALL    SELECT nationkey FROM nation) t(a) ORDER BY a LIMIT 1", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        List findAll = PlanNodeSearcher.searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals(findAll.size(), 1, "There should be exactly one RemoteExchange");
        Assert.assertEquals(((ExchangeNode) Iterables.getOnlyElement(findAll)).getType(), ExchangeNode.Type.GATHER);
        Assert.assertEquals(PlanNodeSearcher.searchFrom(plan.getRoot()).where(planNode -> {
            return (planNode instanceof TopNNode) && ((TopNNode) planNode).getStep() == TopNNode.Step.PARTIAL;
        }).count(), 2, "There should be exactly two partial TopN nodes");
        assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionOverSingleNodeAggregationAndUnion() {
        List findAll = PlanNodeSearcher.searchFrom(plan("SELECT count(*) FROM (SELECT 1 FROM nation GROUP BY regionkey UNION ALL (   SELECT 1 FROM nation    UNION ALL    SELECT 1 FROM nation))", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false).getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals(findAll.size(), 2, "There should be exactly two RemoteExchanges");
        Assert.assertEquals(((ExchangeNode) findAll.get(0)).getType(), ExchangeNode.Type.GATHER);
        Assert.assertEquals(((ExchangeNode) findAll.get(1)).getType(), ExchangeNode.Type.REPARTITION);
    }

    @Test
    public void testPartialAggregationsWithUnion() {
        Plan plan = plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testPartialRollupAggregationsWithUnion() {
        Plan plan = plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY ROLLUP (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testAggregationWithUnionAndValues() {
        assertAtMostOneAggregationBetweenRemoteExchanges(plan("SELECT regionkey, count(*) FROM (SELECT regionkey FROM nation UNION ALL SELECT * FROM (VALUES 2, 100) t(regionkey)) GROUP BY regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false));
    }

    @Test
    public void testUnionOnProbeSide() {
        assertPlanIsFullyDistributed(plan("SELECT * FROM (SELECT * FROM nation UNION ALL SELECT * from nation) n, region r WHERE n.regionkey=r.regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false));
    }

    private void assertPlanIsFullyDistributed(Plan plan) {
        int size = PlanNodeSearcher.searchFrom(plan.getRoot()).where(TestUnion::isRemoteGatheringExchange).findAll().size();
        if (size == 0) {
            return;
        }
        Assert.assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()).recurseOnlyWhen(TestUnion::isNotRemoteGatheringExchange).findAll().stream().noneMatch(this::shouldBeDistributed), "There is a node that should be distributed between output and first REMOTE GATHER ExchangeNode");
        Assert.assertEquals(size, 1, "Only a single REMOTE GATHER was expected");
    }

    private boolean shouldBeDistributed(PlanNode planNode) {
        if ((planNode instanceof JoinNode) || (planNode instanceof AggregationNode)) {
            return true;
        }
        return (planNode instanceof TopNNode) && ((TopNNode) planNode).getStep() == TopNNode.Step.PARTIAL;
    }

    private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) {
        Iterator it = ((List) PlanNodeSearcher.searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll().stream().flatMap(planNode -> {
            return planNode.getSources().stream();
        }).collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom((PlanNode) it.next());
            Class<AggregationNode> cls = AggregationNode.class;
            Objects.requireNonNull(AggregationNode.class);
            Assert.assertFalse(searchFrom.where((v1) -> {
                return r1.isInstance(v1);
            }).recurseOnlyWhen(TestUnion::isNotRemoteExchange).findAll().size() > 1, "More than a single AggregationNode between remote exchanges");
        }
    }

    private static boolean isNotRemoteGatheringExchange(PlanNode planNode) {
        return !isRemoteGatheringExchange(planNode);
    }

    private static boolean isRemoteGatheringExchange(PlanNode planNode) {
        return isRemoteExchange(planNode) && ((ExchangeNode) planNode).getType() == ExchangeNode.Type.GATHER;
    }

    private static boolean isNotRemoteExchange(PlanNode planNode) {
        return !isRemoteExchange(planNode);
    }

    private static boolean isRemoteExchange(PlanNode planNode) {
        return (planNode instanceof ExchangeNode) && ((ExchangeNode) planNode).getScope() == ExchangeNode.Scope.REMOTE;
    }
}
