package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/vespa/modelintegration/evaluator/TensorConverter.class */
class TensorConverter {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.vespa.modelintegration.evaluator.TensorConverter$1, reason: invalid class name */
    /* loaded from: input_file:ai/vespa/modelintegration/evaluator/TensorConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$OnnxJavaType;
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType = new int[TensorInfo.OnnxTensorType.values().length];

        static {
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$ai$onnxruntime$OnnxJavaType = new int[OnnxJavaType.values().length];
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT8.ordinal()] = 3;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT16.ordinal()] = 4;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT32.ordinal()] = 5;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT64.ordinal()] = 6;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.FLOAT16.ordinal()] = 7;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BFLOAT16.ordinal()] = 8;
            } catch (NoSuchFieldError e13) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/modelintegration/evaluator/TensorConverter$Short2Float.class */
    public interface Short2Float {
        float convert(short s);
    }

    TensorConverter() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> map, OrtEnvironment ortEnvironment, OrtSession ortSession) throws OrtException {
        HashMap hashMap = new HashMap();
        for (String str : map.keySet()) {
            Tensor tensor = map.get(str);
            String onnxName = toOnnxName(str, ortSession.getInputInfo().keySet());
            hashMap.put(onnxName, toOnnxTensor(tensor, toTensorInfo(((NodeInfo) ortSession.getInputInfo().get(onnxName)).getInfo()), ortEnvironment));
        }
        return hashMap;
    }

    static OnnxTensor toOnnxTensor(Tensor tensor, TensorInfo tensorInfo, OrtEnvironment ortEnvironment) throws OrtException {
        if (!(tensor instanceof IndexedTensor)) {
            throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
        }
        IndexedTensor indexedTensor = (IndexedTensor) tensor;
        ByteBuffer order = ByteBuffer.allocateDirect(((int) indexedTensor.size()) * tensorInfo.type.size).order(ByteOrder.nativeOrder());
        if (tensorInfo.type == OnnxJavaType.FLOAT) {
            for (int i = 0; i < indexedTensor.size(); i++) {
                order.putFloat(indexedTensor.getFloat(i));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind().asFloatBuffer(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.DOUBLE) {
            for (int i2 = 0; i2 < indexedTensor.size(); i2++) {
                order.putDouble(indexedTensor.get(i2));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind().asDoubleBuffer(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.INT8) {
            for (int i3 = 0; i3 < indexedTensor.size(); i3++) {
                order.put((byte) indexedTensor.get(i3));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.INT16) {
            for (int i4 = 0; i4 < indexedTensor.size(); i4++) {
                order.putShort((short) indexedTensor.get(i4));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind().asShortBuffer(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.INT32) {
            for (int i5 = 0; i5 < indexedTensor.size(); i5++) {
                order.putInt((int) indexedTensor.get(i5));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind().asIntBuffer(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.INT64) {
            for (int i6 = 0; i6 < indexedTensor.size(); i6++) {
                order.putLong((long) indexedTensor.get(i6));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind().asLongBuffer(), indexedTensor.shape());
        }
        if (tensorInfo.type == OnnxJavaType.FLOAT16) {
            for (int i7 = 0; i7 < indexedTensor.size(); i7++) {
                order.putShort(Fp16Conversions.floatToFp16((float) indexedTensor.get(i7)));
            }
            return OnnxTensor.createTensor(ortEnvironment, order.rewind(), indexedTensor.shape(), OnnxJavaType.FLOAT16);
        }
        if (tensorInfo.type != OnnxJavaType.BFLOAT16) {
            throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + tensorInfo.type);
        }
        for (int i8 = 0; i8 < indexedTensor.size(); i8++) {
            order.putShort(Fp16Conversions.floatToBf16((float) indexedTensor.get(i8)));
        }
        return OnnxTensor.createTensor(ortEnvironment, order.rewind(), indexedTensor.shape(), OnnxJavaType.BFLOAT16);
    }

    private static void extractTensor(FloatBuffer floatBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, floatBuffer.get(i2));
        }
    }

    private static void extractTensor(DoubleBuffer doubleBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, doubleBuffer.get(i2));
        }
    }

    private static void extractTensor(ByteBuffer byteBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, byteBuffer.get(i2));
        }
    }

    private static void extractTensor(ShortBuffer shortBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, shortBuffer.get(i2));
        }
    }

    private static void extractTensor(ShortBuffer shortBuffer, Short2Float short2Float, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, short2Float.convert(shortBuffer.get(i2)));
        }
    }

    private static void extractTensor(IntBuffer intBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, intBuffer.get(i2));
        }
    }

    private static void extractTensor(LongBuffer longBuffer, IndexedTensor.BoundBuilder boundBuilder, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            boundBuilder.cellByDirectIndex(i2, (float) longBuffer.get(i2));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor toVespaTensor(OnnxValue onnxValue) {
        if (!(onnxValue instanceof OnnxTensor)) {
            throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
        }
        OnnxTensor onnxTensor = (OnnxTensor) onnxValue;
        TensorInfo info = onnxTensor.getInfo();
        TensorType vespaType = toVespaType(onnxTensor.getInfo());
        DimensionSizes of = DimensionSizes.of(vespaType);
        IndexedTensor.BoundBuilder of2 = Tensor.Builder.of(vespaType, of);
        long j = of.totalSize();
        if (j > 2147483647L) {
            throw new IllegalArgumentException("TotalSize=" + j + " currently limited at INTEGER.MAX_VALUE");
        }
        int i = (int) j;
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxJavaType[info.type.ordinal()]) {
            case 1:
                extractTensor(onnxTensor.getFloatBuffer(), of2, i);
                break;
            case 2:
                extractTensor(onnxTensor.getDoubleBuffer(), of2, i);
                break;
            case 3:
                extractTensor(onnxTensor.getByteBuffer(), of2, i);
                break;
            case 4:
                extractTensor(onnxTensor.getShortBuffer(), of2, i);
                break;
            case 5:
                extractTensor(onnxTensor.getIntBuffer(), of2, i);
                break;
            case 6:
                extractTensor(onnxTensor.getLongBuffer(), of2, i);
                break;
            case 7:
                extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, of2, i);
                break;
            case 8:
                extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, of2, i);
                break;
            default:
                throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
        }
        return of2.build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return asValidName((String) entry.getKey());
        }, entry2 -> {
            return toVespaType(((NodeInfo) entry2.getValue()).getInfo());
        }));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String asValidName(String str) {
        return OnnxImporter.asValidIdentifier(str);
    }

    static String toOnnxName(String str, Set<String> set) {
        if (set.contains(str)) {
            return str;
        }
        for (String str2 : set) {
            if (asValidName(str2).equals(str)) {
                return str2;
            }
        }
        throw new IllegalArgumentException("ONNX model has no input with name " + str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static TensorType toVespaType(ValueInfo valueInfo) {
        TensorInfo tensorInfo = toTensorInfo(valueInfo);
        TensorType.Builder builder = new TensorType.Builder(toVespaValueType(tensorInfo.onnxType));
        long[] shape = tensorInfo.getShape();
        for (int i = 0; i < shape.length; i++) {
            long j = shape[i];
            String str = "d" + i;
            if (j > 0) {
                builder.indexed(str, j);
            } else {
                builder.indexed(str);
            }
        }
        return builder.build();
    }

    private static TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxTensorType) {
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[onnxTensorType.ordinal()]) {
            case 1:
                return TensorType.Value.INT8;
            case 2:
                return TensorType.Value.BFLOAT16;
            case 3:
                return TensorType.Value.FLOAT;
            case 4:
                return TensorType.Value.FLOAT;
            case 5:
                return TensorType.Value.DOUBLE;
            default:
                return TensorType.Value.DOUBLE;
        }
    }

    private static TensorInfo toTensorInfo(ValueInfo valueInfo) {
        if (valueInfo instanceof TensorInfo) {
            return (TensorInfo) valueInfo;
        }
        throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
    }
}
