package ai.vespa.models.evaluation;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/vespa/models/evaluation/OnnxModel.class */
public class OnnxModel implements AutoCloseable {
    private final String name;
    private final File modelFile;
    private final OnnxEvaluatorOptions options;
    private final OnnxRuntime onnx;
    private OnnxEvaluator evaluator;
    final List<InputSpec> inputSpecs = new ArrayList();
    final List<OutputSpec> outputSpecs = new ArrayList();
    private final Map<String, ExpressionNode> exprPerOutput = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/models/evaluation/OnnxModel$InputSpec.class */
    public static class InputSpec {
        String onnxName;
        String source;
        TensorType wantedType;

        InputSpec(String str, String str2, TensorType tensorType) {
            this.onnxName = str;
            this.source = str2;
            this.wantedType = tensorType;
        }

        InputSpec(String str, String str2) {
            this(str, str2, null);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/models/evaluation/OnnxModel$OutputSpec.class */
    public static class OutputSpec {
        String onnxName;
        String outputAs;
        TensorType expectedType;

        OutputSpec(String str, String str2, TensorType tensorType) {
            this.onnxName = str;
            this.outputAs = str2;
            this.expectedType = tensorType;
        }

        OutputSpec(String str, String str2) {
            this(str, str2, null);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addInputMapping(String str, String str2) {
        if (this.evaluator != null) {
            throw new IllegalStateException("input mapping must be added before load()");
        }
        this.inputSpecs.add(new InputSpec(str, str2));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addOutputMapping(String str, String str2) {
        if (this.evaluator != null) {
            throw new IllegalStateException("output mapping must be added before load()");
        }
        this.outputSpecs.add(new OutputSpec(str, str2));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public OnnxModel(String str, File file, OnnxEvaluatorOptions onnxEvaluatorOptions, OnnxRuntime onnxRuntime) {
        this.name = str;
        this.modelFile = file;
        this.options = onnxEvaluatorOptions;
        this.onnx = onnxRuntime;
    }

    public String name() {
        return this.name;
    }

    public void load() {
        if (this.evaluator == null) {
            this.evaluator = this.onnx.evaluatorOf(this.modelFile.getPath(), this.options);
            fillInputTypes(evaluator().getInputs());
            fillOutputTypes(evaluator().getOutputs());
            fillOutputExpressions();
        }
    }

    void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> map) {
        if (this.inputSpecs.isEmpty()) {
            for (Map.Entry<String, OnnxEvaluator.IdAndType> entry : map.entrySet()) {
                this.inputSpecs.add(new InputSpec(entry.getKey(), entry.getValue().id(), entry.getValue().type()));
            }
            return;
        }
        if (map.size() != this.inputSpecs.size()) {
            throw new IllegalArgumentException("Onnx model " + name() + ": Mismatch between " + this.inputSpecs.size() + " configured inputs and " + map.size() + " actual model inputs");
        }
        for (InputSpec inputSpec : this.inputSpecs) {
            OnnxEvaluator.IdAndType idAndType = map.get(inputSpec.onnxName);
            if (idAndType == null) {
                throw new IllegalArgumentException("Onnx model " + name() + ": No type in actual model for configured input " + inputSpec.onnxName);
            }
            inputSpec.wantedType = idAndType.type();
        }
    }

    void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> map) {
        if (this.outputSpecs.isEmpty()) {
            for (Map.Entry<String, OnnxEvaluator.IdAndType> entry : map.entrySet()) {
                this.outputSpecs.add(new OutputSpec(entry.getKey(), entry.getValue().id(), entry.getValue().type()));
            }
            return;
        }
        if (map.size() != this.outputSpecs.size()) {
            throw new IllegalArgumentException("Onnx model " + name() + ": Mismatch between " + this.outputSpecs.size() + " configured outputs and " + map.size() + " actual model outputs");
        }
        for (OutputSpec outputSpec : this.outputSpecs) {
            OnnxEvaluator.IdAndType idAndType = map.get(outputSpec.onnxName);
            if (idAndType == null) {
                throw new IllegalArgumentException("Onnx model " + name() + ": No type in actual model for configured output " + outputSpec.onnxName);
            }
            outputSpec.expectedType = idAndType.type();
        }
    }

    public Map<String, TensorType> inputs() {
        HashMap hashMap = new HashMap();
        for (InputSpec inputSpec : this.inputSpecs) {
            hashMap.put(inputSpec.source, inputSpec.wantedType);
        }
        return hashMap;
    }

    public Map<String, TensorType> outputs() {
        HashMap hashMap = new HashMap();
        for (OutputSpec outputSpec : this.outputSpecs) {
            hashMap.put(outputSpec.outputAs, outputSpec.expectedType);
        }
        return hashMap;
    }

    void fillOutputExpressions() {
        for (OutputSpec outputSpec : this.outputSpecs) {
            this.exprPerOutput.put(outputSpec.outputAs, new OnnxExpressionNode(this, outputSpec.onnxName, outputSpec.expectedType, outputSpec.outputAs));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ExpressionNode getExpressionForOutput(String str) {
        return (str == null && this.exprPerOutput.size() == 1) ? this.exprPerOutput.values().iterator().next() : this.exprPerOutput.get(str);
    }

    public Tensor evaluate(Map<String, Tensor> map, String str) {
        HashMap hashMap = new HashMap();
        for (InputSpec inputSpec : this.inputSpecs) {
            Tensor tensor = map.get(inputSpec.source);
            if (tensor == null) {
                throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + inputSpec.source);
            }
            hashMap.put(inputSpec.onnxName, tensor);
        }
        String str2 = null;
        for (OutputSpec outputSpec : this.outputSpecs) {
            if (outputSpec.outputAs.equals(str)) {
                str2 = outputSpec.onnxName;
            }
        }
        if (str2 == null) {
            throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + str);
        }
        return unmappedEvaluate(hashMap, str2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Tensor unmappedEvaluate(Map<String, Tensor> map, String str) {
        return evaluator().evaluate(map, str);
    }

    private OnnxEvaluator evaluator() {
        if (this.evaluator == null) {
            throw new IllegalStateException("ONNX model has not been loaded.");
        }
        return this.evaluator;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.evaluator != null) {
            this.evaluator.close();
            this.evaluator = null;
        }
    }
}
