package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.SessionTestUtils;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.sanity.TypeValidator;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.WindowFrame;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingMetadata;
import java.util.Optional;
import java.util.UUID;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/TestTypeValidator.class */
public class TestTypeValidator {
    private static final TypeValidator TYPE_VALIDATOR = new TypeValidator();
    private final TestingFunctionResolution functionResolution = new TestingFunctionResolution();
    private final SymbolAllocator symbolAllocator = new SymbolAllocator();
    private final Symbol columnA = this.symbolAllocator.newSymbol("a", BigintType.BIGINT);
    private final Symbol columnB = this.symbolAllocator.newSymbol("b", IntegerType.INTEGER);
    private final Symbol columnC = this.symbolAllocator.newSymbol("c", DoubleType.DOUBLE);
    private final Symbol columnD = this.symbolAllocator.newSymbol("d", DateType.DATE);
    private final Symbol columnE = this.symbolAllocator.newSymbol("e", VarcharType.createVarcharType(3));
    private final TableScanNode baseTableScan = new TableScanNode(newId(), TestingHandles.TEST_TABLE_HANDLE, ImmutableList.copyOf(ImmutableMap.builder().put(this.columnA, new TestingMetadata.TestingColumnHandle("a")).put(this.columnB, new TestingMetadata.TestingColumnHandle("b")).put(this.columnC, new TestingMetadata.TestingColumnHandle("c")).put(this.columnD, new TestingMetadata.TestingColumnHandle("d")).put(this.columnE, new TestingMetadata.TestingColumnHandle("e")).buildOrThrow().keySet()), ImmutableMap.builder().put(this.columnA, new TestingMetadata.TestingColumnHandle("a")).put(this.columnB, new TestingMetadata.TestingColumnHandle("b")).put(this.columnC, new TestingMetadata.TestingColumnHandle("c")).put(this.columnD, new TestingMetadata.TestingColumnHandle("d")).put(this.columnE, new TestingMetadata.TestingColumnHandle("e")).buildOrThrow(), TupleDomain.all(), Optional.empty(), false, Optional.empty());

    @Test
    public void testValidProject() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), TypeSignatureTranslator.toSqlType(BigintType.BIGINT));
        Cast cast2 = new Cast(this.columnC.toSymbolReference(), TypeSignatureTranslator.toSqlType(BigintType.BIGINT));
        assertTypesValid(new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(cast2, BigintType.BIGINT), cast2).build()));
    }

    @Test
    public void testValidUnion() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("output", DateType.DATE);
        ImmutableListMultimap build = ImmutableListMultimap.builder().put(newSymbol, this.columnD).put(newSymbol, this.columnD).build();
        assertTypesValid(new UnionNode(newId(), ImmutableList.of(this.baseTableScan, this.baseTableScan), build, ImmutableList.copyOf(build.keySet())));
    }

    @Test
    public void testValidWindow() {
        assertTypesValid(new WindowNode(newId(), this.baseTableScan, new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new WindowNode.Function(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnC.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), false)), Optional.empty(), ImmutableSet.of(), 0));
    }

    @Test
    public void testValidAggregation() {
        assertTypesValid(AggregationNode.singleAggregation(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB))));
    }

    @Test
    public void testValidTypeOnlyCoercion() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), TypeSignatureTranslator.toSqlType(BigintType.BIGINT));
        assertTypesValid(new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(this.columnE.toSymbolReference(), VarcharType.VARCHAR), this.columnE.toSymbolReference()).build()));
    }

    @Test
    public void testInvalidProject() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), TypeSignatureTranslator.toSqlType(IntegerType.INTEGER));
        ProjectNode projectNode = new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(cast, IntegerType.INTEGER), new Cast(this.columnA.toSymbolReference(), TypeSignatureTranslator.toSqlType(IntegerType.INTEGER))).build());
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(projectNode);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer");
    }

    @Test
    public void testInvalidAggregationFunctionCall() {
        AggregationNode singleAggregation = AggregationNode.singleAggregation(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnA.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB)));
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(singleAggregation);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint");
    }

    @Test
    public void testInvalidAggregationFunctionSignature() {
        AggregationNode singleAggregation = AggregationNode.singleAggregation(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", BigintType.BIGINT), new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB)));
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(singleAggregation);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be bigint, but the actual type is double");
    }

    @Test
    public void testInvalidWindowFunctionCall() {
        WindowNode windowNode = new WindowNode(newId(), this.baseTableScan, new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new WindowNode.Function(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnA.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), false)), Optional.empty(), ImmutableSet.of(), 0);
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(windowNode);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint");
    }

    @Test
    public void testInvalidWindowFunctionSignature() {
        WindowNode windowNode = new WindowNode(newId(), this.baseTableScan, new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(this.symbolAllocator.newSymbol("sum", BigintType.BIGINT), new WindowNode.Function(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE})), ImmutableList.of(this.columnC.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), false)), Optional.empty(), ImmutableSet.of(), 0);
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(windowNode);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be bigint, but the actual type is double");
    }

    @Test
    public void testInvalidUnion() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("output", DateType.DATE);
        ImmutableListMultimap build = ImmutableListMultimap.builder().put(newSymbol, this.columnD).put(newSymbol, this.columnA).build();
        UnionNode unionNode = new UnionNode(newId(), ImmutableList.of(this.baseTableScan, this.baseTableScan), build, ImmutableList.copyOf(build.keySet()));
        Assertions.assertThatThrownBy(() -> {
            assertTypesValid(unionNode);
        }).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint");
    }

    private void assertTypesValid(PlanNode planNode) {
        TYPE_VALIDATOR.validate(planNode, SessionTestUtils.TEST_SESSION, TestingPlannerContext.PLANNER_CONTEXT, TypeAnalyzer.createTestingTypeAnalyzer(TestingPlannerContext.PLANNER_CONTEXT), this.symbolAllocator.getTypes(), WarningCollector.NOOP);
    }

    private static PlanNodeId newId() {
        return new PlanNodeId(UUID.randomUUID().toString());
    }
}
