package com.yahoo.tensor.serialization;

import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.Optional;

@Beta
/* loaded from: input_file:com/yahoo/tensor/serialization/DenseBinaryFormat.class */
public class DenseBinaryFormat implements BinaryFormat {
    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public void encode(GrowableByteBuffer growableByteBuffer, Tensor tensor) {
        if (!(tensor instanceof IndexedTensor)) {
            throw new RuntimeException("The dense format is only supported for indexed tensors");
        }
        encodeDimensions(growableByteBuffer, (IndexedTensor) tensor);
        encodeCells(growableByteBuffer, tensor);
    }

    private void encodeDimensions(GrowableByteBuffer growableByteBuffer, IndexedTensor indexedTensor) {
        growableByteBuffer.putInt1_4Bytes(indexedTensor.type().dimensions().size());
        for (int i = 0; i < indexedTensor.type().dimensions().size(); i++) {
            growableByteBuffer.putUtf8String(indexedTensor.type().dimensions().get(i).name());
            growableByteBuffer.putInt1_4Bytes(indexedTensor.dimensionSizes().size(i));
        }
    }

    private void encodeCells(GrowableByteBuffer growableByteBuffer, Tensor tensor) {
        Iterator<Double> valueIterator = tensor.valueIterator();
        while (valueIterator.hasNext()) {
            growableByteBuffer.putDouble(valueIterator.next().doubleValue());
        }
    }

    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public Tensor decode(Optional<TensorType> optional, GrowableByteBuffer growableByteBuffer) {
        TensorType decodeType;
        DimensionSizes sizesFromType;
        if (optional.isPresent()) {
            decodeType = optional.get();
            TensorType decodeType2 = decodeType(growableByteBuffer);
            if (!decodeType2.isAssignableTo(decodeType)) {
                throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + decodeType2 + " cannot be assigned to type " + decodeType);
            }
            sizesFromType = sizesFromType(decodeType2);
        } else {
            decodeType = decodeType(growableByteBuffer);
            sizesFromType = sizesFromType(decodeType);
        }
        Tensor.Builder of = Tensor.Builder.of(decodeType, sizesFromType);
        decodeCells(sizesFromType, growableByteBuffer, (IndexedTensor.BoundBuilder) of);
        return of.build();
    }

    private TensorType decodeType(GrowableByteBuffer growableByteBuffer) {
        int int1_4Bytes = growableByteBuffer.getInt1_4Bytes();
        TensorType.Builder builder = new TensorType.Builder();
        for (int i = 0; i < int1_4Bytes; i++) {
            builder.indexed(growableByteBuffer.getUtf8String(), growableByteBuffer.getInt1_4Bytes());
        }
        return builder.build();
    }

    private DimensionSizes sizesFromType(TensorType tensorType) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            builder.set(i, tensorType.dimensions().get(i).size().get().intValue());
        }
        return builder.build();
    }

    private void decodeCells(DimensionSizes dimensionSizes, GrowableByteBuffer growableByteBuffer, IndexedTensor.BoundBuilder boundBuilder) {
        for (int i = 0; i < dimensionSizes.totalSize(); i++) {
            boundBuilder.cellByDirectIndex(i, growableByteBuffer.getDouble());
        }
    }
}
