/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractTensor<V extends Vector<?>, T extends Number>
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(AbstractTensor.class);
    protected final TensorShape shape;
    protected final DType dType;
    protected final AbstractTensor[] sliceCache;
    private final int stride;
    private volatile TensorCache originCache = null;

    protected AbstractTensor(DType dType, TensorShape shape, boolean cacheSlices) {
        Preconditions.checkArgument((shape != null && shape.dims() > 0 ? 1 : 0) != 0);
        this.dType = dType;
        this.shape = shape;
        this.sliceCache = cacheSlices ? new AbstractTensor[shape.first()] : null;
        this.stride = shape.first() > 1 && this.dims() == 2 ? this.getOffset(shape.sparseRowOffset() + 1, shape.sparseColumnOffset()) : 0;
    }

    public static AbstractTensor make(DType dType, TensorShape shape) {
        return switch (dType) {
            case DType.F32 -> new FloatBufferTensor(shape);
            case DType.BF16 -> new BFloat16BufferTensor(shape);
            case DType.I8 -> new Q8ByteBufferTensor(shape);
            default -> throw new RuntimeException("Unsupported tensor type: " + String.valueOf((Object)dType));
        };
    }

    protected abstract AbstractTensor make(TensorShape var1);

    protected abstract AbstractTensor make(int var1, int var2, TensorShape var3, boolean var4);

    public AbstractTensor copyShape() {
        return TensorCache.instance.get(this.dType, this.shape);
    }

    public final int dims() {
        return this.shape.dims();
    }

    public final TensorShape shape() {
        return this.shape;
    }

    public final long size() {
        return this.shape.size();
    }

    public abstract float get(int ... var1);

    public abstract void set(float var1, int ... var2);

    public AbstractTensor slice(int ... dims) {
        return this.slice(false, dims);
    }

    public AbstractTensor slice(boolean cacheInnerSlice, int ... dims) {
        Preconditions.checkArgument((dims.length < this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified for tensor");
        try {
            if (dims.length == 1 && this.sliceCache != null && this.sliceCache[dims[0]] != null) {
                return this.sliceCache[dims[0]];
            }
        }
        catch (Throwable t) {
            logger.warn("Dims = {}", (Object)Arrays.toString(dims), (Object)t);
            throw t;
        }
        TensorShape slicedShape = this.shape.slice(dims.length);
        int totalOffset = 0;
        if (dims.length == 1 && this.shape.dims() == 2) {
            totalOffset = this.shape.sparseColumnLength() * dims[0];
        } else {
            for (int d = 0; d <= dims.length - 1; ++d) {
                int offset = this.shape.sparseColumnLength();
                for (int i = this.shape.dims() - 2; i > d; --i) {
                    offset *= this.shape.dim(i);
                }
                totalOffset += dims[d] * offset;
            }
        }
        AbstractTensor r = this.make(totalOffset, (int)slicedShape.size(), slicedShape, cacheInnerSlice);
        if (dims.length == 1 && this.sliceCache != null) {
            this.sliceCache[dims[0]] = r;
        }
        return r;
    }

    public AbstractTensor<V, T> sparsify(int offset, int length) {
        if (this.shape.isSparse()) {
            return this;
        }
        if (length == this.shape.last()) {
            return this;
        }
        AbstractTensor sparseT = this.make(this.shape.sparsifyColumns(offset, length));
        int originalLength = this.shape.last();
        int[] cursor = new int[this.shape.dims()];
        try {
            do {
                cursor[cursor.length - 1] = offset;
                sparseT.copyFrom(this, this.getOffset(cursor), sparseT.getOffset(cursor), length);
                cursor[cursor.length - 1] = originalLength - 1;
            } while (this.iterate(cursor));
        }
        catch (Throwable t) {
            logger.warn("Cursor = {}", (Object)Arrays.toString(cursor), (Object)t);
            throw t;
        }
        return sparseT;
    }

    public AbstractTensor[] split(int numChunks, int dim) {
        AbstractTensor[] chunks = new AbstractTensor[numChunks];
        int innerLength = this.shape.dim(dim) / numChunks;
        if (innerLength * numChunks != this.shape.dim(dim)) {
            throw new IllegalStateException("Chunks must be of equal size");
        }
        TensorShape newShape = this.shape.setDimValue(dim, innerLength);
        for (int i = 0; i < numChunks; ++i) {
            chunks[i] = this.make(Ints.checkedCast((long)((long)i * newShape.size())), Ints.checkedCast((long)newShape.size()), newShape, true);
        }
        return chunks;
    }

    public final boolean iterate(int[] cursor) {
        Preconditions.checkArgument((cursor.length == this.shape.dims() ? 1 : 0) != 0);
        for (int i = cursor.length - 1; i >= 0; --i) {
            Preconditions.checkArgument((cursor[i] >= 0 && cursor[i] < this.shape.dim(i) ? 1 : 0) != 0);
            if (cursor[i] + 1 < this.shape.dim(i)) {
                int n = i;
                cursor[n] = cursor[n] + 1;
                break;
            }
            cursor[i] = 0;
            if (i != 0) continue;
            return false;
        }
        return true;
    }

    public final int getStride() {
        return this.stride;
    }

    public final int getOffset(int ... dims) {
        return this.shape.getOffset(dims);
    }

    public final AbstractTensor transpose() {
        Preconditions.checkArgument((!this.shape.isSparse() ? 1 : 0) != 0, (Object)"Cannot transpose a sparse tensor");
        int[] tshape = new int[this.dims()];
        for (int i = 0; i < tshape.length; ++i) {
            tshape[i] = this.shape.dim(this.shape.dims() - i - 1);
        }
        AbstractTensor tt = this.make(TensorShape.of(tshape));
        int[] cursor = new int[this.dims()];
        int[] tcursor = new int[this.dims()];
        do {
            float v = this.get(cursor);
            for (int i = 0; i < tcursor.length; ++i) {
                tcursor[i] = cursor[cursor.length - i - 1];
            }
            tt.set(v, tcursor);
        } while (this.iterate(cursor));
        return tt;
    }

    public final DType dType() {
        return this.dType;
    }

    public abstract V getVector(VectorSpecies<T> var1, int ... var2);

    public abstract void intoTensor(V var1, int ... var2);

    public void intoTensor(V vector, VectorMask<T> mask, int ... offset) {
        throw new UnsupportedOperationException();
    }

    public abstract MemorySegment getMemorySegment();

    public abstract int getMemorySegmentOffset(int var1);

    public abstract void copyFrom(AbstractTensor var1, int var2, int var3, int var4);

    public abstract void clear();

    @Override
    public void close() {
        if (this.originCache != null) {
            this.originCache.release(this);
        }
    }

    void setOwnerCache(TensorCache cache) {
        this.originCache = cache;
    }

    public AbstractTensor quantize(DType dType) {
        return this.quantize(dType, false);
    }

    public AbstractTensor quantize(DType dType, boolean force) {
        if (!(force || this.shape().first() != 1 && this.dType != dType && this.dType.size() >= dType.size())) {
            return this;
        }
        if (this.shape.isSparse()) {
            logger.info("Quantizing sparse tensor is not supported");
            return this;
        }
        return switch (dType) {
            case DType.Q4 -> new Q4ByteBufferTensor(this);
            case DType.I8 -> new Q8ByteBufferTensor(this);
            case DType.F32 -> new FloatBufferTensor(this);
            case DType.BF16 -> new BFloat16BufferTensor(this);
            default -> this;
        };
    }

    public TensorInfo save(FileChannel out) throws IOException {
        Preconditions.checkArgument((!this.shape.isSparse() ? 1 : 0) != 0, (Object)"Cannot save a sparse tensor");
        ByteBuffer bb = this.getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
        long startOffset = out.position();
        out.write(bb);
        long[] lshape = new long[this.shape.dims()];
        for (int i = 0; i < this.shape.dims(); ++i) {
            lshape[i] = this.shape.dim(i);
        }
        return new TensorInfo(this.dType, lshape, new long[]{startOffset, out.position()});
    }

    public void debug(String id) {
        double tmp = 0.0;
        int i = 0;
        while ((long)i < this.size()) {
            tmp += (double)this.get(0, i++);
        }
        System.out.println(String.format("%s = %.5f", id, tmp));
    }
}

