package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.Plugin;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.WindowNode;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.class */
public class TestPushdownFilterIntoWindow extends BaseRuleTest {
    public TestPushdownFilterIntoWindow() {
        super(new Plugin[0]);
    }

    @Test
    public void testEliminateFilter() {
        assertEliminateFilter("row_number");
        assertEliminateFilter("rank");
    }

    private void assertEliminateFilter(String str) {
        ResolvedFunction resolveBuiltinFunction = tester().getMetadata().resolveBuiltinFunction(str, TypeSignatureProvider.fromTypes(new Type[0]));
        tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("rank_1");
            Symbol symbol2 = planBuilder.symbol("a", BigintType.BIGINT);
            return planBuilder.filter(PlanBuilder.expression("rank_1 < cast(100 as bigint)"), planBuilder.window(new DataOrganizationSpecification(ImmutableList.of(symbol2), Optional.of(new OrderingScheme(ImmutableList.of(symbol2), ImmutableMap.of(symbol2, SortOrder.ASC_NULLS_FIRST)))), ImmutableMap.of(symbol, newWindowNodeFunction(resolveBuiltinFunction, symbol2)), planBuilder.values(planBuilder.symbol("a"))));
        }).matches(PlanMatchPattern.topNRanking(builder -> {
            builder.maxRankingPerPartition(99).partial(false);
        }, PlanMatchPattern.values("a")));
    }

    @Test
    public void testKeepFilter() {
        assertKeepFilter("row_number");
        assertKeepFilter("rank");
    }

    private void assertKeepFilter(String str) {
        ResolvedFunction resolveBuiltinFunction = tester().getMetadata().resolveBuiltinFunction(str, TypeSignatureProvider.fromTypes(new Type[0]));
        tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("row_number_1");
            Symbol symbol2 = planBuilder.symbol("a", BigintType.BIGINT);
            return planBuilder.filter(PlanBuilder.expression("cast(3 as bigint) < row_number_1 and row_number_1 < cast(100 as bigint)"), planBuilder.window(new DataOrganizationSpecification(ImmutableList.of(symbol2), Optional.of(new OrderingScheme(ImmutableList.of(symbol2), ImmutableMap.of(symbol2, SortOrder.ASC_NULLS_FIRST)))), ImmutableMap.of(symbol, newWindowNodeFunction(resolveBuiltinFunction, symbol2)), planBuilder.values(planBuilder.symbol("a"))));
        }).matches(PlanMatchPattern.filter("cast(3 as bigint) < row_number_1 and row_number_1 < cast(100 as bigint)", PlanMatchPattern.topNRanking(builder -> {
            builder.partial(false).maxRankingPerPartition(99).specification(ImmutableList.of("a"), ImmutableList.of("a"), ImmutableMap.of("a", SortOrder.ASC_NULLS_FIRST));
        }, PlanMatchPattern.values("a")).withAlias("row_number_1", new TopNRankingSymbolMatcher())));
        tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("row_number_1");
            Symbol symbol2 = planBuilder2.symbol("a", BigintType.BIGINT);
            return planBuilder2.filter(PlanBuilder.expression("row_number_1 < cast(100 as bigint) and a = BIGINT '1'"), planBuilder2.window(new DataOrganizationSpecification(ImmutableList.of(symbol2), Optional.of(new OrderingScheme(ImmutableList.of(symbol2), ImmutableMap.of(symbol2, SortOrder.ASC_NULLS_FIRST)))), ImmutableMap.of(symbol, newWindowNodeFunction(resolveBuiltinFunction, symbol2)), planBuilder2.values(planBuilder2.symbol("a"))));
        }).matches(PlanMatchPattern.filter("a = BIGINT '1'", PlanMatchPattern.topNRanking(builder2 -> {
            builder2.partial(false).maxRankingPerPartition(99).specification(ImmutableList.of("a"), ImmutableList.of("a"), ImmutableMap.of("a", SortOrder.ASC_NULLS_FIRST));
        }, PlanMatchPattern.values("a")).withAlias("row_number_1", new TopNRankingSymbolMatcher())));
    }

    @Test
    public void testNoUpperBound() {
        assertNoUpperBound("row_number");
        assertNoUpperBound("rank");
    }

    private void assertNoUpperBound(String str) {
        ResolvedFunction resolveBuiltinFunction = tester().getMetadata().resolveBuiltinFunction(str, TypeSignatureProvider.fromTypes(new Type[0]));
        tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("row_number_1");
            Symbol symbol2 = planBuilder.symbol("a");
            return planBuilder.filter(PlanBuilder.expression("cast(3 as bigint) < row_number_1"), planBuilder.window(new DataOrganizationSpecification(ImmutableList.of(symbol2), Optional.of(new OrderingScheme(ImmutableList.of(symbol2), ImmutableMap.of(symbol2, SortOrder.ASC_NULLS_FIRST)))), ImmutableMap.of(symbol, newWindowNodeFunction(resolveBuiltinFunction, symbol2)), planBuilder.values(symbol2)));
        }).doesNotFire();
    }

    private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolvedFunction, Symbol symbol) {
        return new WindowNode.Function(resolvedFunction, ImmutableList.of(symbol.toSymbolReference()), WindowNode.Frame.DEFAULT_FRAME, false);
    }
}
