package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationSourceColumns.class */
public class TestPruneAggregationSourceColumns extends BaseRuleTest {
    public TestPruneAggregationSourceColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testNotAllInputsReferenced() {
        tester().assertThat((Rule) new PruneAggregationSourceColumns()).on(planBuilder -> {
            return buildAggregation(planBuilder, Predicates.alwaysTrue());
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("key"), ImmutableMap.of(Optional.of("avg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("input"))), ImmutableMap.of(new Symbol("avg"), new Symbol("mask")), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.strictProject(ImmutableMap.of("input", PlanMatchPattern.expression("input"), "key", PlanMatchPattern.expression("key"), "keyHash", PlanMatchPattern.expression("keyHash"), "mask", PlanMatchPattern.expression("mask")), PlanMatchPattern.values("input", "key", "keyHash", "mask", "unused"))));
    }

    @Test
    public void testAllInputsReferenced() {
        tester().assertThat((Rule) new PruneAggregationSourceColumns()).on(planBuilder -> {
            return buildAggregation(planBuilder, variableReferenceExpression -> {
                return !variableReferenceExpression.getName().equals("unused");
            });
        }).doesNotFire();
    }

    private AggregationNode buildAggregation(PlanBuilder planBuilder, Predicate<VariableReferenceExpression> predicate) {
        VariableReferenceExpression variable = planBuilder.variable("avg");
        VariableReferenceExpression variable2 = planBuilder.variable("input");
        VariableReferenceExpression variable3 = planBuilder.variable("key");
        VariableReferenceExpression variable4 = planBuilder.variable("keyHash");
        VariableReferenceExpression variable5 = planBuilder.variable("mask");
        List list = (List) ImmutableList.of(variable2, variable3, variable4, variable5, planBuilder.variable("unused")).stream().filter(predicate).collect(ImmutableList.toImmutableList());
        return planBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.singleGroupingSet(variable3).addAggregation(variable, PlanBuilder.expression("avg(input)"), (List<Type>) ImmutableList.of(BigintType.BIGINT), variable5).hashVariable(variable4).source(planBuilder.values((List<VariableReferenceExpression>) list, (List<List<RowExpression>>) ImmutableList.of()));
        });
    }
}
