/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.com.google.common.collect.ImmutableMultimap;
import com.linkedin.coral.com.google.common.collect.Multimap;
import com.linkedin.coral.hive.hive2rel.functions.GenericProjectFunction;
import com.linkedin.coral.trino.rel2trino.CalciteTrinoUDFMap;
import com.linkedin.coral.trino.rel2trino.TrinoTryCastFunction;
import com.linkedin.coral.trino.rel2trino.UDFMapUtils;
import com.linkedin.coral.trino.rel2trino.UDFTransformer;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlMapValueConstructor;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;

public class Calcite2TrinoUDFConverter {
    private Calcite2TrinoUDFConverter() {
    }

    public static RelNode convertRel(RelNode calciteNode) {
        RelShuttleImpl converter = new RelShuttleImpl(){

            public RelNode visit(LogicalProject project) {
                return super.visit(project).accept((RexShuttle)this.getTrinoRexConverter((RelNode)project));
            }

            public RelNode visit(LogicalFilter inputFilter) {
                return super.visit(inputFilter).accept((RexShuttle)this.getTrinoRexConverter((RelNode)inputFilter));
            }

            public RelNode visit(LogicalAggregate aggregate) {
                return super.visit(aggregate).accept((RexShuttle)this.getTrinoRexConverter((RelNode)aggregate));
            }

            public RelNode visit(LogicalMatch match) {
                return super.visit(match).accept((RexShuttle)this.getTrinoRexConverter((RelNode)match));
            }

            public RelNode visit(TableScan scan) {
                return super.visit(scan).accept((RexShuttle)this.getTrinoRexConverter((RelNode)scan));
            }

            public RelNode visit(TableFunctionScan scan) {
                return super.visit(scan).accept((RexShuttle)this.getTrinoRexConverter((RelNode)scan));
            }

            public RelNode visit(LogicalValues values) {
                return super.visit(values).accept((RexShuttle)this.getTrinoRexConverter((RelNode)values));
            }

            public RelNode visit(LogicalJoin join) {
                return super.visit(join).accept((RexShuttle)this.getTrinoRexConverter((RelNode)join));
            }

            public RelNode visit(LogicalCorrelate correlate) {
                return super.visit(correlate).accept((RexShuttle)this.getTrinoRexConverter((RelNode)correlate));
            }

            public RelNode visit(LogicalUnion union) {
                return super.visit(union).accept((RexShuttle)this.getTrinoRexConverter((RelNode)union));
            }

            public RelNode visit(LogicalIntersect intersect) {
                return super.visit(intersect).accept((RexShuttle)this.getTrinoRexConverter((RelNode)intersect));
            }

            public RelNode visit(LogicalMinus minus) {
                return super.visit(minus).accept((RexShuttle)this.getTrinoRexConverter((RelNode)minus));
            }

            public RelNode visit(LogicalSort sort) {
                return super.visit(sort).accept((RexShuttle)this.getTrinoRexConverter((RelNode)sort));
            }

            public RelNode visit(LogicalExchange exchange) {
                return super.visit(exchange).accept((RexShuttle)this.getTrinoRexConverter((RelNode)exchange));
            }

            public RelNode visit(RelNode other) {
                return super.visit(other).accept((RexShuttle)this.getTrinoRexConverter(other));
            }

            private TrinoRexConverter getTrinoRexConverter(RelNode node) {
                return new TrinoRexConverter(node.getCluster().getRexBuilder(), node.getCluster().getTypeFactory());
            }
        };
        return calciteNode.accept((RelShuttle)converter);
    }

    public static class TrinoRexConverter
    extends RexShuttle {
        private final RexBuilder rexBuilder;
        private final RelDataTypeFactory typeFactory;
        private static final Multimap<SqlTypeFamily, SqlTypeFamily> SUPPORTED_TYPE_CAST_MAP = ImmutableMultimap.builder().putAll((Object)SqlTypeFamily.CHARACTER, (Object[])new SqlTypeFamily[]{SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN}).build();

        public TrinoRexConverter(RexBuilder rexBuilder, RelDataTypeFactory typeFactory) {
            this.rexBuilder = rexBuilder;
            this.typeFactory = typeFactory;
        }

        public RexNode visitCall(RexCall call) {
            Optional<RexNode> modifiedCall;
            if (call.getOperator() instanceof GenericProjectFunction) {
                return GenericProjectToTrinoConverter.convertGenericProject(this.rexBuilder, call);
            }
            if (call.getOperator() instanceof SqlMapValueConstructor) {
                return this.convertMapValueConstructor(this.rexBuilder, call);
            }
            if (call.getOperator().getName().equals("from_utc_timestamp") && (modifiedCall = this.visitFromUtcTimestampCall(call)).isPresent()) {
                return modifiedCall.get();
            }
            UDFTransformer transformer = CalciteTrinoUDFMap.getUDFTransformer(call.getOperator().getName(), call.operands.size());
            if (transformer != null) {
                return super.visitCall((RexCall)transformer.transformCall(this.rexBuilder, call.getOperands()));
            }
            RexCall modifiedCall2 = this.adjustInconsistentTypesToEqualityOperator(call);
            return super.visitCall(modifiedCall2);
        }

        private Optional<RexNode> visitFromUtcTimestampCall(RexCall call) {
            RelDataType inputType = ((RexNode)call.getOperands().get(0)).getType();
            RelDataType targetType = this.typeFactory.createSqlType(SqlTypeName.TIMESTAMP, 3);
            List convertedOperands = this.visitList(call.getOperands(), null);
            RexNode sourceValue = (RexNode)convertedOperands.get(0);
            RexNode timezone = (RexNode)convertedOperands.get(1);
            SqlOperator trinoAtTimeZone = UDFMapUtils.createUDF("at_timezone", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoWithTimeZone = UDFMapUtils.createUDF("with_timezone", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoToUnixTime = UDFMapUtils.createUDF("to_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.DOUBLE));
            SqlOperator trinoFromUnixtimeNanos = UDFMapUtils.createUDF("from_unixtime_nanos", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoFromUnixTime = UDFMapUtils.createUDF("from_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoCanonicalizeHiveTimezoneId = UDFMapUtils.createUDF("$canonicalize_hive_timezone_id", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.VARCHAR));
            RelDataType bigintType = this.typeFactory.createSqlType(SqlTypeName.BIGINT);
            RelDataType doubleType = this.typeFactory.createSqlType(SqlTypeName.DOUBLE);
            if (inputType.getSqlTypeName() == SqlTypeName.BIGINT || inputType.getSqlTypeName() == SqlTypeName.INTEGER || inputType.getSqlTypeName() == SqlTypeName.SMALLINT || inputType.getSqlTypeName() == SqlTypeName.TINYINT) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixtimeNanos, new RexNode[]{this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{this.rexBuilder.makeCast(bigintType, sourceValue), this.rexBuilder.makeBigintLiteral(BigDecimal.valueOf(1000000L))})}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.DOUBLE || inputType.getSqlTypeName() == SqlTypeName.FLOAT || inputType.getSqlTypeName() == SqlTypeName.DECIMAL) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixTime, new RexNode[]{this.rexBuilder.makeCast(doubleType, sourceValue)}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.TIMESTAMP || inputType.getSqlTypeName() == SqlTypeName.DATE) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixTime, new RexNode[]{this.rexBuilder.makeCall(trinoToUnixTime, new RexNode[]{this.rexBuilder.makeCall(trinoWithTimeZone, new RexNode[]{sourceValue, this.rexBuilder.makeLiteral("UTC")})})}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            return Optional.empty();
        }

        private RexCall adjustInconsistentTypesToEqualityOperator(RexCall call) {
            SqlOperator op = call.getOperator();
            if (op.getKind() != SqlKind.EQUALS) {
                return call;
            }
            RexNode leftOperand = (RexNode)call.getOperands().get(0);
            RexNode rightOperand = (RexNode)call.getOperands().get(1);
            if (leftOperand.getKind() == SqlKind.CAST) {
                leftOperand = (RexNode)((RexCall)leftOperand).getOperands().get(0);
            }
            if (SUPPORTED_TYPE_CAST_MAP.containsEntry((Object)leftOperand.getType().getSqlTypeName().getFamily(), (Object)rightOperand.getType().getSqlTypeName().getFamily())) {
                RexNode tryCastNode = this.rexBuilder.makeCall(rightOperand.getType(), (SqlOperator)TrinoTryCastFunction.INSTANCE, (List)ImmutableList.of((Object)leftOperand));
                return (RexCall)this.rexBuilder.makeCall(op, new RexNode[]{tryCastNode, rightOperand});
            }
            return call;
        }

        private RexNode convertMapValueConstructor(RexBuilder rexBuilder, RexCall call) {
            List sourceOperands = this.visitList(call.getOperands(), null);
            ArrayList<RexNode> results = new ArrayList<RexNode>();
            ArrayList keys = new ArrayList();
            for (int i = 0; i < sourceOperands.size(); i += 2) {
                keys.add(sourceOperands.get(i));
            }
            results.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, keys));
            ArrayList values = new ArrayList();
            for (int i = 1; i < sourceOperands.size(); i += 2) {
                values.add(sourceOperands.get(i));
            }
            results.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, values));
            return rexBuilder.makeCall(call.getOperator(), results);
        }
    }
}

