package ai.vespa.triton;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.triton.TritonOnnxClient;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:ai/vespa/triton/TritonOnnxEvaluator.class */
class TritonOnnxEvaluator implements OnnxEvaluator {
    private final String modelName;
    private final TritonOnnxClient triton;
    private final TritonOnnxClient.ModelMetadata modelMetadata;
    private final boolean isExplicitControlMode;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TritonOnnxEvaluator(TritonOnnxClient tritonOnnxClient, String str, boolean z) {
        this.modelName = str;
        this.triton = tritonOnnxClient;
        this.isExplicitControlMode = z;
        if (z) {
            try {
                this.triton.loadModel(str);
            } catch (TritonOnnxClient.TritonException e) {
                throw new RuntimeException("Failed to load model: " + str, e);
            }
        }
        this.modelMetadata = this.triton.getModelMetadata(str);
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Tensor evaluate(Map<String, Tensor> map, String str) {
        try {
            return this.triton.evaluate(this.modelName, map, str);
        } catch (TritonOnnxClient.TritonException e) {
            throw new RuntimeException("Failed to evaluate model: " + this.modelName, e);
        }
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Map<String, Tensor> evaluate(Map<String, Tensor> map) {
        try {
            return this.triton.evaluate(this.modelName, map);
        } catch (TritonOnnxClient.TritonException e) {
            throw new RuntimeException("Failed to evaluate model: " + this.modelName, e);
        }
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Map<String, OnnxEvaluator.IdAndType> getInputs() {
        HashMap hashMap = new HashMap();
        this.modelMetadata.inputs().forEach((str, tensorType) -> {
            hashMap.put(str, new OnnxEvaluator.IdAndType(str, tensorType));
        });
        return hashMap;
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Map<String, OnnxEvaluator.IdAndType> getOutputs() {
        HashMap hashMap = new HashMap();
        this.modelMetadata.outputs().forEach((str, tensorType) -> {
            hashMap.put(str, new OnnxEvaluator.IdAndType(str, tensorType));
        });
        return hashMap;
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Map<String, TensorType> getInputInfo() {
        return this.modelMetadata.inputs();
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator
    public Map<String, TensorType> getOutputInfo() {
        return this.modelMetadata.outputs();
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxEvaluator, java.lang.AutoCloseable
    public void close() {
        if (this.isExplicitControlMode) {
            this.triton.unloadModel(this.modelName);
        }
    }
}
