package io.trino.plugin.pinot.query;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import io.trino.matching.Captures;
import io.trino.matching.Match;
import io.trino.matching.Pattern;
import io.trino.plugin.pinot.PinotErrorCode;
import io.trino.plugin.pinot.PinotException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SchemaTableName;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction;
import org.apache.pinot.core.operator.transform.transformer.datetime.DateTimeTransformerFactory;
import org.apache.pinot.core.operator.transform.transformer.datetime.EpochToEpochTransformer;
import org.apache.pinot.core.operator.transform.transformer.timeunit.TimeUnitTransformerFactory;
import org.apache.pinot.segment.spi.AggregationFunctionType;

/* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter.class */
public class PinotExpressionRewriter {
    private static final Map<TransformFunctionType, RewriteRule<FunctionContext>> FUNCTION_RULE_MAP;
    private static final Map<AggregationFunctionType, RewriteRule<FunctionContext>> AGGREGATION_FUNCTION_RULE_MAP;
    private static final RewriteRule<FunctionContext> DEFAULT_REWRITE_RULE = new DefaultRewriteRule();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.trino.plugin.pinot.query.PinotExpressionRewriter$2, reason: invalid class name */
    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$pinot$common$request$context$ExpressionContext$Type = new int[ExpressionContext.Type.values().length];

        static {
            try {
                $SwitchMap$org$apache$pinot$common$request$context$ExpressionContext$Type[ExpressionContext.Type.LITERAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$pinot$common$request$context$ExpressionContext$Type[ExpressionContext.Type.IDENTIFIER.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$pinot$common$request$context$ExpressionContext$Type[ExpressionContext.Type.FUNCTION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$Context.class */
    public interface Context {
        SchemaTableName getSchemaTableName();

        Map<String, ColumnHandle> getColumnHandles();
    }

    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$CountStarRewriteRule.class */
    private static class CountStarRewriteRule implements RewriteRule<FunctionContext> {
        private CountStarRewriteRule() {
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.aggregationFunction().with(PinotPatterns.aggregationFunctionType().equalTo(AggregationFunctionType.COUNT)).with(PinotPatterns.singleInput().matching(PinotPatterns.expression().with(PinotPatterns.expressionType().equalTo(ExpressionContext.Type.IDENTIFIER)).with(PinotPatterns.identifier().equalTo(PinotPatterns.WILDCARD))));
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public FunctionContext rewrite(FunctionContext functionContext, Captures captures, Context context) {
            return functionContext;
        }
    }

    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$DateTimeConvertRewriteRule.class */
    private static class DateTimeConvertRewriteRule implements RewriteRule<FunctionContext> {
        private DateTimeConvertRewriteRule() {
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo(TransformFunctionType.DATETIMECONVERT));
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public FunctionContext rewrite(FunctionContext functionContext, Captures captures, Context context) {
            Verify.verify(functionContext.getArguments().size() == 4);
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext) functionContext.getArguments().get(0));
            PinotExpressionRewriter.verifyTailArgumentsAllLiteral(functionContext.getArguments());
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add(PinotExpressionRewriter.rewriteExpression((ExpressionContext) functionContext.getArguments().get(0), context));
            String upperCase = ((ExpressionContext) functionContext.getArguments().get(1)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(upperCase)));
            String upperCase2 = ((ExpressionContext) functionContext.getArguments().get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(upperCase2)));
            String upperCase3 = ((ExpressionContext) functionContext.getArguments().get(3)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            Preconditions.checkState(DateTimeTransformerFactory.getDateTimeTransformer(upperCase, upperCase2, upperCase3) instanceof EpochToEpochTransformer, "Unsupported date format: simple date format not supported");
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(upperCase3)));
            return new FunctionContext(functionContext.getType(), functionContext.getFunctionName(), builder.build());
        }
    }

    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$DateTruncRewriteRule.class */
    private static class DateTruncRewriteRule implements RewriteRule<FunctionContext> {
        private DateTruncRewriteRule() {
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo(TransformFunctionType.DATETRUNC));
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public FunctionContext rewrite(FunctionContext functionContext, Captures captures, Context context) {
            List arguments = functionContext.getArguments();
            Preconditions.checkState(arguments.size() >= 2 && arguments.size() <= 5, "Between two to five arguments are required, example: %s", DateTruncTransformFunction.EXAMPLE_INVOCATION);
            ImmutableList.Builder builder = ImmutableList.builder();
            Preconditions.checkState(((ExpressionContext) arguments.get(0)).getType() == ExpressionContext.Type.LITERAL, "First argument must be a literal");
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(((ExpressionContext) arguments.get(0)).getLiteral().getValue().toString().toLowerCase(Locale.ENGLISH))));
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext) functionContext.getArguments().get(1));
            builder.add(PinotExpressionRewriter.rewriteExpression((ExpressionContext) arguments.get(1), context));
            if (arguments.size() >= 3) {
                Preconditions.checkState(((ExpressionContext) arguments.get(2)).getType() == ExpressionContext.Type.LITERAL, "Unexpected 3rd argument: '%s'", arguments.get(2));
                builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(TimeUnit.valueOf(((ExpressionContext) arguments.get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH)).name())));
                if (arguments.size() >= 4) {
                    Preconditions.checkState(((ExpressionContext) arguments.get(3)).getType() == ExpressionContext.Type.LITERAL, "Unexpected 4th argument '%s'", arguments.get(3));
                    builder.add((ExpressionContext) arguments.get(3));
                    if (arguments.size() >= 5) {
                        Preconditions.checkState(((ExpressionContext) arguments.get(4)).getType() == ExpressionContext.Type.LITERAL, "Unexpected 5th argument: '%s'", arguments.get(4));
                        builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(TimeUnit.valueOf(((ExpressionContext) arguments.get(4)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH)).name())));
                    }
                }
            }
            return new FunctionContext(functionContext.getType(), functionContext.getFunctionName(), builder.build());
        }
    }

    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$DefaultRewriteRule.class */
    private static class DefaultRewriteRule implements RewriteRule<FunctionContext> {
        private DefaultRewriteRule() {
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.function();
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public FunctionContext rewrite(FunctionContext functionContext, Captures captures, Context context) {
            return new FunctionContext(functionContext.getType(), functionContext.getFunctionName(), (List) functionContext.getArguments().stream().map(expressionContext -> {
                return PinotExpressionRewriter.rewriteExpression(expressionContext, context);
            }).collect(ImmutableList.toImmutableList()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$RewriteRule.class */
    public interface RewriteRule<T> {
        Pattern<T> getPattern();

        T rewrite(T t, Captures captures, Context context);
    }

    /* loaded from: input_file:io/trino/plugin/pinot/query/PinotExpressionRewriter$TimeConvertRewriteRule.class */
    private static class TimeConvertRewriteRule implements RewriteRule<FunctionContext> {
        private TimeConvertRewriteRule() {
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public Pattern<FunctionContext> getPattern() {
            return PinotPatterns.transformFunction().with(PinotPatterns.transformFunctionType().equalTo(TransformFunctionType.TIMECONVERT));
        }

        @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.RewriteRule
        public FunctionContext rewrite(FunctionContext functionContext, Captures captures, Context context) {
            Verify.verify(functionContext.getArguments().size() == 3);
            PinotExpressionRewriter.verifyIsIdentifierOrFunction((ExpressionContext) functionContext.getArguments().get(0));
            PinotExpressionRewriter.verifyTailArgumentsAllLiteral(functionContext.getArguments());
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add(PinotExpressionRewriter.rewriteExpression((ExpressionContext) functionContext.getArguments().get(0), context));
            String upperCase = ((ExpressionContext) functionContext.getArguments().get(1)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            TimeUnit valueOf = TimeUnit.valueOf(upperCase);
            String upperCase2 = ((ExpressionContext) functionContext.getArguments().get(2)).getLiteral().getValue().toString().toUpperCase(Locale.ENGLISH);
            TimeUnitTransformerFactory.getTimeUnitTransformer(valueOf, upperCase2);
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(upperCase)));
            builder.add(ExpressionContext.forLiteralContext(Literal.stringValue(upperCase2)));
            return new FunctionContext(functionContext.getType(), functionContext.getFunctionName(), builder.build());
        }
    }

    private PinotExpressionRewriter() {
    }

    public static ExpressionContext rewriteExpression(final SchemaTableName schemaTableName, ExpressionContext expressionContext, final Map<String, ColumnHandle> map) {
        Objects.requireNonNull(expressionContext, "expressionContext is null");
        return rewriteExpression(expressionContext, new Context() { // from class: io.trino.plugin.pinot.query.PinotExpressionRewriter.1
            @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.Context
            public SchemaTableName getSchemaTableName() {
                return schemaTableName;
            }

            @Override // io.trino.plugin.pinot.query.PinotExpressionRewriter.Context
            public Map<String, ColumnHandle> getColumnHandles() {
                return map;
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ExpressionContext rewriteExpression(ExpressionContext expressionContext, Context context) {
        switch (AnonymousClass2.$SwitchMap$org$apache$pinot$common$request$context$ExpressionContext$Type[expressionContext.getType().ordinal()]) {
            case 1:
                return expressionContext;
            case 2:
                return ExpressionContext.forIdentifier(PinotSqlFormatter.getColumnHandle(expressionContext.getIdentifier(), context.getSchemaTableName(), context.getColumnHandles()).getColumnName());
            case 3:
                return ExpressionContext.forFunction(rewriteFunction(expressionContext.getFunction(), context));
            default:
                throw new PinotException(PinotErrorCode.PINOT_EXCEPTION, Optional.empty(), String.format("Unsupported expression type '%s'", expressionContext.getType()));
        }
    }

    private static FunctionContext rewriteFunction(FunctionContext functionContext, Context context) {
        Optional empty = Optional.empty();
        if (functionContext.getType() == FunctionContext.Type.TRANSFORM) {
            RewriteRule<FunctionContext> rewriteRule = FUNCTION_RULE_MAP.get(PinotTransformFunctionTypeResolver.getTransformFunctionType(functionContext).orElseThrow());
            if (rewriteRule != null) {
                empty = applyRule(rewriteRule, functionContext, context);
            }
        } else {
            Preconditions.checkState(functionContext.getType() == FunctionContext.Type.AGGREGATION, "Unexpected function type for '%s'", functionContext);
            RewriteRule<FunctionContext> rewriteRule2 = AGGREGATION_FUNCTION_RULE_MAP.get(AggregationFunctionType.getAggregationFunctionType(functionContext.getFunctionName()));
            if (rewriteRule2 != null) {
                empty = applyRule(rewriteRule2, functionContext, context);
            }
        }
        if (empty.isPresent()) {
            return (FunctionContext) empty.get();
        }
        Optional applyRule = applyRule(DEFAULT_REWRITE_RULE, functionContext, context);
        if (applyRule.isPresent()) {
            return (FunctionContext) applyRule.get();
        }
        throw new PinotException(PinotErrorCode.PINOT_EXCEPTION, Optional.empty(), String.format("Unsupported function expression '%s'", functionContext));
    }

    private static <T> Optional<T> applyRule(RewriteRule<T> rewriteRule, T t, Context context) {
        Iterator<T> it = rewriteRule.getPattern().match(t).iterator();
        return it.hasNext() ? Optional.of(rewriteRule.rewrite(t, ((Match) it.next()).captures(), context)) : Optional.empty();
    }

    private static void verifyIsIdentifierOrFunction(ExpressionContext expressionContext) {
        Verify.verify(expressionContext.getType() == ExpressionContext.Type.IDENTIFIER || expressionContext.getType() == ExpressionContext.Type.FUNCTION);
    }

    private static void verifyTailArgumentsAllLiteral(List<ExpressionContext> list) {
        list.stream().skip(1L).forEach(expressionContext -> {
            Verify.verify(expressionContext.getType() == ExpressionContext.Type.LITERAL);
        });
    }

    static {
        HashMap hashMap = new HashMap();
        hashMap.put(TransformFunctionType.DATETIMECONVERT, new DateTimeConvertRewriteRule());
        hashMap.put(TransformFunctionType.TIMECONVERT, new TimeConvertRewriteRule());
        hashMap.put(TransformFunctionType.DATETRUNC, new DateTruncRewriteRule());
        FUNCTION_RULE_MAP = Maps.immutableEnumMap(hashMap);
        HashMap hashMap2 = new HashMap();
        hashMap2.put(AggregationFunctionType.COUNT, new CountStarRewriteRule());
        AGGREGATION_FUNCTION_RULE_MAP = Maps.immutableEnumMap(hashMap2);
    }
}
