/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.type.DateTimes;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.temporal.TemporalAdjusters;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;

public class UnwrapYearInComparison
extends ExpressionRewriteRuleSet {
    public UnwrapYearInComparison(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        super(UnwrapYearInComparison.createRewrite(plannerContext, typeAnalyzer));
    }

    private static ExpressionRewriteRuleSet.ExpressionRewriter createRewrite(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        Objects.requireNonNull(plannerContext, "plannerContext is null");
        Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        return (expression, context) -> UnwrapYearInComparison.unwrapYear(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression);
    }

    private static Expression unwrapYear(Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new Visitor(plannerContext, typeAnalyzer, session, types), (Expression)expression);
    }

    private static Object calculateRangeStartInclusive(int year, Type type) {
        if (type == DateType.DATE) {
            LocalDate firstDay = LocalDate.ofYearDay(year, 1);
            return firstDay.toEpochDay();
        }
        if (type instanceof TimestampType) {
            TimestampType timestampType = (TimestampType)type;
            long yearStartEpochSecond = LocalDateTime.of(year, 1, 1, 0, 0).toEpochSecond(ZoneOffset.UTC);
            long yearStartEpochMicros = Math.multiplyExact(yearStartEpochSecond, 1000000);
            if (timestampType.isShort()) {
                return yearStartEpochMicros;
            }
            return new LongTimestamp(yearStartEpochMicros, 0);
        }
        throw new UnsupportedOperationException("Unsupported type: " + type);
    }

    @VisibleForTesting
    public static Object calculateRangeEndInclusive(int year, Type type) {
        if (type == DateType.DATE) {
            LocalDate lastDay = LocalDate.ofYearDay(year, 1).with(TemporalAdjusters.lastDayOfYear());
            return lastDay.toEpochDay();
        }
        if (type instanceof TimestampType) {
            TimestampType timestampType = (TimestampType)type;
            long nextYearStartEpochSecond = LocalDateTime.of(year + 1, 1, 1, 0, 0).toEpochSecond(ZoneOffset.UTC);
            long nextYearStartEpochMicros = Math.multiplyExact(nextYearStartEpochSecond, 1000000);
            if (timestampType.isShort()) {
                return nextYearStartEpochMicros - DateTimes.scaleFactor(timestampType.getPrecision(), 6);
            }
            int picosOfMicro = Math.toIntExact(1000000L - DateTimes.scaleFactor(timestampType.getPrecision(), 12));
            return new LongTimestamp(nextYearStartEpochMicros - 1L, picosOfMicro);
        }
        throw new UnsupportedOperationException("Unsupported type: " + type);
    }

    private static class Visitor
    extends ExpressionRewriter<Void> {
        private final PlannerContext plannerContext;
        private final TypeAnalyzer typeAnalyzer;
        private final Session session;
        private final TypeProvider types;
        private final LiteralEncoder literalEncoder;

        public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.literalEncoder = new LiteralEncoder(plannerContext);
        }

        public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            ComparisonExpression expression = (ComparisonExpression)treeRewriter.defaultRewrite((Expression)node, null);
            return this.unwrapYear(expression);
        }

        public Expression rewriteInPredicate(InPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            FunctionCall call;
            InPredicate inPredicate = (InPredicate)treeRewriter.defaultRewrite((Expression)node, null);
            Expression value = inPredicate.getValue();
            Expression valueList = inPredicate.getValueList();
            if (!(value instanceof FunctionCall && ResolvedFunction.extractFunctionName((call = (FunctionCall)value).getName()).equals("year") && call.getArguments().size() == 1 && valueList instanceof InListExpression)) {
                return inPredicate;
            }
            InListExpression inListExpression = (InListExpression)valueList;
            ImmutableList.Builder comparisonExpressions = ImmutableList.builderWithExpectedSize((int)inListExpression.getValues().size());
            for (Expression rightExpression : inListExpression.getValues()) {
                ComparisonExpression comparisonExpression = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, value, rightExpression);
                Expression unwrappedExpression = this.unwrapYear(comparisonExpression);
                if (unwrappedExpression == comparisonExpression) {
                    return inPredicate;
                }
                comparisonExpressions.add((Object)unwrappedExpression);
            }
            return ExpressionUtils.or((Collection<Expression>)comparisonExpressions.build());
        }

        private Expression unwrapYear(ComparisonExpression expression) {
            FunctionCall call;
            Expression expression2 = expression.getLeft();
            if (!(expression2 instanceof FunctionCall) || !ResolvedFunction.extractFunctionName((call = (FunctionCall)expression2).getName()).equals("year") || call.getArguments().size() != 1) {
                return expression;
            }
            Map<NodeRef<Expression>, Type> expressionTypes = this.typeAnalyzer.getTypes(this.session, this.types, (Expression)expression);
            Expression argument = (Expression)Iterables.getOnlyElement((Iterable)call.getArguments());
            Type argumentType = expressionTypes.get(NodeRef.of((Node)argument));
            Object right = new ExpressionInterpreter(expression.getRight(), this.plannerContext, this.session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE);
            if (right == null || right instanceof NullLiteral) {
                return switch (expression.getOperator()) {
                    default -> throw new IncompatibleClassChangeError();
                    case ComparisonExpression.Operator.EQUAL, ComparisonExpression.Operator.NOT_EQUAL, ComparisonExpression.Operator.LESS_THAN, ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, ComparisonExpression.Operator.GREATER_THAN, ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL -> new Cast((Expression)new NullLiteral(), TypeSignatureTranslator.toSqlType((Type)BooleanType.BOOLEAN));
                    case ComparisonExpression.Operator.IS_DISTINCT_FROM -> new IsNotNullPredicate(argument);
                };
            }
            if (right instanceof Expression) {
                return expression;
            }
            if (argumentType instanceof TimestampWithTimeZoneType) {
                return expression;
            }
            if (argumentType != DateType.DATE && !(argumentType instanceof TimestampType)) {
                return expression;
            }
            int year = Math.toIntExact((Long)right);
            return switch (expression.getOperator()) {
                default -> throw new IncompatibleClassChangeError();
                case ComparisonExpression.Operator.EQUAL -> this.between(argument, argumentType, UnwrapYearInComparison.calculateRangeStartInclusive(year, argumentType), UnwrapYearInComparison.calculateRangeEndInclusive(year, argumentType));
                case ComparisonExpression.Operator.NOT_EQUAL -> new NotExpression((Expression)this.between(argument, argumentType, UnwrapYearInComparison.calculateRangeStartInclusive(year, argumentType), UnwrapYearInComparison.calculateRangeEndInclusive(year, argumentType)));
                case ComparisonExpression.Operator.IS_DISTINCT_FROM -> ExpressionUtils.or(new Expression[]{new IsNullPredicate(argument), new NotExpression((Expression)this.between(argument, argumentType, UnwrapYearInComparison.calculateRangeStartInclusive(year, argumentType), UnwrapYearInComparison.calculateRangeEndInclusive(year, argumentType)))});
                case ComparisonExpression.Operator.LESS_THAN -> new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, argument, this.toExpression(UnwrapYearInComparison.calculateRangeStartInclusive(year, argumentType), argumentType));
                case ComparisonExpression.Operator.LESS_THAN_OR_EQUAL -> new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, argument, this.toExpression(UnwrapYearInComparison.calculateRangeEndInclusive(year, argumentType), argumentType));
                case ComparisonExpression.Operator.GREATER_THAN -> new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, argument, this.toExpression(UnwrapYearInComparison.calculateRangeEndInclusive(year, argumentType), argumentType));
                case ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL -> new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, argument, this.toExpression(UnwrapYearInComparison.calculateRangeStartInclusive(year, argumentType), argumentType));
            };
        }

        private BetweenPredicate between(Expression argument, Type type, Object minInclusive, Object maxInclusive) {
            return new BetweenPredicate(argument, this.toExpression(minInclusive, type), this.toExpression(maxInclusive, type));
        }

        private Expression toExpression(Object value, Type type) {
            return this.literalEncoder.toExpression(this.session, value, type);
        }
    }
}

