package io.trino.cost;

import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Literal;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.SymbolReference;
import io.trino.util.MoreMath;
import java.util.Iterator;
import java.util.Objects;
import java.util.OptionalDouble;

/* loaded from: input_file:io/trino/cost/ScalarStatsCalculator.class */
public class ScalarStatsCalculator {
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.trino.cost.ScalarStatsCalculator$1, reason: invalid class name */
    /* loaded from: input_file:io/trino/cost/ScalarStatsCalculator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$trino$sql$tree$ArithmeticUnaryExpression$Sign;
        static final /* synthetic */ int[] $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator = new int[ArithmeticBinaryExpression.Operator.values().length];

        static {
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[ArithmeticBinaryExpression.Operator.ADD.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[ArithmeticBinaryExpression.Operator.SUBTRACT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[ArithmeticBinaryExpression.Operator.MULTIPLY.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[ArithmeticBinaryExpression.Operator.DIVIDE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[ArithmeticBinaryExpression.Operator.MODULUS.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$io$trino$sql$tree$ArithmeticUnaryExpression$Sign = new int[ArithmeticUnaryExpression.Sign.values().length];
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticUnaryExpression$Sign[ArithmeticUnaryExpression.Sign.PLUS.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$ArithmeticUnaryExpression$Sign[ArithmeticUnaryExpression.Sign.MINUS.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    /* loaded from: input_file:io/trino/cost/ScalarStatsCalculator$Visitor.class */
    private class Visitor extends AstVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final LiteralInterpreter literalInterpreter;
        private final TypeProvider types;

        Visitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
            this.input = planNodeStatsEstimate;
            this.session = session;
            this.literalInterpreter = new LiteralInterpreter(ScalarStatsCalculator.this.plannerContext, session);
            this.types = typeProvider;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitNode(Node node, Void r4) {
            return SymbolStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitSymbolReference(SymbolReference symbolReference, Void r5) {
            return this.input.getSymbolStatistics(Symbol.from(symbolReference));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitNullLiteral(NullLiteral nullLiteral, Void r4) {
            return ScalarStatsCalculator.nullStatsEstimate();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitLiteral(Literal literal, Void r8) {
            Type analyze = ExpressionAnalyzer.createConstantAnalyzer(ScalarStatsCalculator.this.plannerContext, new AllowAllAccessControl(), this.session, ImmutableMap.of(), WarningCollector.NOOP).analyze(literal, Scope.create());
            OptionalDouble statsRepresentation = StatsUtil.toStatsRepresentation(analyze, this.literalInterpreter.evaluate(literal, analyze));
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(1.0d);
            if (statsRepresentation.isPresent()) {
                distinctValuesCount.setLowValue(statsRepresentation.getAsDouble());
                distinctValuesCount.setHighValue(statsRepresentation.getAsDouble());
            }
            return distinctValuesCount.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitFunctionCall(FunctionCall functionCall, Void r9) {
            Object optimize = new ExpressionInterpreter(functionCall, ScalarStatsCalculator.this.plannerContext, this.session, ExpressionUtils.getExpressionTypes(ScalarStatsCalculator.this.plannerContext, this.session, functionCall, this.types)).optimize(NoOpSymbolResolver.INSTANCE);
            return (optimize == null || (optimize instanceof NullLiteral)) ? ScalarStatsCalculator.nullStatsEstimate() : (!(optimize instanceof Expression) || ExpressionUtils.isEffectivelyLiteral(ScalarStatsCalculator.this.plannerContext, this.session, (Expression) optimize)) ? SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(1.0d).build() : SymbolStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitCast(Cast cast, Void r8) {
            SymbolStatsEstimate symbolStatsEstimate = (SymbolStatsEstimate) process(cast.getExpression());
            double distinctValuesCount = symbolStatsEstimate.getDistinctValuesCount();
            double lowValue = symbolStatsEstimate.getLowValue();
            double highValue = symbolStatsEstimate.getHighValue();
            if (isIntegralType(ScalarStatsCalculator.this.typeAnalyzer.getType(this.session, this.types, cast))) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double d = (highValue - lowValue) + 1.0d;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > d) {
                        distinctValuesCount = d;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(symbolStatsEstimate.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private boolean isIntegralType(Type type) {
            if ((type instanceof BigintType) || (type instanceof IntegerType) || (type instanceof SmallintType) || (type instanceof TinyintType)) {
                return true;
            }
            return (type instanceof DecimalType) && ((DecimalType) type).getScale() == 0;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression arithmeticUnaryExpression, Void r6) {
            SymbolStatsEstimate symbolStatsEstimate = (SymbolStatsEstimate) process(arithmeticUnaryExpression.getValue());
            switch (AnonymousClass1.$SwitchMap$io$trino$sql$tree$ArithmeticUnaryExpression$Sign[arithmeticUnaryExpression.getSign().ordinal()]) {
                case 1:
                    return symbolStatsEstimate;
                case 2:
                    return SymbolStatsEstimate.buildFrom(symbolStatsEstimate).setLowValue(-symbolStatsEstimate.getHighValue()).setHighValue(-symbolStatsEstimate.getLowValue()).build();
                default:
                    throw new IllegalStateException("Unexpected sign: " + arithmeticUnaryExpression.getSign());
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression arithmeticBinaryExpression, Void r11) {
            Objects.requireNonNull(arithmeticBinaryExpression, "node is null");
            SymbolStatsEstimate symbolStatsEstimate = (SymbolStatsEstimate) process(arithmeticBinaryExpression.getLeft());
            SymbolStatsEstimate symbolStatsEstimate2 = (SymbolStatsEstimate) process(arithmeticBinaryExpression.getRight());
            if (symbolStatsEstimate.isUnknown() || symbolStatsEstimate2.isUnknown()) {
                return SymbolStatsEstimate.unknown();
            }
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(symbolStatsEstimate.getAverageRowSize(), symbolStatsEstimate2.getAverageRowSize())).setNullsFraction((symbolStatsEstimate.getNullsFraction() + symbolStatsEstimate2.getNullsFraction()) - (symbolStatsEstimate.getNullsFraction() * symbolStatsEstimate2.getNullsFraction())).setDistinctValuesCount(MoreMath.min(symbolStatsEstimate.getDistinctValuesCount() * symbolStatsEstimate2.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double lowValue = symbolStatsEstimate.getLowValue();
            double highValue = symbolStatsEstimate.getHighValue();
            double lowValue2 = symbolStatsEstimate2.getLowValue();
            double highValue2 = symbolStatsEstimate2.getHighValue();
            if (Double.isNaN(lowValue) || Double.isNaN(highValue) || Double.isNaN(lowValue2) || Double.isNaN(highValue2)) {
                distinctValuesCount.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (arithmeticBinaryExpression.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE && lowValue2 < 0.0d && highValue2 > 0.0d) {
                distinctValuesCount.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (arithmeticBinaryExpression.getOperator() == ArithmeticBinaryExpression.Operator.MODULUS) {
                double max = MoreMath.max(Math.abs(lowValue2), Math.abs(highValue2));
                if (highValue <= 0.0d) {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(0.0d);
                } else if (lowValue >= 0.0d) {
                    distinctValuesCount.setLowValue(0.0d).setHighValue(MoreMath.min(max, highValue));
                } else {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(MoreMath.min(max, highValue));
                }
            } else {
                double operate = operate(arithmeticBinaryExpression.getOperator(), lowValue, lowValue2);
                double operate2 = operate(arithmeticBinaryExpression.getOperator(), lowValue, highValue2);
                double operate3 = operate(arithmeticBinaryExpression.getOperator(), highValue, lowValue2);
                double operate4 = operate(arithmeticBinaryExpression.getOperator(), highValue, highValue2);
                distinctValuesCount.setLowValue(MoreMath.min(operate, operate2, operate3, operate4)).setHighValue(MoreMath.max(operate, operate2, operate3, operate4));
            }
            return distinctValuesCount.build();
        }

        private double operate(ArithmeticBinaryExpression.Operator operator, double d, double d2) {
            switch (AnonymousClass1.$SwitchMap$io$trino$sql$tree$ArithmeticBinaryExpression$Operator[operator.ordinal()]) {
                case 1:
                    return d + d2;
                case 2:
                    return d - d2;
                case 3:
                    return d * d2;
                case 4:
                    return d / d2;
                case 5:
                    return d % d2;
                default:
                    throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression coalesceExpression, Void r6) {
            Objects.requireNonNull(coalesceExpression, "node is null");
            SymbolStatsEstimate symbolStatsEstimate = null;
            Iterator it = coalesceExpression.getOperands().iterator();
            while (it.hasNext()) {
                SymbolStatsEstimate symbolStatsEstimate2 = (SymbolStatsEstimate) process((Expression) it.next());
                symbolStatsEstimate = symbolStatsEstimate != null ? estimateCoalesce(symbolStatsEstimate, symbolStatsEstimate2) : symbolStatsEstimate2;
            }
            return (SymbolStatsEstimate) Objects.requireNonNull(symbolStatsEstimate, "result is null");
        }

        private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate symbolStatsEstimate, SymbolStatsEstimate symbolStatsEstimate2) {
            return symbolStatsEstimate.getNullsFraction() == 0.0d ? symbolStatsEstimate : symbolStatsEstimate.getNullsFraction() == 1.0d ? symbolStatsEstimate2 : SymbolStatsEstimate.builder().setLowValue(MoreMath.min(symbolStatsEstimate.getLowValue(), symbolStatsEstimate2.getLowValue())).setHighValue(MoreMath.max(symbolStatsEstimate.getHighValue(), symbolStatsEstimate2.getHighValue())).setDistinctValuesCount(symbolStatsEstimate.getDistinctValuesCount() + MoreMath.min(symbolStatsEstimate2.getDistinctValuesCount(), this.input.getOutputRowCount() * symbolStatsEstimate.getNullsFraction())).setNullsFraction(symbolStatsEstimate.getNullsFraction() * symbolStatsEstimate2.getNullsFraction()).setAverageRowSize(MoreMath.max(symbolStatsEstimate.getAverageRowSize(), symbolStatsEstimate2.getAverageRowSize())).build();
        }
    }

    @Inject
    public ScalarStatsCalculator(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext cannot be null");
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    public SymbolStatsEstimate calculate(Expression expression, PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
        return (SymbolStatsEstimate) new Visitor(planNodeStatsEstimate, session, typeProvider).process(expression);
    }

    private static SymbolStatsEstimate nullStatsEstimate() {
        return SymbolStatsEstimate.builder().setDistinctValuesCount(0.0d).setNullsFraction(1.0d).build();
    }
}
