package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Optional;

/* loaded from: input_file:com/yahoo/tensor/serialization/TypedBinaryFormat.class */
public class TypedBinaryFormat {
    private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
    private static final int DENSE_BINARY_FORMAT_TYPE = 2;
    private static final int MIXED_BINARY_FORMAT_TYPE = 3;
    private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5;
    private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
    private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;
    private static final int DOUBLE_VALUE_TYPE = 0;
    private static final int FLOAT_VALUE_TYPE = 1;
    private static final int BFLOAT16_VALUE_TYPE = 2;
    private static final int INT8_VALUE_TYPE = 3;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.yahoo.tensor.serialization.TypedBinaryFormat$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/serialization/TypedBinaryFormat$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$TensorType$Value = new int[TensorType.Value.values().length];

        static {
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public static byte[] encode(Tensor tensor) {
        return asByteArray(encode(tensor, new GrowableByteBuffer()));
    }

    public static GrowableByteBuffer encode(Tensor tensor, GrowableByteBuffer growableByteBuffer) {
        getFormatEncoder(growableByteBuffer, tensor).encode(growableByteBuffer, tensor);
        return growableByteBuffer;
    }

    public static Tensor decode(Optional<TensorType> optional, GrowableByteBuffer growableByteBuffer) {
        return getFormatDecoder(growableByteBuffer).decode(optional, growableByteBuffer);
    }

    private static BinaryFormat getFormatEncoder(GrowableByteBuffer growableByteBuffer, Tensor tensor) {
        boolean hasMappedDimensions = tensor.type().hasMappedDimensions();
        boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions();
        boolean z = hasMappedDimensions && hasIndexedDimensions;
        if ((tensor instanceof MixedTensor) && !z && hasIndexedDimensions) {
            z = true;
        }
        if (z && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(growableByteBuffer, 3);
            return new MixedBinaryFormat();
        }
        if (z) {
            encodeFormatType(growableByteBuffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE);
            encodeValueType(growableByteBuffer, tensor.type().valueType());
            return new MixedBinaryFormat(tensor.type().valueType());
        }
        if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(growableByteBuffer, 2);
            return new DenseBinaryFormat();
        }
        if (hasIndexedDimensions) {
            encodeFormatType(growableByteBuffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE);
            encodeValueType(growableByteBuffer, tensor.type().valueType());
            return new DenseBinaryFormat(tensor.type().valueType());
        }
        if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(growableByteBuffer, 1);
            return new SparseBinaryFormat();
        }
        encodeFormatType(growableByteBuffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
        encodeValueType(growableByteBuffer, tensor.type().valueType());
        return new SparseBinaryFormat(tensor.type().valueType());
    }

    private static BinaryFormat getFormatDecoder(GrowableByteBuffer growableByteBuffer) {
        int decodeFormatType = decodeFormatType(growableByteBuffer);
        switch (decodeFormatType) {
            case 1:
                return new SparseBinaryFormat();
            case 2:
                return new DenseBinaryFormat();
            case 3:
                return new MixedBinaryFormat();
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
            default:
                throw new IllegalArgumentException("Binary format type " + decodeFormatType + " is unknown");
            case SPARSE_BINARY_FORMAT_WITH_CELLTYPE /* 5 */:
                return new SparseBinaryFormat(decodeValueType(growableByteBuffer));
            case DENSE_BINARY_FORMAT_WITH_CELLTYPE /* 6 */:
                return new DenseBinaryFormat(decodeValueType(growableByteBuffer));
            case MIXED_BINARY_FORMAT_WITH_CELLTYPE /* 7 */:
                return new MixedBinaryFormat(decodeValueType(growableByteBuffer));
        }
    }

    private static void encodeFormatType(GrowableByteBuffer growableByteBuffer, int i) {
        growableByteBuffer.putInt1_4Bytes(i);
    }

    private static int decodeFormatType(GrowableByteBuffer growableByteBuffer) {
        return growableByteBuffer.getInt1_4Bytes();
    }

    private static void encodeValueType(GrowableByteBuffer growableByteBuffer, TensorType.Value value) {
        switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[value.ordinal()]) {
            case 1:
                growableByteBuffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE);
                return;
            case 2:
                growableByteBuffer.putInt1_4Bytes(1);
                return;
            case 3:
                growableByteBuffer.putInt1_4Bytes(2);
                return;
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                growableByteBuffer.putInt1_4Bytes(3);
                return;
            default:
                throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + value);
        }
    }

    private static TensorType.Value decodeValueType(GrowableByteBuffer growableByteBuffer) {
        int int1_4Bytes = growableByteBuffer.getInt1_4Bytes();
        switch (int1_4Bytes) {
            case DOUBLE_VALUE_TYPE /* 0 */:
                return TensorType.Value.DOUBLE;
            case 1:
                return TensorType.Value.FLOAT;
            case 2:
                return TensorType.Value.BFLOAT16;
            case 3:
                return TensorType.Value.INT8;
            default:
                throw new IllegalArgumentException("Received tensor value type '" + int1_4Bytes + "'. Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
        }
    }

    private static byte[] asByteArray(GrowableByteBuffer growableByteBuffer) {
        growableByteBuffer.flip();
        byte[] bArr = new byte[growableByteBuffer.remaining()];
        growableByteBuffer.get(bArr);
        return bArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static short bFloat16BitsFromFloat(float f) {
        return (short) (Float.floatToRawIntBits(f) >>> 16);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static float floatFromBFloat16Bits(short s) {
        return Float.intBitsToFloat(s << 16);
    }
}
