package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.TensorInfo;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.ShortBuffer;
import java.nio.channels.FileChannel;
import java.util.List;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

/* loaded from: input_file:com/github/tjake/jlama/tensor/SegmentedTensor.class */
public class SegmentedTensor extends BFloat16BufferTensor {
    private final AbstractTensor[] tensors;
    private final int[] splitPoints;

    public static SegmentedTensor wrap(List<AbstractTensor> list) {
        Preconditions.checkArgument(list.size() > 1, "Must have at least two tensor to segment");
        Preconditions.checkArgument(list.get(0).shape().dims() == 2, "First tensor must be 2D");
        AbstractTensor abstractTensor = list.get(0);
        int first = abstractTensor.shape().first();
        int last = abstractTensor.shape().last();
        int[] iArr = new int[list.size()];
        iArr[0] = first;
        for (int i = 1; i < list.size(); i++) {
            AbstractTensor abstractTensor2 = list.get(i);
            Preconditions.checkArgument(abstractTensor2.shape().last() == last, "All tensors must have the same second dimension");
            first += abstractTensor2.shape().first();
            iArr[i] = first;
        }
        return new SegmentedTensor(TensorShape.of(first, last), iArr, (AbstractTensor[]) list.toArray(new AbstractTensor[0]));
    }

    protected SegmentedTensor(TensorShape tensorShape, int[] iArr, AbstractTensor... abstractTensorArr) {
        super("segmented-tensor", ShortBuffer.allocate(0), tensorShape, false);
        this.splitPoints = iArr;
        this.tensors = abstractTensorArr;
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public TensorInfo save(FileChannel fileChannel) throws IOException {
        long position = fileChannel.position();
        for (AbstractTensor abstractTensor : this.tensors) {
            abstractTensor.save(fileChannel);
        }
        long[] jArr = new long[this.shape.dims()];
        for (int i = 0; i < this.shape.dims(); i++) {
            jArr[i] = this.shape.dim(i);
        }
        return new TensorInfo(this.dType, jArr, new long[]{position, fileChannel.position()});
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public AbstractTensor slice(int... iArr) {
        Preconditions.checkArgument(iArr.length == 1, "Must slice on first dimension");
        int i = iArr[0];
        int i2 = 0;
        while (i2 < this.splitPoints.length) {
            if (i < this.splitPoints[i2]) {
                AbstractTensor abstractTensor = this.tensors[i2];
                int[] iArr2 = new int[1];
                iArr2[0] = i - (i2 == 0 ? 0 : this.splitPoints[i2 - 1]);
                return abstractTensor.slice(iArr2);
            }
            i2++;
        }
        throw new IllegalArgumentException("Index out of range");
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public AbstractTensor slice(boolean z, int... iArr) {
        return super.slice(iArr);
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    protected AbstractTensor make(TensorShape tensorShape) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    protected AbstractTensor make(int i, int i2, TensorShape tensorShape, boolean z) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public float get(int... iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public void set(float f, int... iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    /* renamed from: getVector */
    public ShortVector mo41getVector(VectorSpecies<Short> vectorSpecies, int... iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public void intoTensor(ShortVector shortVector, int... iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor
    public String toString() {
        return "SegmentedBF16Tensor{shape=" + String.valueOf(this.shape) + ", tensors=" + this.tensors.length + "}";
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public MemorySegment getMemorySegment() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public int getMemorySegmentOffset(int i) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public void copyFrom(AbstractTensor abstractTensor, int i, int i2, int i3) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // com.github.tjake.jlama.tensor.BFloat16BufferTensor, com.github.tjake.jlama.tensor.AbstractTensor
    public void clear() {
        throw new UnsupportedOperationException("Not supported");
    }
}
