package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPruneCorrelatedJoinColumns.class */
public class TestPruneCorrelatedJoinColumns extends BaseRuleTest {
    public TestPruneCorrelatedJoinColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testRemoveUnusedCorrelatedJoinNode() {
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("correlation_symbol");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.correlatedJoin(ImmutableList.of(symbol2), planBuilder.values(symbol, symbol2), planBuilder.values(1, planBuilder.symbol("b"))));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a")), PlanMatchPattern.values("a", "correlationSymbol")));
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("a");
            Symbol symbol2 = planBuilder2.symbol("correlation_symbol");
            Symbol symbol3 = planBuilder2.symbol("b");
            return planBuilder2.project(Assignments.identity(new Symbol[]{symbol}), planBuilder2.correlatedJoin(ImmutableList.of(symbol2), planBuilder2.values(symbol, symbol2), CorrelatedJoinNode.Type.LEFT, new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol3.toSymbolReference(), symbol2.toSymbolReference()), planBuilder2.values(1, symbol3)));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a")), PlanMatchPattern.values("a", "correlationSymbol")));
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("a");
            Symbol symbol2 = planBuilder3.symbol("b");
            return planBuilder3.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder3.correlatedJoin(ImmutableList.of(), planBuilder3.values(1, symbol), planBuilder3.values(symbol2)));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("b", PlanMatchPattern.expression("b")), PlanMatchPattern.values("b")));
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder4 -> {
            Symbol symbol = planBuilder4.symbol("a");
            Symbol symbol2 = planBuilder4.symbol("b");
            return planBuilder4.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder4.correlatedJoin(ImmutableList.of(), planBuilder4.values(1, symbol), CorrelatedJoinNode.Type.RIGHT, new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol2.toSymbolReference(), symbol.toSymbolReference()), planBuilder4.values(symbol2)));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("b", PlanMatchPattern.expression("b")), PlanMatchPattern.values("b")));
    }

    @Test
    public void testPruneUnreferencedSubquerySymbol() {
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("correlation_symbol");
            Symbol symbol3 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.correlatedJoin(ImmutableList.of(symbol2), planBuilder.values(symbol, symbol2), CorrelatedJoinNode.Type.LEFT, new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol3.toSymbolReference(), symbol.toSymbolReference()), planBuilder.filter(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol3.toSymbolReference(), symbol2.toSymbolReference()), planBuilder.values(5, symbol3, planBuilder.symbol("c")))));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a")), PlanMatchPattern.correlatedJoin(ImmutableList.of("correlation_symbol"), PlanMatchPattern.values("a", "correlation_symbol"), PlanMatchPattern.project(ImmutableMap.of("b", PlanMatchPattern.expression("b")), PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.values("b", "c"))))));
    }

    @Test
    public void testPruneUnreferencedInputSymbol() {
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("correlation_symbol");
            Symbol symbol3 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol3}), planBuilder.correlatedJoin(ImmutableList.of(symbol2), planBuilder.values(symbol, symbol2), CorrelatedJoinNode.Type.LEFT, new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol3.toSymbolReference(), symbol2.toSymbolReference()), planBuilder.filter(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, symbol3.toSymbolReference(), symbol2.toSymbolReference()), planBuilder.values(symbol3))));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("b", PlanMatchPattern.expression("b")), PlanMatchPattern.correlatedJoin(ImmutableList.of("correlation_symbol"), PlanMatchPattern.project(ImmutableMap.of("correlation_symbol", PlanMatchPattern.expression("correlation_symbol")), PlanMatchPattern.values("a", "correlation_symbol")), PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.values("b")))));
    }

    @Test
    public void testDoNotPruneUnreferencedCorrelationSymbol() {
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("correlation_symbol");
            Symbol symbol3 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol, symbol3}), planBuilder.correlatedJoin(ImmutableList.of(symbol2), planBuilder.values(symbol, symbol2), CorrelatedJoinNode.Type.LEFT, BooleanLiteral.TRUE_LITERAL, planBuilder.values(symbol3)));
        }).doesNotFire();
    }

    @Test
    public void testAllOutputsReferenced() {
        tester().assertThat(new PruneCorrelatedJoinColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("correlation_symbol");
            Symbol symbol3 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol, symbol3}), planBuilder.correlatedJoin(ImmutableList.of(symbol2), planBuilder.values(symbol, symbol2), CorrelatedJoinNode.Type.LEFT, BooleanLiteral.TRUE_LITERAL, planBuilder.filter(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, symbol3.toSymbolReference(), symbol2.toSymbolReference()), planBuilder.values(symbol3))));
        }).doesNotFire();
    }
}
