package io.trino.sql.planner.sanity;

import com.google.common.base.Preconditions;
import com.google.common.collect.ListMultimap;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.SimplePlanVisitor;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.sanity.PlanSanityChecker;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import io.trino.type.FunctionType;
import io.trino.type.TypeCoercion;
import io.trino.type.UnknownType;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/sanity/TypeValidator.class */
public final class TypeValidator implements PlanSanityChecker.Checker {

    /* loaded from: input_file:io/trino/sql/planner/sanity/TypeValidator$Visitor.class */
    private static class Visitor extends SimplePlanVisitor<Void> {
        private final Session session;
        private final TypeCoercion typeCoercion;
        private final TypeAnalyzer typeAnalyzer;
        private final TypeProvider types;

        public Visitor(Session session, TypeManager typeManager, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider) {
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            Objects.requireNonNull(typeManager);
            this.typeCoercion = new TypeCoercion(typeManager::getType);
            this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitAggregation(AggregationNode aggregationNode, Void r7) {
            visitPlan((PlanNode) aggregationNode, (AggregationNode) r7);
            AggregationNode.Step step = aggregationNode.getStep();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                Symbol key = entry.getKey();
                AggregationNode.Aggregation value = entry.getValue();
                switch (step) {
                    case SINGLE:
                        checkSignature(key, value.getResolvedFunction().getSignature());
                        checkCall(key, value.getResolvedFunction().getSignature(), value.getArguments());
                        break;
                    case FINAL:
                        checkSignature(key, value.getResolvedFunction().getSignature());
                        break;
                }
            }
            return null;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitWindow(WindowNode windowNode, Void r6) {
            visitPlan((PlanNode) windowNode, (WindowNode) r6);
            checkWindowFunctions(windowNode.getWindowFunctions());
            return null;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitProject(ProjectNode projectNode, Void r8) {
            visitPlan((PlanNode) projectNode, (ProjectNode) r8);
            for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
                Type type = this.types.get(entry.getKey());
                SymbolReference value = entry.getValue();
                if (value instanceof SymbolReference) {
                    verifyTypeSignature(entry.getKey(), type, this.types.get(Symbol.from(value)));
                } else {
                    verifyTypeSignature(entry.getKey(), type, this.typeAnalyzer.getType(this.session, this.types, entry.getValue()));
                }
            }
            return null;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitUnion(UnionNode unionNode, Void r8) {
            visitPlan((PlanNode) unionNode, (UnionNode) r8);
            ListMultimap<Symbol, Symbol> symbolMapping = unionNode.getSymbolMapping();
            for (Symbol symbol : symbolMapping.keySet()) {
                List list = symbolMapping.get(symbol);
                Type type = this.types.get(symbol);
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    verifyTypeSignature(symbol, type, this.types.get((Symbol) it.next()));
                }
            }
            return null;
        }

        private void checkWindowFunctions(Map<Symbol, WindowNode.Function> map) {
            map.forEach((symbol, function) -> {
                checkSignature(symbol, function.getResolvedFunction().getSignature());
                checkCall(symbol, function.getResolvedFunction().getSignature(), function.getArguments());
            });
        }

        private void checkSignature(Symbol symbol, BoundSignature boundSignature) {
            verifyTypeSignature(symbol, this.types.get(symbol), boundSignature.getReturnType());
        }

        private void checkCall(Symbol symbol, BoundSignature boundSignature, List<Expression> list) {
            verifyTypeSignature(symbol, this.types.get(symbol), boundSignature.getReturnType());
            Preconditions.checkArgument(boundSignature.getArgumentTypes().size() == list.size(), "expected %s arguments, but found %s arguments", boundSignature.getArgumentTypes().size(), list.size());
            for (int i = 0; i < list.size(); i++) {
                Type type = (Type) boundSignature.getArgumentTypes().get(i);
                if (!(type instanceof FunctionType)) {
                    verifyTypeSignature(symbol, type, this.typeAnalyzer.getType(this.session, this.types, list.get(i)));
                }
            }
        }

        private void verifyTypeSignature(Symbol symbol, Type type, Type type2) {
            if ((type2 instanceof UnknownType) || this.typeCoercion.isTypeOnlyCoercion(type2, type)) {
                return;
            }
            Preconditions.checkArgument(type.equals(type2), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol, type, type2);
        }
    }

    @Override // io.trino.sql.planner.sanity.PlanSanityChecker.Checker
    public void validate(PlanNode planNode, Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider, WarningCollector warningCollector) {
        planNode.accept(new Visitor(session, plannerContext.getTypeManager(), typeAnalyzer, typeProvider), null);
    }
}
