package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.util.Arrays;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

/* loaded from: input_file:com/github/tjake/jlama/tensor/Float16BufferTensor.class */
public class Float16BufferTensor extends AbstractTensor<ShortVector, Short> {
    private final ShortBuffer b;
    private final String name;
    private final MemorySegment segment;

    public Float16BufferTensor(AbstractTensor abstractTensor) {
        this(abstractTensor.shape);
        Preconditions.checkArgument(abstractTensor.dType != DType.F16, "This should never happen, likely a bug");
        int[] iArr = new int[abstractTensor.shape.dims()];
        do {
            set(abstractTensor.get(iArr), iArr);
        } while (abstractTensor.iterate(iArr));
    }

    public Float16BufferTensor(int... iArr) {
        this(TensorShape.of(iArr));
    }

    public Float16BufferTensor(TensorShape tensorShape) {
        super(DType.F16, tensorShape, true);
        this.name = "tmp";
        this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast(size() * dType().size()), 64L).asShortBuffer();
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public Float16BufferTensor(ShortBuffer shortBuffer, TensorShape tensorShape, boolean z) {
        this("none", shortBuffer, tensorShape, z);
    }

    public Float16BufferTensor(String str, ShortBuffer shortBuffer, TensorShape tensorShape, boolean z) {
        super(DType.F16, tensorShape, z);
        Preconditions.checkArgument(shortBuffer.isDirect(), "Must use direct buffers");
        this.name = str;
        this.b = shortBuffer;
        this.segment = MemorySegment.ofBuffer(shortBuffer);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    protected AbstractTensor make(TensorShape tensorShape) {
        return new Float16BufferTensor(tensorShape);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    protected AbstractTensor make(int i, int i2, TensorShape tensorShape, boolean z) {
        return new Float16BufferTensor(this.name, this.b.slice(i, i2), tensorShape, z);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public float get(int... iArr) {
        Preconditions.checkArgument(iArr.length <= this.shape.dims(), "Too many dimensions specified");
        Preconditions.checkArgument(iArr.length == this.shape.dims(), "Must specify all dimensions");
        return Float.float16ToFloat(this.b.get(getOffset(iArr)));
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void set(float f, int... iArr) {
        Preconditions.checkArgument(iArr.length <= this.shape.dims(), "Too many dimensions specified for tensor");
        Preconditions.checkArgument(iArr.length == this.shape.dims(), "Must specify all dimensions");
        Preconditions.checkArgument(!this.b.isReadOnly(), "Can't modify a read only buffer");
        this.b.put(getOffset(iArr), Float.floatToFloat16(f));
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    /* renamed from: getVector, reason: avoid collision after fix types in other method and merged with bridge method [inline-methods] */
    public ShortVector mo41getVector(VectorSpecies<Short> vectorSpecies, int... iArr) {
        return ShortVector.fromMemorySegment(vectorSpecies, this.segment, getMemorySegmentOffset(getOffset(iArr)), ByteOrder.LITTLE_ENDIAN);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void intoTensor(ShortVector shortVector, int... iArr) {
        Preconditions.checkArgument(!this.b.isReadOnly());
        shortVector.intoMemorySegment(this.segment, getMemorySegmentOffset(getOffset(iArr)), ByteOrder.LITTLE_ENDIAN);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public int getMemorySegmentOffset(int i) {
        return i * this.dType.size();
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void copyFrom(AbstractTensor abstractTensor, int i, int i2, int i3) {
        Preconditions.checkArgument(this.dType == abstractTensor.dType, "different types");
        Preconditions.checkArgument(!this.b.isReadOnly(), "Read-only");
        this.segment.asSlice(getMemorySegmentOffset(i2), i3).copyFrom(abstractTensor.getMemorySegment().asSlice(abstractTensor.getMemorySegmentOffset(i), i3));
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void clear() {
        Preconditions.checkArgument(!this.b.isReadOnly(), "Can't clear a read-only buffer");
        this.segment.fill((byte) 0);
    }

    public String toString() {
        short[] sArr = new short[Math.min(10, this.b.remaining())];
        this.b.duplicate().get(sArr);
        return "Float16BufferTensor{name='" + this.name + "'shape=" + String.valueOf(this.shape) + ", b=" + Arrays.toString(sArr) + "...}";
    }
}
