package io.trino.sql.planner.sanity;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.TableHandle;
import io.trino.plugin.tpch.TpchColumnHandle;
import io.trino.plugin.tpch.TpchTableHandle;
import io.trino.plugin.tpch.TpchTransactionHandle;
import io.trino.spi.type.BigintType;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/sanity/TestValidateStreamingAggregations.class */
public class TestValidateStreamingAggregations extends BasePlanTest {
    private PlannerContext plannerContext;
    private TypeAnalyzer typeAnalyzer;
    private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
    private TableHandle nationTableHandle;

    @BeforeAll
    public void setup() {
        this.plannerContext = getQueryRunner().getPlannerContext();
        this.typeAnalyzer = TypeAnalyzer.createTestingTypeAnalyzer(this.plannerContext);
        this.nationTableHandle = new TableHandle(getCurrentCatalogHandle(), new TpchTableHandle("sf1", "nation", 1.0d), TpchTransactionHandle.INSTANCE);
    }

    @Test
    public void testValidateSuccessful() {
        validatePlan(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.step(AggregationNode.Step.SINGLE).singleGroupingSet(planBuilder.symbol("nationkey")).source(planBuilder.tableScan(this.nationTableHandle, ImmutableList.of(planBuilder.symbol("nationkey", BigintType.BIGINT)), ImmutableMap.of(planBuilder.symbol("nationkey", BigintType.BIGINT), new TpchColumnHandle("nationkey", BigintType.BIGINT))));
            });
        });
        validatePlan(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.step(AggregationNode.Step.SINGLE).singleGroupingSet(planBuilder2.symbol("unique"), planBuilder2.symbol("nationkey")).preGroupedSymbols(planBuilder2.symbol("unique"), planBuilder2.symbol("nationkey")).source(planBuilder2.assignUniqueId(planBuilder2.symbol("unique"), planBuilder2.tableScan(this.nationTableHandle, ImmutableList.of(planBuilder2.symbol("nationkey", BigintType.BIGINT)), ImmutableMap.of(planBuilder2.symbol("nationkey", BigintType.BIGINT), new TpchColumnHandle("nationkey", BigintType.BIGINT)))));
            });
        });
    }

    @Test
    public void testValidateFailed() {
        Assertions.assertThatThrownBy(() -> {
            validatePlan(planBuilder -> {
                return planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.step(AggregationNode.Step.SINGLE).singleGroupingSet(planBuilder.symbol("nationkey")).preGroupedSymbols(planBuilder.symbol("nationkey")).source(planBuilder.tableScan(this.nationTableHandle, ImmutableList.of(planBuilder.symbol("nationkey", BigintType.BIGINT)), ImmutableMap.of(planBuilder.symbol("nationkey", BigintType.BIGINT), new TpchColumnHandle("nationkey", BigintType.BIGINT))));
                });
            });
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("Streaming aggregation with input not grouped on the grouping keys");
    }

    private void validatePlan(Function<PlanBuilder, PlanNode> function) {
        getQueryRunner().inTransaction(session -> {
            PlanBuilder planBuilder = new PlanBuilder(this.idAllocator, this.plannerContext.getMetadata(), session);
            PlanNode planNode = (PlanNode) function.apply(planBuilder);
            TypeProvider types = planBuilder.getTypes();
            session.getCatalog().ifPresent(str -> {
                this.plannerContext.getMetadata().getCatalogHandle(session, str);
            });
            new ValidateStreamingAggregations().validate(planNode, session, this.plannerContext, this.typeAnalyzer, types, WarningCollector.NOOP);
            return null;
        });
    }
}
