package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/models/evaluation/OnnxExpressionNode.class */
class OnnxExpressionNode extends CompositeNode {
    private final OnnxModel model;
    private final String onnxOutputName;
    private final TensorType expectedType;
    private final String outputAs;
    private final List<String> modelInputs = new ArrayList();
    private final List<ExpressionNode> inputRefs = new ArrayList();

    /* JADX INFO: Access modifiers changed from: package-private */
    public OnnxExpressionNode(OnnxModel onnxModel, String str, TensorType tensorType, String str2) {
        this.model = onnxModel;
        this.onnxOutputName = str;
        this.expectedType = tensorType;
        this.outputAs = str2;
        for (OnnxModel.InputSpec inputSpec : onnxModel.inputSpecs) {
            this.modelInputs.add(inputSpec.onnxName);
            Optional<Reference> parseOnnxInput = parseOnnxInput(inputSpec.source);
            if (parseOnnxInput.isEmpty()) {
                throw new IllegalArgumentException("Bad input source for ONNX model " + onnxModel.name() + ": '" + inputSpec + "'");
            }
            this.inputRefs.add(new ReferenceNode(parseOnnxInput.get()));
        }
    }

    static Optional<Reference> parseOnnxInput(String str) {
        Optional<Reference> simple = Reference.simple(str);
        if (simple.isPresent()) {
            return simple;
        }
        try {
            return Optional.of(Reference.fromIdentifier(str));
        } catch (Exception e) {
            return Optional.empty();
        }
    }

    public List<ExpressionNode> children() {
        return List.copyOf(this.inputRefs);
    }

    public CompositeNode setChildren(List<ExpressionNode> list) {
        if (this.inputRefs.size() != list.size()) {
            throw new IllegalArgumentException("bad setChildren");
        }
        this.inputRefs.clear();
        this.inputRefs.addAll(list);
        return this;
    }

    public Value evaluate(Context context) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.modelInputs.size(); i++) {
            hashMap.put(this.modelInputs.get(i), this.inputRefs.get(i).evaluate(context).asTensor());
        }
        return new TensorValue(this.model.unmappedEvaluate(hashMap, this.onnxOutputName));
    }

    public TensorType type(TypeContext<Reference> typeContext) {
        return this.expectedType;
    }

    public int hashCode() {
        return Objects.hash("OnnxExpressionNode", this.model.name(), this.onnxOutputName);
    }

    public StringBuilder toString(StringBuilder sb, SerializationContext serializationContext, Deque<String> deque, CompositeNode compositeNode) {
        sb.append("onnx_expression_node(").append(this.model.name()).append(")");
        if (this.outputAs != null && !this.outputAs.equals("")) {
            sb.append(".").append(this.outputAs);
        }
        return sb;
    }
}
