/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.pinot.query;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import io.trino.matching.Captures;
import io.trino.matching.Match;
import io.trino.matching.Pattern;
import io.trino.plugin.pinot.PinotErrorCode;
import io.trino.plugin.pinot.PinotException;
import io.trino.plugin.pinot.query.PinotPatterns;
import io.trino.plugin.pinot.query.PinotSqlFormatter;
import io.trino.plugin.pinot.query.PinotTransformFunctionTypeResolver;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SchemaTableName;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction;
import org.apache.pinot.core.operator.transform.transformer.datetime.BaseDateTimeTransformer;
import org.apache.pinot.core.operator.transform.transformer.datetime.DateTimeTransformerFactory;
import org.apache.pinot.core.operator.transform.transformer.datetime.EpochToEpochTransformer;
import org.apache.pinot.core.operator.transform.transformer.timeunit.TimeUnitTransformerFactory;
import org.apache.pinot.segment.spi.AggregationFunctionType;

public class PinotExpressionRewriter {
    private static final Map<TransformFunctionType, RewriteRule<FunctionContext>> FUNCTION_RULE_MAP;
    private static final Map<AggregationFunctionType, RewriteRule<FunctionContext>> AGGREGATION_FUNCTION_RULE_MAP;
    private static final RewriteRule<FunctionContext> DEFAULT_REWRITE_RULE;

    private PinotExpressionRewriter() {
    }

    public static ExpressionContext rewriteExpression(final SchemaTableName schemaTableName, ExpressionContext expressionContext, final Map<String, ColumnHandle> columnHandles) {
        Objects.requireNonNull(expressionContext, "expressionContext is null");
        Context context = new Context(){

            @Override
            public SchemaTableName getSchemaTableName() {
                return schemaTableName;
            }

            @Override
            public Map<String, ColumnHandle> getColumnHandles() {
                return columnHandles;
            }
        };
        return PinotExpressionRewriter.rewriteExpression(expressionContext, context);
    }

    private static ExpressionContext rewriteExpression(ExpressionContext expressionContext, Context context) {
        switch (expressionContext.getType()) {
            case LITERAL: {
                return expressionContext;
            }
            case IDENTIFIER: {
                return ExpressionContext.forIdentifier((String)PinotSqlFormatter.getColumnHandle(expressionContext.getIdentifier(), context.getSchemaTableName(), context.getColumnHandles()).getColumnName());
            }
            case FUNCTION: {
                return ExpressionContext.forFunction((FunctionContext)PinotExpressionRewriter.rewriteFunction(expressionContext.getFunction(), context));
            }
        }
        throw new PinotException(PinotErrorCode.PINOT_EXCEPTION, Optional.empty(), String.format("Unsupported expression type '%s'", expressionContext.getType()));
    }

    private static FunctionContext rewriteFunction(FunctionContext functionContext, Context context) {
        Optional<Object> result = Optional.empty();
        if (functionContext.getType() == FunctionContext.Type.TRANSFORM) {
            RewriteRule<FunctionContext> rule = FUNCTION_RULE_MAP.get(PinotTransformFunctionTypeResolver.getTransformFunctionType(functionContext).orElseThrow());
            if (rule != null) {
                result = PinotExpressionRewriter.applyRule(rule, functionContext, context);
            }
        } else {
            Preconditions.checkState((functionContext.getType() == FunctionContext.Type.AGGREGATION ? 1 : 0) != 0, (String)"Unexpected function type for '%s'", (Object)functionContext);
            RewriteRule<FunctionContext> rule = AGGREGATION_FUNCTION_RULE_MAP.get(AggregationFunctionType.getAggregationFunctionType((String)functionContext.getFunctionName()));
            if (rule != null) {
                result = PinotExpressionRewriter.applyRule(rule, functionContext, context);
            }
        }
        if (result.isPresent()) {
            return (FunctionContext)result.get();
        }
        result = PinotExpressionRewriter.applyRule(DEFAULT_REWRITE_RULE, functionContext, context);
        if (result.isPresent()) {
            return (FunctionContext)result.get();
        }
        throw new PinotException(PinotErrorCode.PINOT_EXCEPTION, Optional.empty(), String.format("Unsupported function expression '%s'", functionContext));
    }

    private static <T> Optional<T> applyRule(RewriteRule<T> rule, T object, Context context) {
        Iterator iterator = rule.getPattern().match(object).iterator();
        if (iterator.hasNext()) {
            Match match = (Match)iterator.next();
            return Optional.of(rule.rewrite(object, match.captures(), context));
        }
        return Optional.empty();
    }

    private static void verifyIsIdentifierOrFunction(ExpressionContext expressionContext) {
        Verify.verify((expressionContext.getType() == ExpressionContext.Type.IDENTIFIER || expressionContext.getType() == ExpressionContext.Type.FUNCTION ? 1 : 0) != 0);
    }

    private static void verifyTailArgumentsAllLiteral(List<ExpressionContext> arguments) {
        arguments.stream().skip(1L).forEach(argument -> Verify.verify((argument.getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0));
    }

    static {
        DEFAULT_REWRITE_RULE = new DefaultRewriteRule();
        HashMap<TransformFunctionType, RewriteRule<FunctionContext>> functionMap = new HashMap<TransformFunctionType, RewriteRule<FunctionContext>>();
        functionMap.put(TransformFunctionType.DATETIMECONVERT, new DateTimeConvertRewriteRule());
        functionMap.put(TransformFunctionType.TIMECONVERT, new TimeConvertRewriteRule());
        functionMap.put(TransformFunctionType.DATETRUNC, new DateTruncRewriteRule());
        FUNCTION_RULE_MAP = Maps.immutableEnumMap(functionMap);
        HashMap<AggregationFunctionType, CountStarRewriteRule> aggregationFunctionMap = new HashMap<AggregationFunctionType, CountStarRewriteRule>();
        aggregationFunctionMap.put(AggregationFunctionType.COUNT, new CountStarRewriteRule());
        AGGREGATION_FUNCTION_RULE_MAP = Maps.immutableEnumMap(aggregationFunctionMap);
    }

    private static interface Context {
        public SchemaTableName getSchemaTableName();

        public Map<String, ColumnHandle> getColumnHandles();
    }

    private static interface RewriteRule<T> {
        public Pattern<T> getPattern();

        public T rewrite(T var1, Captures var2, Context var3);
    }

    private static class DefaultRewriteRule
    implements RewriteRule<FunctionContext> {
        private DefaultRewriteRule() {
        }

        @Override
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.function();
        }

        @Override
        public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) {
            List arguments = (List)object.getArguments().stream().map(argument -> PinotExpressionRewriter.rewriteExpression(argument, context)).collect(ImmutableList.toImmutableList());
            return new FunctionContext(object.getType(), object.getFunctionName(), arguments);
        }
    }

    private static class DateTimeConvertRewriteRule
    implements RewriteRule<FunctionContext> {
        private DateTimeConvertRewriteRule() {
        }

        @Override
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo((Object)TransformFunctionType.DATETIMECONVERT));
        }

        @Override
        public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) {
            Verify.verify((object.getArguments().size() == 4 ? 1 : 0) != 0);
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext)object.getArguments().get(0));
            PinotExpressionRewriter.verifyTailArgumentsAllLiteral(object.getArguments());
            ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
            argumentsBuilder.add((Object)PinotExpressionRewriter.rewriteExpression((ExpressionContext)object.getArguments().get(0), context));
            String inputFormat = ((ExpressionContext)object.getArguments().get(1)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)inputFormat)));
            String outputFormat = ((ExpressionContext)object.getArguments().get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)outputFormat)));
            String granularity = ((ExpressionContext)object.getArguments().get(3)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            BaseDateTimeTransformer dateTimeTransformer = DateTimeTransformerFactory.getDateTimeTransformer((String)inputFormat, (String)outputFormat, (String)granularity);
            Preconditions.checkState((boolean)(dateTimeTransformer instanceof EpochToEpochTransformer), (Object)"Unsupported date format: simple date format not supported");
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)granularity)));
            return new FunctionContext(object.getType(), object.getFunctionName(), (List)argumentsBuilder.build());
        }
    }

    private static class TimeConvertRewriteRule
    implements RewriteRule<FunctionContext> {
        private TimeConvertRewriteRule() {
        }

        @Override
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo((Object)TransformFunctionType.TIMECONVERT));
        }

        @Override
        public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) {
            Verify.verify((object.getArguments().size() == 3 ? 1 : 0) != 0);
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext)object.getArguments().get(0));
            PinotExpressionRewriter.verifyTailArgumentsAllLiteral(object.getArguments());
            ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
            argumentsBuilder.add((Object)PinotExpressionRewriter.rewriteExpression((ExpressionContext)object.getArguments().get(0), context));
            String inputTimeUnitArgument = ((ExpressionContext)object.getArguments().get(1)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument);
            String outputTimeUnitArgument = ((ExpressionContext)object.getArguments().get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            TimeUnitTransformerFactory.getTimeUnitTransformer((TimeUnit)inputTimeUnit, (String)outputTimeUnitArgument);
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)inputTimeUnitArgument)));
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)outputTimeUnitArgument)));
            return new FunctionContext(object.getType(), object.getFunctionName(), (List)argumentsBuilder.build());
        }
    }

    private static class DateTruncRewriteRule
    implements RewriteRule<FunctionContext> {
        private DateTruncRewriteRule() {
        }

        @Override
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo((Object)TransformFunctionType.DATETRUNC));
        }

        @Override
        public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) {
            List arguments = object.getArguments();
            Preconditions.checkState((arguments.size() >= 2 && arguments.size() <= 5 ? 1 : 0) != 0, (String)"Between two to five arguments are required, example: %s", (Object)DateTruncTransformFunction.EXAMPLE_INVOCATION);
            ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
            Preconditions.checkState((((ExpressionContext)arguments.get(0)).getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0, (Object)"First argument must be a literal");
            String unit = ((ExpressionContext)arguments.get(0)).getLiteral().getValue().toString().toLowerCase(Locale.ENGLISH);
            argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)unit)));
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext)object.getArguments().get(1));
            ExpressionContext valueArgument = PinotExpressionRewriter.rewriteExpression((ExpressionContext)arguments.get(1), context);
            argumentsBuilder.add((Object)valueArgument);
            if (arguments.size() >= 3) {
                Preconditions.checkState((((ExpressionContext)arguments.get(2)).getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0, (String)"Unexpected 3rd argument: '%s'", arguments.get(2));
                String inputTimeUnitArgument = ((ExpressionContext)arguments.get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
                TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument);
                argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)inputTimeUnit.name())));
                if (arguments.size() >= 4) {
                    Preconditions.checkState((((ExpressionContext)arguments.get(3)).getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0, (String)"Unexpected 4th argument '%s'", arguments.get(3));
                    argumentsBuilder.add((Object)((ExpressionContext)arguments.get(3)));
                    if (arguments.size() >= 5) {
                        Preconditions.checkState((((ExpressionContext)arguments.get(4)).getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0, (String)"Unexpected 5th argument: '%s'", arguments.get(4));
                        String outputTimeUnitArgument = ((ExpressionContext)arguments.get(4)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
                        TimeUnit outputTimeUnit = TimeUnit.valueOf(outputTimeUnitArgument);
                        argumentsBuilder.add((Object)ExpressionContext.forLiteralContext((Literal)Literal.stringValue((String)outputTimeUnit.name())));
                    }
                }
            }
            return new FunctionContext(object.getType(), object.getFunctionName(), (List)argumentsBuilder.build());
        }
    }

    private static class CountStarRewriteRule
    implements RewriteRule<FunctionContext> {
        private CountStarRewriteRule() {
        }

        @Override
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.aggregationFunction().with(PinotPatterns.aggregationFunctionType().equalTo((Object)AggregationFunctionType.COUNT)).with(PinotPatterns.singleInput().matching(PinotPatterns.expression().with(PinotPatterns.expressionType().equalTo((Object)ExpressionContext.Type.IDENTIFIER)).with(PinotPatterns.identifier().equalTo((Object)"*"))));
        }

        @Override
        public FunctionContext rewrite(FunctionContext object, Captures captures, Context context) {
            return object;
        }
    }
}

