/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import com.linkedin.coral.hive.hive2rel.functions.HiveReturnTypes;
import com.linkedin.coral.trino.rel2trino.UDFMapUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

public class UDFTransformer {
    private static final Map<String, SqlOperator> OP_MAP = new HashMap<String, SqlOperator>();
    public static final String OPERATOR = "op";
    public static final String OPERANDS = "operands";
    public static final String INPUT = "input";
    public static final String VALUE = "value";
    public static final String REGEX = "regex";
    public static final String NAME = "name";
    public final String calciteOperatorName;
    public final SqlOperator targetOperator;
    public final List<JsonObject> operandTransformers;
    public final JsonObject resultTransformer;
    public final List<JsonObject> operatorTransformers;

    private UDFTransformer(String calciteOperatorName, SqlOperator targetOperator, List<JsonObject> operandTransformers, JsonObject resultTransformer, List<JsonObject> operatorTransformers) {
        this.calciteOperatorName = calciteOperatorName;
        this.targetOperator = targetOperator;
        this.operandTransformers = operandTransformers;
        this.resultTransformer = resultTransformer;
        this.operatorTransformers = operatorTransformers;
    }

    public static UDFTransformer of(@Nonnull String calciteOperatorName, @Nonnull SqlOperator targetOperator, @Nullable String operandTransformers, @Nullable String resultTransformer, @Nullable String operatorTransformers) {
        List<JsonObject> operands = null;
        JsonObject result = null;
        List<JsonObject> operators = null;
        if (operandTransformers != null) {
            operands = UDFTransformer.parseJsonObjectsFromString(operandTransformers);
        }
        if (resultTransformer != null) {
            result = new JsonParser().parse(resultTransformer).getAsJsonObject();
        }
        if (operatorTransformers != null) {
            operators = UDFTransformer.parseJsonObjectsFromString(operatorTransformers);
        }
        return new UDFTransformer(calciteOperatorName, targetOperator, operands, result, operators);
    }

    public RexNode transformCall(RexBuilder rexBuilder, List<RexNode> sourceOperands) {
        SqlOperator newTargetOperator = this.transformTargetOperator(this.targetOperator, sourceOperands);
        if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) {
            String operands = sourceOperands.stream().map(i -> i.toString()).collect(Collectors.joining(","));
            throw new IllegalArgumentException(String.format("An equivalent Trino operator was not found for the function call: %s(%s)", this.calciteOperatorName, operands));
        }
        List<RexNode> newOperands = this.transformOperands(rexBuilder, sourceOperands);
        RexNode newCall = rexBuilder.makeCall(newTargetOperator, newOperands);
        return this.transformResult(rexBuilder, newCall, sourceOperands);
    }

    private List<RexNode> transformOperands(RexBuilder rexBuilder, List<RexNode> sourceOperands) {
        if (this.operandTransformers == null) {
            return sourceOperands;
        }
        ArrayList<RexNode> sources = new ArrayList<RexNode>();
        sources.add(null);
        sources.addAll(sourceOperands);
        ArrayList<RexNode> results = new ArrayList<RexNode>();
        for (JsonObject operandTransformer : this.operandTransformers) {
            results.add(this.transformExpression(rexBuilder, operandTransformer, sources));
        }
        return results;
    }

    private RexNode transformResult(RexBuilder rexBuilder, RexNode result, List<RexNode> sourceOperands) {
        if (this.resultTransformer == null) {
            return result;
        }
        ArrayList<RexNode> sources = new ArrayList<RexNode>();
        sources.add(result);
        sources.addAll(sourceOperands);
        return this.transformExpression(rexBuilder, this.resultTransformer, sources);
    }

    private RexNode transformExpression(RexBuilder rexBuilder, JsonObject transformer, List<RexNode> sourceOperands) {
        if (transformer.get(OPERATOR) != null) {
            ArrayList<RexNode> inputOperands = new ArrayList<RexNode>();
            for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) {
                if (!inputOperand.isJsonObject()) continue;
                inputOperands.add(this.transformExpression(rexBuilder, inputOperand.getAsJsonObject(), sourceOperands));
            }
            String operatorName = transformer.get(OPERATOR).getAsString();
            SqlOperator op = OP_MAP.get(operatorName);
            if (op == null) {
                throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation");
            }
            return rexBuilder.makeCall(op, inputOperands);
        }
        if (transformer.get(INPUT) != null) {
            int index = transformer.get(INPUT).getAsInt();
            if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) {
                throw new IllegalArgumentException("Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size());
            }
            return sourceOperands.get(index);
        }
        JsonElement value = transformer.get(VALUE);
        if (value == null) {
            throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value");
        }
        if (!value.isJsonPrimitive()) {
            throw new IllegalArgumentException("Value should be of primitive type: " + value);
        }
        JsonPrimitive primitive = value.getAsJsonPrimitive();
        if (primitive.isString()) {
            return rexBuilder.makeLiteral(primitive.getAsString());
        }
        if (primitive.isBoolean()) {
            return rexBuilder.makeLiteral(primitive.getAsBoolean());
        }
        if (primitive.isNumber()) {
            return rexBuilder.makeBigintLiteral(value.getAsBigDecimal());
        }
        throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive);
    }

    private SqlOperator transformTargetOperator(SqlOperator operator, List<RexNode> sourceOperands) {
        if (this.operatorTransformers == null) {
            return operator;
        }
        for (JsonObject operatorTransformer : this.operatorTransformers) {
            if (!(operatorTransformer.has(REGEX) && operatorTransformer.has(INPUT) && operatorTransformer.has(NAME))) {
                throw new IllegalArgumentException("JSON node for target operator transformer must have a matcher, input and name");
            }
            int index = operatorTransformer.get(INPUT).getAsInt() - 1;
            if (index < 0 || index >= sourceOperands.size()) {
                throw new IllegalArgumentException(String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size()));
            }
            String functionName = operatorTransformer.get(NAME).getAsString();
            if (functionName.isEmpty()) {
                throw new IllegalArgumentException("JSON node for transformation must have a non-empty name");
            }
            String matcher = operatorTransformer.get(REGEX).getAsString();
            if (!Pattern.matches(matcher, sourceOperands.get(index).toString())) continue;
            return UDFMapUtils.createUDF(functionName, operator.getReturnTypeInference());
        }
        return operator;
    }

    private static List<JsonObject> parseJsonObjectsFromString(String s) {
        ArrayList<JsonObject> objects = new ArrayList<JsonObject>();
        JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray();
        for (JsonElement object : transformerArray) {
            objects.add(object.getAsJsonObject());
        }
        return objects;
    }

    static {
        OP_MAP.put("+", (SqlOperator)SqlStdOperatorTable.PLUS);
        OP_MAP.put("-", (SqlOperator)SqlStdOperatorTable.MINUS);
        OP_MAP.put("*", (SqlOperator)SqlStdOperatorTable.MULTIPLY);
        OP_MAP.put("/", (SqlOperator)SqlStdOperatorTable.DIVIDE);
        OP_MAP.put("^", (SqlOperator)SqlStdOperatorTable.POWER);
        OP_MAP.put("%", (SqlOperator)SqlStdOperatorTable.MOD);
        OP_MAP.put("hive_pattern_to_trino", (SqlOperator)new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO), HiveReturnTypes.STRING, null, (SqlOperandTypeChecker)OperandTypes.STRING, null, null));
    }
}

