package com.yahoo.schema.expressiontransforms;

import com.yahoo.path.Path;
import com.yahoo.schema.OnnxModel;
import com.yahoo.schema.RankProfile;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.vespa.model.ml.ConvertedModel;
import com.yahoo.vespa.model.ml.FeatureArguments;
import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;

/* loaded from: input_file:com/yahoo/schema/expressiontransforms/OnnxModelTransformer.class */
public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTransformContext> {
    public ExpressionNode transform(ExpressionNode expressionNode, RankProfileTransformContext rankProfileTransformContext) {
        return expressionNode instanceof ReferenceNode ? transformFeature((ReferenceNode) expressionNode, rankProfileTransformContext) : expressionNode instanceof CompositeNode ? super.transformChildren((CompositeNode) expressionNode, rankProfileTransformContext) : expressionNode;
    }

    private ExpressionNode transformFeature(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        if (rankProfileTransformContext.rankProfile() != null && rankProfileTransformContext.rankProfile().schema() != null) {
            return transformFeature(referenceNode, rankProfileTransformContext.rankProfile());
        }
        return referenceNode;
    }

    public static ExpressionNode transformFeature(ReferenceNode referenceNode, RankProfile rankProfile) {
        String name = referenceNode.getName();
        if (!name.equals("onnxModel") && !name.equals("onnx")) {
            return referenceNode;
        }
        Arguments arguments = referenceNode.getArguments();
        if (arguments.isEmpty()) {
            throw new IllegalArgumentException("An " + name + " feature must take an argument referring to a onnx-model config or an ONNX file.");
        }
        if (arguments.expressions().size() > 3) {
            throw new IllegalArgumentException("An " + name + " feature can have at most 3 arguments.");
        }
        String modelConfigName = getModelConfigName(referenceNode.reference());
        OnnxModel onnxModel = rankProfile.onnxModels().get(modelConfigName);
        if (onnxModel == null) {
            String asString = asString((ExpressionNode) arguments.expressions().get(0));
            return ConvertedModel.fromStore(rankProfile.schema().applicationPackage(), new ModelName(null, Path.fromString(asString), true), asString, rankProfile).expression(new FeatureArguments(arguments), null);
        }
        String modelOutput = getModelOutput(referenceNode.reference(), onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()));
        if (onnxModel.getOutputMap().containsValue(modelOutput)) {
            return new ReferenceNode("onnx", List.of(new ReferenceNode(modelConfigName)), modelOutput);
        }
        throw new IllegalArgumentException(name + " argument '" + modelOutput + "' output not found in model '" + onnxModel.getFileName() + "'");
    }

    public static String getModelConfigName(Reference reference) {
        if (reference.arguments().size() <= 0) {
            return null;
        }
        ExpressionNode expressionNode = (ExpressionNode) reference.arguments().expressions().get(0);
        if (expressionNode instanceof ReferenceNode) {
            return expressionNode.toString();
        }
        if (expressionNode instanceof ConstantNode) {
            return asValidIdentifier(expressionNode);
        }
        return null;
    }

    public static String getModelOutput(Reference reference, String str) {
        return reference.output() != null ? reference.output() : reference.arguments().expressions().size() == 2 ? asValidIdentifier((ExpressionNode) reference.arguments().expressions().get(1)) : reference.arguments().expressions().size() > 2 ? asValidIdentifier((ExpressionNode) reference.arguments().expressions().get(2)) : str;
    }

    public static String stripQuotes(String str) {
        if (isNotQuoteSign(str.codePointAt(0))) {
            return str;
        }
        if (isNotQuoteSign(str.codePointAt(str.length() - 1))) {
            throw new IllegalArgumentException("argument [" + str + "] is missing end quote");
        }
        return str.substring(1, str.length() - 1);
    }

    public static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }

    private static String asValidIdentifier(ExpressionNode expressionNode) {
        return asValidIdentifier(asString(expressionNode));
    }

    private static boolean isNotQuoteSign(int i) {
        return (i == 39 || i == 34) ? false : true;
    }

    public static String asString(ExpressionNode expressionNode) {
        if (expressionNode instanceof ConstantNode) {
            return stripQuotes(expressionNode.toString());
        }
        throw new IllegalArgumentException("Expected a constant string as argument, but got '" + String.valueOf(expressionNode));
    }
}
