package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import onnx.Onnx;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TypeConverter.class */
public class TypeConverter {
    TypeConverter() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType orderedTensorType) {
        Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
        if (shape != null) {
            if (shape.getDimCount() != orderedTensorType.rank()) {
                throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
            }
            for (int i = 0; i < orderedTensorType.dimensions().size(); i++) {
                int dimensionMap = orderedTensorType.dimensionMap(i);
                Onnx.TensorShapeProto.Dimension dim = shape.getDim(i);
                long dimValue = dim.getDimValue() == 0 ? 1L : dim.getDimValue();
                if (dimValue != -1 && dimValue != ((Long) ((TensorType.Dimension) orderedTensorType.type().dimensions().get(dimensionMap)).size().orElse(-1L)).longValue()) {
                    throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static OrderedTensorType typeFrom(Onnx.TypeProto typeProto) {
        Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(Onnx.TensorProto.DataType.forNumber(typeProto.getTensorType().getElemType())));
        for (int i = 0; i < shape.getDimCount(); i++) {
            String str = "d" + i;
            Onnx.TensorShapeProto.Dimension dim = shape.getDim(i);
            long dimValue = dim.getDimValue() == 0 ? 1L : dim.getDimValue();
            if (dimValue >= 0) {
                builder.add(TensorType.Dimension.indexed(str, dimValue));
            } else {
                builder.add(TensorType.Dimension.indexed(str));
            }
        }
        return builder.build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static OrderedTensorType typeFrom(Onnx.TensorProto tensorProto) {
        return OrderedTensorType.fromDimensionList(toValueType(Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType())), tensorProto.getDimsList());
    }

    private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
        switch (dataType) {
            case FLOAT:
                return TensorType.Value.FLOAT;
            case DOUBLE:
                return TensorType.Value.DOUBLE;
            case BOOL:
                return TensorType.Value.FLOAT;
            case INT8:
                return TensorType.Value.FLOAT;
            case INT16:
                return TensorType.Value.FLOAT;
            case INT32:
                return TensorType.Value.FLOAT;
            case INT64:
                return TensorType.Value.FLOAT;
            case UINT8:
                return TensorType.Value.FLOAT;
            case UINT16:
                return TensorType.Value.FLOAT;
            case UINT32:
                return TensorType.Value.FLOAT;
            case UINT64:
                return TensorType.Value.FLOAT;
            default:
                throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type");
        }
    }
}
