package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Optional;

/* loaded from: input_file:com/yahoo/tensor/serialization/DenseBinaryFormat.class */
public class DenseBinaryFormat implements BinaryFormat {
    private final TensorType.Value serializationValueType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.yahoo.tensor.serialization.DenseBinaryFormat$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/serialization/DenseBinaryFormat$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$TensorType$Value = new int[TensorType.Value.values().length];

        static {
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseBinaryFormat() {
        this(TensorType.Value.DOUBLE);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseBinaryFormat(TensorType.Value value) {
        this.serializationValueType = value;
    }

    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public void encode(GrowableByteBuffer growableByteBuffer, Tensor tensor) {
        if (!(tensor instanceof IndexedTensor)) {
            throw new RuntimeException("The dense format is only supported for indexed tensors");
        }
        encodeDimensions(growableByteBuffer, (IndexedTensor) tensor);
        encodeCells(growableByteBuffer, (IndexedTensor) tensor);
    }

    private void encodeDimensions(GrowableByteBuffer growableByteBuffer, IndexedTensor indexedTensor) {
        growableByteBuffer.putInt1_4Bytes(indexedTensor.type().dimensions().size());
        for (int i = 0; i < indexedTensor.type().dimensions().size(); i++) {
            growableByteBuffer.putUtf8String(indexedTensor.type().dimensions().get(i).name());
            growableByteBuffer.putInt1_4Bytes((int) indexedTensor.dimensionSizes().size(i));
        }
    }

    private void encodeCells(GrowableByteBuffer growableByteBuffer, IndexedTensor indexedTensor) {
        switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[this.serializationValueType.ordinal()]) {
            case 1:
                encodeDoubleCells(indexedTensor, growableByteBuffer);
                return;
            case 2:
                encodeFloatCells(indexedTensor, growableByteBuffer);
                return;
            case 3:
                encodeBFloat16Cells(indexedTensor, growableByteBuffer);
                return;
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                encodeInt8Cells(indexedTensor, growableByteBuffer);
                return;
            default:
                return;
        }
    }

    private void encodeDoubleCells(IndexedTensor indexedTensor, GrowableByteBuffer growableByteBuffer) {
        for (int i = 0; i < indexedTensor.sizeAsInt(); i++) {
            growableByteBuffer.putDouble(indexedTensor.get(i));
        }
    }

    private void encodeFloatCells(IndexedTensor indexedTensor, GrowableByteBuffer growableByteBuffer) {
        for (int i = 0; i < indexedTensor.sizeAsInt(); i++) {
            growableByteBuffer.putFloat(indexedTensor.getFloat(i));
        }
    }

    private void encodeBFloat16Cells(IndexedTensor indexedTensor, GrowableByteBuffer growableByteBuffer) {
        for (int i = 0; i < indexedTensor.sizeAsInt(); i++) {
            growableByteBuffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(indexedTensor.getFloat(i)));
        }
    }

    private void encodeInt8Cells(IndexedTensor indexedTensor, GrowableByteBuffer growableByteBuffer) {
        for (int i = 0; i < indexedTensor.sizeAsInt(); i++) {
            growableByteBuffer.put((byte) indexedTensor.getFloat(i));
        }
    }

    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public Tensor decode(Optional<TensorType> optional, GrowableByteBuffer growableByteBuffer) {
        TensorType decodeType;
        DimensionSizes sizesFromType;
        if (optional.isPresent()) {
            decodeType = optional.get();
            if (decodeType.valueType() != this.serializationValueType) {
                throw new IllegalArgumentException("Tensor value type mismatch. Value type " + String.valueOf(decodeType.valueType()) + " is not " + String.valueOf(this.serializationValueType));
            }
            TensorType decodeType2 = decodeType(growableByteBuffer);
            if (!decodeType2.isAssignableTo(decodeType)) {
                throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + String.valueOf(decodeType2) + " cannot be assigned to type " + String.valueOf(decodeType));
            }
            sizesFromType = sizesFromType(decodeType2);
        } else {
            decodeType = decodeType(growableByteBuffer);
            sizesFromType = sizesFromType(decodeType);
        }
        Tensor.Builder of = Tensor.Builder.of(decodeType, sizesFromType);
        decodeCells(sizesFromType, growableByteBuffer, (IndexedTensor.BoundBuilder) of);
        return of.build();
    }

    private TensorType decodeType(GrowableByteBuffer growableByteBuffer) {
        TensorType.Builder builder = new TensorType.Builder(this.serializationValueType);
        int int1_4Bytes = growableByteBuffer.getInt1_4Bytes();
        for (int i = 0; i < int1_4Bytes; i++) {
            builder.indexed(growableByteBuffer.getUtf8String(), growableByteBuffer.getInt1_4Bytes());
        }
        return builder.build();
    }

    private DimensionSizes sizesFromType(TensorType tensorType) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            builder.set(i, tensorType.dimensions().get(i).size().get().longValue());
        }
        return builder.build();
    }

    private void decodeCells(DimensionSizes dimensionSizes, GrowableByteBuffer growableByteBuffer, IndexedTensor.BoundBuilder boundBuilder) {
        switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[this.serializationValueType.ordinal()]) {
            case 1:
                decodeDoubleCells(dimensionSizes, boundBuilder, growableByteBuffer);
                return;
            case 2:
                decodeFloatCells(dimensionSizes, boundBuilder, growableByteBuffer);
                return;
            case 3:
                decodeBFloat16Cells(dimensionSizes, boundBuilder, growableByteBuffer);
                return;
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                decodeInt8Cells(dimensionSizes, boundBuilder, growableByteBuffer);
                return;
            default:
                return;
        }
    }

    private void decodeDoubleCells(DimensionSizes dimensionSizes, IndexedTensor.BoundBuilder boundBuilder, GrowableByteBuffer growableByteBuffer) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= dimensionSizes.totalSize()) {
                return;
            }
            boundBuilder.cellByDirectIndex(j2, growableByteBuffer.getDouble());
            j = j2 + 1;
        }
    }

    private void decodeFloatCells(DimensionSizes dimensionSizes, IndexedTensor.BoundBuilder boundBuilder, GrowableByteBuffer growableByteBuffer) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= dimensionSizes.totalSize()) {
                return;
            }
            boundBuilder.cellByDirectIndex(j2, growableByteBuffer.getFloat());
            j = j2 + 1;
        }
    }

    private void decodeBFloat16Cells(DimensionSizes dimensionSizes, IndexedTensor.BoundBuilder boundBuilder, GrowableByteBuffer growableByteBuffer) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= dimensionSizes.totalSize()) {
                return;
            }
            boundBuilder.cellByDirectIndex(j2, TypedBinaryFormat.floatFromBFloat16Bits(growableByteBuffer.getShort()));
            j = j2 + 1;
        }
    }

    private void decodeInt8Cells(DimensionSizes dimensionSizes, IndexedTensor.BoundBuilder boundBuilder, GrowableByteBuffer growableByteBuffer) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= dimensionSizes.totalSize()) {
                return;
            }
            boundBuilder.cellByDirectIndex(j2, growableByteBuffer.get());
            j = j2 + 1;
        }
    }
}
