package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.class */
public class TestImplementFilteredAggregations extends BaseRuleTest {
    public TestImplementFilteredAggregations() {
        super(new Plugin[0]);
    }

    @Test
    public void testFilterToMask() {
        tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("g");
            Symbol symbol3 = planBuilder.symbol("filter", BooleanType.BOOLEAN);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol2).addAggregation(planBuilder.symbol("sum"), functionWithFilter("sum", symbol, Optional.of(symbol3)), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(symbol, symbol2, symbol3));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("g"), ImmutableMap.of(Optional.of("sum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("filter"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter("true", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a"), "g", PlanMatchPattern.expression("g"), "filter", PlanMatchPattern.expression("filter")), PlanMatchPattern.values("a", "g", "filter")))));
    }

    @Test
    public void testCombineMaskAndFilter() {
        tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("g");
            Symbol symbol3 = planBuilder.symbol("mask", BooleanType.BOOLEAN);
            Symbol symbol4 = planBuilder.symbol("filter", BooleanType.BOOLEAN);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol2).addAggregation(planBuilder.symbol("sum"), (Expression) functionWithFilter("sum", symbol, Optional.of(symbol4)), (List<Type>) ImmutableList.of(BigintType.BIGINT), symbol3).source(planBuilder.values(symbol, symbol2, symbol3, symbol4));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("g"), ImmutableMap.of(Optional.of("sum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("new_mask"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter("true", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a"), "g", PlanMatchPattern.expression("g"), "mask", PlanMatchPattern.expression("mask"), "filter", PlanMatchPattern.expression("filter"), "new_mask", PlanMatchPattern.expression("mask AND filter")), PlanMatchPattern.values("a", "g", "mask", "filter")))));
    }

    @Test
    public void testWithFilterPushdown() {
        tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("g");
            Symbol symbol3 = planBuilder.symbol("filter", BooleanType.BOOLEAN);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("sum"), functionWithFilter("sum", symbol, Optional.of(symbol3)), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(symbol, symbol2, symbol3));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("sum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("filter"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter("filter", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a"), "g", PlanMatchPattern.expression("g"), "filter", PlanMatchPattern.expression("filter")), PlanMatchPattern.values("a", "g", "filter")))));
    }

    @Test
    public void testWithMultipleAggregations() {
        tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("g");
            Symbol symbol3 = planBuilder.symbol("filter", BooleanType.BOOLEAN);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("sum"), functionWithFilter("sum", symbol, Optional.of(symbol3)), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("avg"), functionWithFilter("avg", symbol, Optional.empty()), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(symbol, symbol2, symbol3));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("sum"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a")), Optional.of("avg"), PlanMatchPattern.functionCall("avg", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("filter"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter("true", PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a"), "g", PlanMatchPattern.expression("g"), "filter", PlanMatchPattern.expression("filter")), PlanMatchPattern.values("a", "g", "filter")))));
    }

    private FunctionCall functionWithFilter(String str, Symbol symbol, Optional<Symbol> optional) {
        return new FunctionCall(Optional.empty(), QualifiedName.of(str), Optional.empty(), optional.map((v0) -> {
            return v0.toSymbolReference();
        }), Optional.empty(), false, Optional.empty(), Optional.empty(), ImmutableList.of(symbol.toSymbolReference()));
    }
}
