/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Q5ByteBufferTensor
extends AbstractTensor<ByteVector, Byte> {
    private static final Logger logger = LoggerFactory.getLogger(Q5ByteBufferTensor.class);
    public static final int BLOCK_SIZE = 32;
    private static final float I_BLOCK_SIZE = 0.03125f;
    final ByteBuffer b;
    final FloatBufferTensor blockF;
    final int[] b5;
    private final String name;
    private final MemorySegment segment;

    public Q5ByteBufferTensor(AbstractTensor ft) {
        this(ft.shape);
        Preconditions.checkArgument((ft.dType != DType.Q5 ? 1 : 0) != 0, (Object)"This should never happen, likely a bug");
        Preconditions.checkArgument((ft.size() % 32L == 0L ? 1 : 0) != 0, (Object)"I8 buffer must be a multiple of BLOCK_SIZE");
        ArrayList<int[]> startBlockCursors = new ArrayList<int[]>();
        int[] cursor = new int[ft.shape.dims()];
        int c = 0;
        do {
            if (c++ % 32 != 0) continue;
            startBlockCursors.add(Arrays.copyOf(cursor, cursor.length));
        } while (ft.iterate(cursor));
        VectorMath.pfor(0, startBlockCursors.size(), i -> {
            int[] blockStartCursor = (int[])startBlockCursors.get(i);
            this.processBlock(ft, blockStartCursor);
        });
    }

    void processBlock(AbstractTensor ft, int[] blockStartCursor) {
        int[] cursor = Arrays.copyOf(blockStartCursor, blockStartCursor.length);
        float max = Float.MIN_VALUE;
        float amax = Float.MIN_VALUE;
        for (int i = 0; i < 32; ++i) {
            float absv;
            float v = ft.get(cursor);
            float f = absv = v < 0.0f ? -v : v;
            if (absv > amax) {
                max = v;
                amax = absv;
            }
            ft.iterate(cursor);
        }
        float scale = max / -16.0f;
        float iscale = scale != 0.0f ? 1.0f / scale : 0.0f;
        this.blockF.set(scale, Q4ByteBufferTensor.makeBlockShape(blockStartCursor));
        int i = ft.getOffset(blockStartCursor);
        int q = 0;
        cursor = Arrays.copyOf(blockStartCursor, blockStartCursor.length);
        int j = 0;
        while (j < 16) {
            float f0 = ft.get(cursor) * iscale;
            ft.iterate(cursor);
            float f1 = ft.get(cursor) * iscale;
            ft.iterate(cursor);
            short fb0 = (byte)Math.min(31, (byte)(f0 + 16.5f));
            short fb1 = (byte)Math.min(31, (byte)(f1 + 16.5f));
            this.b.put(i / 2, (byte)(fb0 & 0xF | (fb1 & 0xF) << 4));
            q |= (fb0 & 0x10) >>> 4 << j + 0;
            q |= (fb1 & 0x10) >>> 4 << j + 16;
            ++j;
            i += 2;
        }
        this.b5[this.getOffset((int[])Q4ByteBufferTensor.makeBlockShape((int[])blockStartCursor))] = q;
    }

    protected Q5ByteBufferTensor(TensorShape shape) {
        super(DType.Q5, shape, true);
        Preconditions.checkArgument((this.size() % 32L == 0L ? 1 : 0) != 0, (Object)"Tensor must be a multiple of BLOCK_SIZE");
        this.blockF = new FloatBufferTensor(Q4ByteBufferTensor.makeBlockShape(shape));
        this.b5 = new int[Ints.checkedCast((long)Q4ByteBufferTensor.makeBlockShape(shape).size())];
        this.name = "tmp";
        this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast((long)(this.size() / 2L)), 64L).order(ByteOrder.LITTLE_ENDIAN);
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public Q5ByteBufferTensor(String name, ByteBuffer b, FloatBufferTensor blockF, int[] b5, TensorShape shape, boolean cacheSlices) {
        super(DType.Q5, shape, cacheSlices);
        Preconditions.checkArgument((boolean)b.isDirect(), (Object)"Must use direct buffers");
        this.name = name;
        this.b = b;
        this.blockF = blockF;
        this.b5 = b5;
        this.segment = MemorySegment.ofBuffer(b);
    }

    @Override
    protected AbstractTensor make(TensorShape shape) {
        return new Q5ByteBufferTensor(shape);
    }

    public FloatBufferTensor getBlockF() {
        return this.blockF;
    }

    @Override
    protected AbstractTensor make(int offset, int length, TensorShape shape, boolean cacheSlices) {
        FloatBufferTensor newBlockF = (FloatBufferTensor)this.blockF.make((int)((float)offset * 0.03125f), (int)((float)length * 0.03125f), Q4ByteBufferTensor.makeBlockShape(shape), cacheSlices);
        return new Q5ByteBufferTensor(this.name, this.b.slice(offset, length), newBlockF, this.b5, shape, cacheSlices);
    }

    @Override
    public float get(int ... dims) {
        int x;
        Preconditions.checkArgument((dims.length <= this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified");
        Preconditions.checkArgument((dims.length == this.shape.dims() ? 1 : 0) != 0, (Object)"Must specify all dimensions");
        int i = this.getOffset(dims);
        float scale = this.blockF.get(Q4ByteBufferTensor.makeBlockShape(dims));
        int q = this.b5[this.getOffset(Q4ByteBufferTensor.makeBlockShape(dims))];
        byte b0 = this.b.get(i / 2);
        int j = i % 32 / 2;
        if (i % 2 == 0) {
            byte xh = (byte)(q >> j + 0 << 4 & 0x10);
            x = (b0 & 0xF | xh) - 16;
        } else {
            byte xh = (byte)(q >> j + 12 & 0x10);
            x = (b0 >> 4 & 0xF | xh) - 16;
        }
        return (float)x * scale;
    }

    public final float getFactorForIndex(int i) {
        int ix = (int)((float)i * 0.03125f);
        if ((long)ix >= this.blockF.size()) {
            throw new RuntimeException();
        }
        return this.blockF.get(ix);
    }

    @Override
    public void set(float v, int ... dims) {
        throw new UnsupportedOperationException();
    }

    @Override
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override
    public int getMemorySegmentOffset(int offset) {
        return offset * this.dType.size();
    }

    @Override
    public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) {
        Preconditions.checkArgument((this.dType == src.dType ? 1 : 0) != 0, (Object)"different types");
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Read-only");
        this.segment.asSlice((long)this.getMemorySegmentOffset(destOffset), length).copyFrom(src.getMemorySegment().asSlice((long)src.getMemorySegmentOffset(srcOffset), length));
    }

    @Override
    public ByteVector getVector(VectorSpecies<Byte> species, int ... voffset) {
        int offset = this.getOffset(voffset);
        return ByteVector.fromMemorySegment(species, (MemorySegment)this.segment, (long)this.getMemorySegmentOffset(offset), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void intoTensor(ByteVector vector, int ... aoffset) {
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0);
        int offset = this.getOffset(aoffset);
        vector.intoMemorySegment(this.segment, (long)this.getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void clear() {
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Can't clear a read-only buffer");
        this.segment.fill((byte)0);
    }

    public String toString() {
        byte[] sample = new byte[Math.min(10, this.b.remaining())];
        this.b.duplicate().get(sample);
        return "Q5BufferTensor{name='" + this.name + "'shape=" + String.valueOf(this.shape) + ", b=" + Arrays.toString(sample) + "...}";
    }
}

