package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.Expression;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.class */
public class TestPruneSemiJoinFilteringSourceColumns extends BaseRuleTest {
    public TestPruneSemiJoinFilteringSourceColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testNotAllColumnsReferenced() {
        tester().assertThat(new PruneSemiJoinFilteringSourceColumns()).on(planBuilder -> {
            return buildSemiJoin(planBuilder, symbol -> {
                return true;
            });
        }).matches(PlanMatchPattern.semiJoin("leftKey", "rightKey", "match", PlanMatchPattern.values("leftKey"), PlanMatchPattern.strictProject(ImmutableMap.of("rightKey", PlanMatchPattern.expression("rightKey"), "rightKeyHash", PlanMatchPattern.expression("rightKeyHash")), PlanMatchPattern.values("rightKey", "rightKeyHash", "rightValue"))));
    }

    @Test
    public void testAllColumnsNeeded() {
        tester().assertThat(new PruneSemiJoinFilteringSourceColumns()).on(planBuilder -> {
            return buildSemiJoin(planBuilder, symbol -> {
                return !symbol.getName().equals("rightValue");
            });
        }).doesNotFire();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static PlanNode buildSemiJoin(PlanBuilder planBuilder, Predicate<Symbol> predicate) {
        Symbol symbol = planBuilder.symbol("match");
        Symbol symbol2 = planBuilder.symbol("leftKey");
        Symbol symbol3 = planBuilder.symbol("rightKey");
        Symbol symbol4 = planBuilder.symbol("rightKeyHash");
        return planBuilder.semiJoin(symbol2, symbol3, symbol, Optional.empty(), Optional.of(symbol4), planBuilder.values(symbol2), planBuilder.values((List<Symbol>) ImmutableList.of(symbol3, symbol4, planBuilder.symbol("rightValue")).stream().filter(predicate).collect(ImmutableList.toImmutableList()), (List<List<Expression>>) ImmutableList.of()));
    }
}
