package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Supplier;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/yahoo/tensor/serialization/MixedBinaryFormat.class */
public class MixedBinaryFormat implements BinaryFormat {
    private final TensorType.Value serializationValueType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.yahoo.tensor.serialization.MixedBinaryFormat$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/serialization/MixedBinaryFormat$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$TensorType$Value = new int[TensorType.Value.values().length];

        static {
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public MixedBinaryFormat() {
        this(TensorType.Value.DOUBLE);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public MixedBinaryFormat(TensorType.Value value) {
        this.serializationValueType = value;
    }

    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public void encode(GrowableByteBuffer growableByteBuffer, Tensor tensor) {
        if (!(tensor instanceof MixedTensor)) {
            throw new RuntimeException("The mixed format is only supported for mixed tensors");
        }
        MixedTensor mixedTensor = (MixedTensor) tensor;
        encodeSparseDimensions(growableByteBuffer, mixedTensor);
        encodeDenseDimensions(growableByteBuffer, mixedTensor);
        encodeCells(growableByteBuffer, mixedTensor);
    }

    private void encodeSparseDimensions(GrowableByteBuffer growableByteBuffer, MixedTensor mixedTensor) {
        List<TensorType.Dimension> list = mixedTensor.type().dimensions().stream().filter(dimension -> {
            return !dimension.isIndexed();
        }).toList();
        growableByteBuffer.putInt1_4Bytes(list.size());
        Iterator<TensorType.Dimension> it = list.iterator();
        while (it.hasNext()) {
            growableByteBuffer.putUtf8String(it.next().name());
        }
    }

    private void encodeDenseDimensions(GrowableByteBuffer growableByteBuffer, MixedTensor mixedTensor) {
        List<TensorType.Dimension> list = mixedTensor.type().dimensions().stream().filter(dimension -> {
            return dimension.isIndexed();
        }).toList();
        growableByteBuffer.putInt1_4Bytes(list.size());
        for (TensorType.Dimension dimension2 : list) {
            growableByteBuffer.putUtf8String(dimension2.name());
            growableByteBuffer.putInt1_4Bytes((int) dimension2.size().orElseThrow(() -> {
                return new IllegalArgumentException("Unknown size of indexed dimension.");
            }).longValue());
        }
    }

    private void encodeCells(GrowableByteBuffer growableByteBuffer, MixedTensor mixedTensor) {
        switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[this.serializationValueType.ordinal()]) {
            case 1:
                Objects.requireNonNull(growableByteBuffer);
                encodeCells(growableByteBuffer, mixedTensor, (v1) -> {
                    r3.putDouble(v1);
                });
                return;
            case 2:
                encodeCells(growableByteBuffer, mixedTensor, d -> {
                    growableByteBuffer.putFloat(d.floatValue());
                });
                return;
            case 3:
                encodeCells(growableByteBuffer, mixedTensor, d2 -> {
                    growableByteBuffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(d2.floatValue()));
                });
                return;
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                encodeCells(growableByteBuffer, mixedTensor, d3 -> {
                    growableByteBuffer.put((byte) d3.floatValue());
                });
                return;
            default:
                return;
        }
    }

    private void encodeCells(GrowableByteBuffer growableByteBuffer, MixedTensor mixedTensor, Consumer<Double> consumer) {
        List<TensorType.Dimension> list = mixedTensor.type().dimensions().stream().filter(dimension -> {
            return !dimension.isIndexed();
        }).toList();
        mixedTensor.denseSubspaceSize();
        List<MixedTensor.DenseSubspace> internalDenseSubspaces = mixedTensor.getInternalDenseSubspaces();
        if (list.size() > 0) {
            growableByteBuffer.putInt1_4Bytes(internalDenseSubspaces.size());
        }
        for (MixedTensor.DenseSubspace denseSubspace : internalDenseSubspaces) {
            for (int i = 0; i < denseSubspace.sparseAddress.size(); i++) {
                growableByteBuffer.putUtf8String(denseSubspace.sparseAddress.label(i));
            }
            for (double d : denseSubspace.cells) {
                consumer.accept(Double.valueOf(d));
            }
        }
    }

    @Override // com.yahoo.tensor.serialization.BinaryFormat
    public Tensor decode(Optional<TensorType> optional, GrowableByteBuffer growableByteBuffer) {
        TensorType decodeType;
        if (optional.isPresent()) {
            decodeType = optional.get();
            if (decodeType.valueType() != this.serializationValueType) {
                throw new IllegalArgumentException("Tensor value type mismatch. Value type " + String.valueOf(decodeType.valueType()) + " is not " + String.valueOf(this.serializationValueType));
            }
            TensorType decodeType2 = decodeType(growableByteBuffer);
            if (!decodeType2.isAssignableTo(decodeType)) {
                throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + String.valueOf(decodeType2) + " cannot be assigned to type " + String.valueOf(decodeType));
            }
        } else {
            decodeType = decodeType(growableByteBuffer);
        }
        MixedTensor.BoundBuilder boundBuilder = (MixedTensor.BoundBuilder) MixedTensor.Builder.of(decodeType);
        decodeCells(growableByteBuffer, boundBuilder, decodeType);
        return boundBuilder.build();
    }

    private TensorType decodeType(GrowableByteBuffer growableByteBuffer) {
        TensorType.Builder builder = new TensorType.Builder(this.serializationValueType);
        int int1_4Bytes = growableByteBuffer.getInt1_4Bytes();
        for (int i = 0; i < int1_4Bytes; i++) {
            builder.mapped(growableByteBuffer.getUtf8String());
        }
        int int1_4Bytes2 = growableByteBuffer.getInt1_4Bytes();
        for (int i2 = 0; i2 < int1_4Bytes2; i2++) {
            builder.indexed(growableByteBuffer.getUtf8String(), growableByteBuffer.getInt1_4Bytes());
        }
        return builder.build();
    }

    private void decodeCells(GrowableByteBuffer growableByteBuffer, MixedTensor.BoundBuilder boundBuilder, TensorType tensorType) {
        switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[this.serializationValueType.ordinal()]) {
            case 1:
                Objects.requireNonNull(growableByteBuffer);
                decodeCells(growableByteBuffer, boundBuilder, tensorType, growableByteBuffer::getDouble);
                return;
            case 2:
                decodeCells(growableByteBuffer, boundBuilder, tensorType, () -> {
                    return Double.valueOf(growableByteBuffer.getFloat());
                });
                return;
            case 3:
                decodeCells(growableByteBuffer, boundBuilder, tensorType, () -> {
                    return Double.valueOf(TypedBinaryFormat.floatFromBFloat16Bits(growableByteBuffer.getShort()));
                });
                return;
            case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                decodeCells(growableByteBuffer, boundBuilder, tensorType, () -> {
                    return Double.valueOf(growableByteBuffer.get());
                });
                return;
            default:
                return;
        }
    }

    private void decodeCells(GrowableByteBuffer growableByteBuffer, MixedTensor.BoundBuilder boundBuilder, TensorType tensorType, Supplier<Double> supplier) {
        List<TensorType.Dimension> list = tensorType.dimensions().stream().filter(dimension -> {
            return !dimension.isIndexed();
        }).toList();
        TensorType createPartialType = MixedTensor.createPartialType(tensorType.valueType(), list);
        long denseSubspaceSize = boundBuilder.denseSubspaceSize();
        int int1_4Bytes = list.size() > 0 ? growableByteBuffer.getInt1_4Bytes() : 1;
        double[] dArr = new double[(int) denseSubspaceSize];
        for (int i = 0; i < int1_4Bytes; i++) {
            TensorAddress.Builder builder = new TensorAddress.Builder(createPartialType);
            Iterator<TensorType.Dimension> it = list.iterator();
            while (it.hasNext()) {
                builder.add(it.next().name(), growableByteBuffer.getUtf8String());
            }
            long j = 0;
            while (true) {
                long j2 = j;
                if (j2 < denseSubspaceSize) {
                    dArr[(int) j2] = supplier.get().doubleValue();
                    j = j2 + 1;
                }
            }
            boundBuilder.block(builder.build(), dArr);
        }
    }
}
