package ai.vespa.triton;

import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.llm.clients.TritonConfig;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.google.protobuf.ByteString;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import grpc.health.v1.HealthGrpc;
import grpc.health.v1.HealthOuterClass;
import inference.GRPCInferenceServiceGrpc;
import inference.GrpcService;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.AbstractBlockingStub;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:ai/vespa/triton/TritonOnnxClient.class */
public class TritonOnnxClient implements AutoCloseable {
    private static final Logger log = Logger.getLogger(TritonOnnxClient.class.getName());
    private final GRPCInferenceServiceGrpc.GRPCInferenceServiceBlockingV2Stub grpcInferenceStub;
    private final HealthGrpc.HealthBlockingV2Stub grpcHealthStub;

    /* loaded from: input_file:ai/vespa/triton/TritonOnnxClient$ModelMetadata.class */
    public static final class ModelMetadata extends Record {
        private final Map<String, TensorType> inputs;
        private final Map<String, TensorType> outputs;

        public ModelMetadata(Map<String, TensorType> map, Map<String, TensorType> map2) {
            this.inputs = map;
            this.outputs = map2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ModelMetadata.class), ModelMetadata.class, "inputs;outputs", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->inputs:Ljava/util/Map;", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ModelMetadata.class), ModelMetadata.class, "inputs;outputs", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->inputs:Ljava/util/Map;", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ModelMetadata.class, Object.class), ModelMetadata.class, "inputs;outputs", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->inputs:Ljava/util/Map;", "FIELD:Lai/vespa/triton/TritonOnnxClient$ModelMetadata;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Map<String, TensorType> inputs() {
            return this.inputs;
        }

        public Map<String, TensorType> outputs() {
            return this.outputs;
        }
    }

    /* loaded from: input_file:ai/vespa/triton/TritonOnnxClient$TritonException.class */
    public static class TritonException extends RuntimeException {
        public TritonException(Throwable th) {
            super(th);
        }

        public TritonException(String str) {
            super(str);
        }

        public TritonException(String str, Throwable th) {
            super(str, th);
        }
    }

    @Inject
    public TritonOnnxClient(TritonConfig tritonConfig) {
        ManagedChannel build = ManagedChannelBuilder.forTarget(tritonConfig.target()).usePlaintext().build();
        this.grpcInferenceStub = GRPCInferenceServiceGrpc.newBlockingV2Stub(build);
        this.grpcHealthStub = HealthGrpc.newBlockingV2Stub(build);
    }

    public ModelMetadata getModelMetadata(String str) {
        GrpcService.ModelMetadataRequest m1350build = GrpcService.ModelMetadataRequest.newBuilder().setName(str).m1350build();
        GrpcService.ModelMetadataResponse modelMetadataResponse = (GrpcService.ModelMetadataResponse) invokeGrpc(this.grpcInferenceStub, gRPCInferenceServiceBlockingV2Stub -> {
            return gRPCInferenceServiceBlockingV2Stub.modelMetadata(m1350build);
        });
        return new ModelMetadata(toTensorTypes(modelMetadataResponse.getInputsList()), toTensorTypes(modelMetadataResponse.getOutputsList()));
    }

    public boolean isHealthy() {
        HealthOuterClass.HealthCheckRequest m97build = HealthOuterClass.HealthCheckRequest.newBuilder().m97build();
        HealthOuterClass.HealthCheckResponse healthCheckResponse = (HealthOuterClass.HealthCheckResponse) invokeGrpc(this.grpcHealthStub, healthBlockingV2Stub -> {
            return healthBlockingV2Stub.check(m97build);
        });
        log.fine(() -> {
            return "Triton health status: " + String.valueOf(healthCheckResponse.getStatus());
        });
        return healthCheckResponse.getStatus() == HealthOuterClass.HealthCheckResponse.ServingStatus.SERVING;
    }

    public void loadModel(String str) {
        log.fine(() -> {
            return "Loading model " + str;
        });
        GrpcService.RepositoryModelLoadRequest m1966build = GrpcService.RepositoryModelLoadRequest.newBuilder().setModelName(str).m1966build();
        invokeGrpc(this.grpcInferenceStub, gRPCInferenceServiceBlockingV2Stub -> {
            return gRPCInferenceServiceBlockingV2Stub.repositoryModelLoad(m1966build);
        });
    }

    public void unloadModel(String str) {
        GrpcService.RepositoryModelUnloadRequest m2062build = GrpcService.RepositoryModelUnloadRequest.newBuilder().setModelName(str).m2062build();
        invokeGrpc(this.grpcInferenceStub, gRPCInferenceServiceBlockingV2Stub -> {
            return gRPCInferenceServiceBlockingV2Stub.repositoryModelUnload(m2062build);
        });
    }

    public Map<String, Tensor> evaluate(String str, Map<String, Tensor> map) {
        return evaluate(str, map, Set.of());
    }

    public Tensor evaluate(String str, Map<String, Tensor> map, String str2) {
        return evaluate(str, map, Set.of(str2)).get(str2);
    }

    public Map<String, Tensor> evaluate(String str, Map<String, Tensor> map, Set<String> set) {
        GrpcService.ModelInferRequest.Builder modelName = GrpcService.ModelInferRequest.newBuilder().setModelName(str);
        GrpcService.ModelMetadataResponse modelMetadataResponse = (GrpcService.ModelMetadataResponse) invokeGrpc(this.grpcInferenceStub, gRPCInferenceServiceBlockingV2Stub -> {
            return gRPCInferenceServiceBlockingV2Stub.modelMetadata(GrpcService.ModelMetadataRequest.newBuilder().setName(str).m1350build());
        });
        map.forEach((str2, tensor) -> {
            addInputToBuilder(modelMetadataResponse.getInputsList(), modelName, tensor, str2);
        });
        set.forEach(str3 -> {
            modelName.addOutputs(GrpcService.ModelInferRequest.InferRequestedOutputTensor.newBuilder().setName(str3).m1203build());
        });
        GrpcService.ModelInferResponse modelInferResponse = (GrpcService.ModelInferResponse) invokeGrpc(this.grpcInferenceStub, gRPCInferenceServiceBlockingV2Stub2 -> {
            return gRPCInferenceServiceBlockingV2Stub2.modelInfer(modelName.m1106build());
        });
        HashMap hashMap = new HashMap();
        for (int i = 0; i < modelInferResponse.getOutputsCount(); i++) {
            GrpcService.ModelInferResponse.InferOutputTensor outputs = modelInferResponse.getOutputs(i);
            hashMap.put(OnnxImporter.asValidIdentifier(outputs.getName()), createTensorFromRawOutput(ByteBuffer.wrap(modelInferResponse.getRawOutputContents(i).toByteArray()).order(ByteOrder.LITTLE_ENDIAN), outputs.getDatatype(), outputs.getShapeList()));
        }
        return hashMap;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        ManagedChannel managedChannel = (ManagedChannel) invokeGrpc(this.grpcInferenceStub, (v0) -> {
            return v0.getChannel();
        });
        managedChannel.shutdown();
        try {
            try {
                if (managedChannel.awaitTermination(5L, TimeUnit.SECONDS)) {
                } else {
                    throw new IllegalStateException("Failed to close channel");
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new TritonException("Failed to close channel", e);
            }
        } finally {
            managedChannel.shutdownNow();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addInputToBuilder(List<GrpcService.ModelMetadataResponse.TensorMetadata> list, GrpcService.ModelInferRequest.Builder builder, Tensor tensor, String str) {
        if (!(tensor instanceof IndexedTensor)) {
            throw new TritonException("Nvidia Triton currently only supports tensors with indexed dimensions");
        }
        IndexedTensor indexedTensor = (IndexedTensor) tensor;
        GrpcService.ModelMetadataResponse.TensorMetadata findMatchingInput = findMatchingInput(list, str);
        GrpcService.ModelInferRequest.InferInputTensor.Builder datatype = GrpcService.ModelInferRequest.InferInputTensor.newBuilder().setName(findMatchingInput.getName()).setDatatype(findMatchingInput.getDatatype());
        for (long j : indexedTensor.shape()) {
            datatype.addShape(j);
        }
        builder.addInputs(datatype.m1154build());
        builder.addRawInputContents(createRawInputContent(findMatchingInput, indexedTensor));
    }

    private static GrpcService.ModelMetadataResponse.TensorMetadata findMatchingInput(List<GrpcService.ModelMetadataResponse.TensorMetadata> list, String str) {
        for (GrpcService.ModelMetadataResponse.TensorMetadata tensorMetadata : list) {
            if (tensorMetadata.getName().equals(str)) {
                return tensorMetadata;
            }
        }
        for (GrpcService.ModelMetadataResponse.TensorMetadata tensorMetadata2 : list) {
            if (OnnxImporter.asValidIdentifier(tensorMetadata2.getName()).equals(str)) {
                return tensorMetadata2;
            }
        }
        throw new TritonException("No matching input type found for " + str);
    }

    private static ByteString createRawInputContent(GrpcService.ModelMetadataResponse.TensorMetadata tensorMetadata, IndexedTensor indexedTensor) {
        ByteBuffer order;
        String datatype = tensorMetadata.getDatatype();
        int size = (int) indexedTensor.size();
        boolean z = -1;
        switch (datatype.hashCode()) {
            case 2035049:
                if (datatype.equals("BF16")) {
                    z = 6;
                    break;
                }
                break;
            case 2163823:
                if (datatype.equals("FP16")) {
                    z = 7;
                    break;
                }
                break;
            case 2163881:
                if (datatype.equals("FP32")) {
                    z = false;
                    break;
                }
                break;
            case 2163976:
                if (datatype.equals("FP64")) {
                    z = true;
                    break;
                }
                break;
            case 2252361:
                if (datatype.equals("INT8")) {
                    z = 2;
                    break;
                }
                break;
            case 69823028:
                if (datatype.equals("INT16")) {
                    z = 3;
                    break;
                }
                break;
            case 69823086:
                if (datatype.equals("INT32")) {
                    z = 4;
                    break;
                }
                break;
            case 69823181:
                if (datatype.equals("INT64")) {
                    z = 5;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                order = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN);
                FloatBuffer asFloatBuffer = order.asFloatBuffer();
                for (int i = 0; i < size; i++) {
                    asFloatBuffer.put(indexedTensor.getFloat(i));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN);
                DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
                for (int i2 = 0; i2 < size; i2++) {
                    asDoubleBuffer.put(indexedTensor.get(i2));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size).order(ByteOrder.LITTLE_ENDIAN);
                for (int i3 = 0; i3 < size; i3++) {
                    order.put((byte) indexedTensor.get(i3));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer asShortBuffer = order.asShortBuffer();
                for (int i4 = 0; i4 < size; i4++) {
                    asShortBuffer.put((short) indexedTensor.get(i4));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN);
                IntBuffer asIntBuffer = order.asIntBuffer();
                for (int i5 = 0; i5 < size; i5++) {
                    asIntBuffer.put((int) indexedTensor.get(i5));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN);
                LongBuffer asLongBuffer = order.asLongBuffer();
                for (int i6 = 0; i6 < size; i6++) {
                    asLongBuffer.put((long) indexedTensor.get(i6));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer asShortBuffer2 = order.asShortBuffer();
                for (int i7 = 0; i7 < size; i7++) {
                    asShortBuffer2.put(Fp16Conversions.floatToBf16(indexedTensor.getFloat(i7)));
                }
                break;
            case true:
                order = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer asShortBuffer3 = order.asShortBuffer();
                for (int i8 = 0; i8 < size; i8++) {
                    asShortBuffer3.put(Fp16Conversions.floatToFp16(indexedTensor.getFloat(i8)));
                }
                break;
            default:
                throw new TritonException("Unsupported tensor datatype from Triton: " + datatype);
        }
        return ByteString.copyFrom(order.rewind());
    }

    private Tensor createTensorFromRawOutput(ByteBuffer byteBuffer, String str, List<Long> list) {
        TensorType vespaTensorType = toVespaTensorType(str, list);
        DimensionSizes of = DimensionSizes.of(vespaTensorType);
        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        long j = of.totalSize();
        IndexedTensor.BoundBuilder of2 = Tensor.Builder.of(vespaTensorType, of);
        boolean z = -1;
        switch (str.hashCode()) {
            case 2035049:
                if (str.equals("BF16")) {
                    z = false;
                    break;
                }
                break;
            case 2163823:
                if (str.equals("FP16")) {
                    z = true;
                    break;
                }
                break;
            case 2163881:
                if (str.equals("FP32")) {
                    z = 2;
                    break;
                }
                break;
            case 2163976:
                if (str.equals("FP64")) {
                    z = 3;
                    break;
                }
                break;
            case 2252361:
                if (str.equals("INT8")) {
                    z = 4;
                    break;
                }
                break;
            case 69823028:
                if (str.equals("INT16")) {
                    z = 5;
                    break;
                }
                break;
            case 69823086:
                if (str.equals("INT32")) {
                    z = 6;
                    break;
                }
                break;
            case 69823181:
                if (str.equals("INT64")) {
                    z = 7;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                ShortBuffer asShortBuffer = byteBuffer.asShortBuffer();
                for (int i = 0; i < j; i++) {
                    of2.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat(asShortBuffer.get(i)));
                }
                break;
            case true:
                ShortBuffer asShortBuffer2 = byteBuffer.asShortBuffer();
                for (int i2 = 0; i2 < j; i2++) {
                    of2.cellByDirectIndex(i2, Fp16Conversions.fp16ToFloat(asShortBuffer2.get(i2)));
                }
                break;
            case true:
                FloatBuffer asFloatBuffer = byteBuffer.asFloatBuffer();
                for (int i3 = 0; i3 < j; i3++) {
                    of2.cellByDirectIndex(i3, asFloatBuffer.get(i3));
                }
                break;
            case true:
                DoubleBuffer asDoubleBuffer = byteBuffer.asDoubleBuffer();
                for (int i4 = 0; i4 < j; i4++) {
                    of2.cellByDirectIndex(i4, asDoubleBuffer.get(i4));
                }
                break;
            case true:
                for (int i5 = 0; i5 < j; i5++) {
                    of2.cellByDirectIndex(i5, byteBuffer.get(i5));
                }
                break;
            case true:
                ShortBuffer asShortBuffer3 = byteBuffer.asShortBuffer();
                for (int i6 = 0; i6 < j; i6++) {
                    of2.cellByDirectIndex(i6, asShortBuffer3.get(i6));
                }
                break;
            case true:
                IntBuffer asIntBuffer = byteBuffer.asIntBuffer();
                for (int i7 = 0; i7 < j; i7++) {
                    of2.cellByDirectIndex(i7, asIntBuffer.get(i7));
                }
                break;
            case true:
                LongBuffer asLongBuffer = byteBuffer.asLongBuffer();
                for (int i8 = 0; i8 < j; i8++) {
                    of2.cellByDirectIndex(i8, (float) asLongBuffer.get(i8));
                }
                break;
            default:
                throw new TritonException("Unsupported type from ONNX output: %s".formatted(str));
        }
        return of2.build();
    }

    private static Map<String, TensorType> toTensorTypes(Collection<GrpcService.ModelMetadataResponse.TensorMetadata> collection) {
        return (Map) collection.stream().collect(Collectors.toMap(tensorMetadata -> {
            return OnnxImporter.asValidIdentifier(tensorMetadata.getName());
        }, tensorMetadata2 -> {
            return toVespaTensorType(tensorMetadata2.getDatatype(), tensorMetadata2.getShapeList());
        }));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static TensorType toVespaTensorType(String str, List<Long> list) {
        TensorType.Value value;
        boolean z = -1;
        switch (str.hashCode()) {
            case 2035049:
                if (str.equals("BF16")) {
                    z = true;
                    break;
                }
                break;
            case 2163823:
                if (str.equals("FP16")) {
                    z = 2;
                    break;
                }
                break;
            case 2163881:
                if (str.equals("FP32")) {
                    z = 3;
                    break;
                }
                break;
            case 2252361:
                if (str.equals("INT8")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                value = TensorType.Value.INT8;
                break;
            case true:
                value = TensorType.Value.BFLOAT16;
                break;
            case true:
            case true:
                value = TensorType.Value.FLOAT;
                break;
            default:
                value = TensorType.Value.DOUBLE;
                break;
        }
        TensorType.Builder builder = new TensorType.Builder(value);
        for (int i = 0; i < list.size(); i++) {
            long longValue = list.get(i).longValue();
            String str2 = "d" + i;
            if (longValue >= 0) {
                builder.indexed(str2, longValue);
            } else {
                builder.indexed(str2);
            }
        }
        return builder.build();
    }

    private <T, S extends AbstractBlockingStub<S>> T invokeGrpc(S s, Function<S, T> function) {
        try {
            return function.apply(s);
        } catch (StatusRuntimeException e) {
            throw new TritonException((Throwable) e);
        }
    }
}
