package ai.vespa.triton;

import ai.vespa.llm.clients.TritonConfig;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.vespa.defaults.Defaults;
import inference.ModelConfigOuterClass;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;

/* loaded from: input_file:ai/vespa/triton/TritonOnnxRuntime.class */
public class TritonOnnxRuntime extends AbstractComponent implements OnnxRuntime {
    private final TritonConfig config;
    private final TritonOnnxClient client;

    public TritonOnnxRuntime() {
        this(new TritonConfig.Builder().build());
    }

    @Inject
    public TritonOnnxRuntime(TritonConfig tritonConfig) {
        this.config = tritonConfig;
        this.client = new TritonOnnxClient(tritonConfig);
    }

    @Override // ai.vespa.modelintegration.evaluator.OnnxRuntime
    public OnnxEvaluator evaluatorOf(String str, OnnxEvaluatorOptions onnxEvaluatorOptions) {
        boolean z = this.config.modelControlMode() == TritonConfig.ModelControlMode.EXPLICIT;
        if (z) {
            copyModelToRepository(str, onnxEvaluatorOptions);
        }
        return new TritonOnnxEvaluator(this.client, modelName(str), z);
    }

    public void deconstruct() {
        this.client.close();
    }

    private void copyModelToRepository(String str, OnnxEvaluatorOptions onnxEvaluatorOptions) {
        Path path = Paths.get(Defaults.getDefaults().underVespaHome(this.config.modelRepositoryPath()), modelName(str));
        Path resolve = path.resolve("1");
        Path resolve2 = resolve.resolve("model.onnx");
        Path resolve3 = path.resolve("config.pbtxt");
        try {
            Files.createDirectories(resolve, new FileAttribute[0]);
            Files.copy(Paths.get(str, new String[0]), resolve2, StandardCopyOption.REPLACE_EXISTING);
            Files.writeString(resolve3, onnxEvaluatorOptions.rawConfig().orElseGet(() -> {
                return generateConfigFromEvaluatorOptions(str, onnxEvaluatorOptions).toString();
            }), new OpenOption[0]);
        } catch (IOException e) {
            throw new UncheckedIOException("Failed to copy model file to repository", e);
        }
    }

    static String modelName(String str) {
        String substring = str.substring(str.lastIndexOf(47) + 1);
        return substring.substring(0, substring.lastIndexOf(46));
    }

    private static ModelConfigOuterClass.ModelConfig generateConfigFromEvaluatorOptions(String str, OnnxEvaluatorOptions onnxEvaluatorOptions) {
        return ModelConfigOuterClass.ModelConfig.newBuilder().setName(modelName(str)).setPlatform("onnxruntime_onnx").setMaxBatchSize(0).putParameters("enable_mem_area", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").m4177build()).putParameters("enable_mem_pattern", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").m4177build()).putParameters("intra_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(onnxEvaluatorOptions.intraOpThreads())).m4177build()).putParameters("inter_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(onnxEvaluatorOptions.interOpThreads())).m4177build()).m3119build();
    }
}
