/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class FloatBufferTensor
extends AbstractTensor<FloatVector, Float> {
    private static final Logger logger = LoggerFactory.getLogger(FloatBufferTensor.class);
    private final FloatBuffer b;
    private final String name;
    private final MemorySegment segment;

    public FloatBufferTensor(AbstractTensor ft) {
        this(ft.shape);
        Preconditions.checkArgument((ft.dType != DType.I32 ? 1 : 0) != 0, (Object)"This should never happen, likely a bug");
        int[] cursor = new int[ft.shape.dims()];
        do {
            this.set(ft.get(cursor), cursor);
        } while (ft.iterate(cursor));
    }

    public FloatBufferTensor(int ... shape) {
        this(TensorShape.of(shape));
    }

    public FloatBufferTensor(TensorShape shape) {
        super(DType.F32, shape, true);
        this.name = "tmp";
        this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast((long)(shape.size() * (long)this.dType().size())), 64L).asFloatBuffer();
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public FloatBufferTensor(FloatBuffer b, TensorShape shape, boolean cacheSlices) {
        this("none", b, shape, cacheSlices);
    }

    public FloatBufferTensor(String name, FloatBuffer b, TensorShape shape, boolean cacheSlices) {
        super(DType.F32, shape, cacheSlices);
        this.name = name;
        if (b.isDirect()) {
            this.b = b;
        } else {
            this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast((long)(this.size() * (long)this.dType().size())), 64L).asFloatBuffer();
            this.b.duplicate().put(b);
        }
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    @Override
    protected AbstractTensor make(TensorShape shape) {
        return new FloatBufferTensor(shape);
    }

    @Override
    protected AbstractTensor make(int offset, int length, TensorShape shape, boolean cacheSlices) {
        return new FloatBufferTensor(this.name, this.b.slice(offset, length), shape, cacheSlices);
    }

    @Override
    public float get(int ... dims) {
        Preconditions.checkArgument((dims.length <= this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified");
        Preconditions.checkArgument((dims.length == this.shape.dims() ? 1 : 0) != 0, (Object)"Must specify all dimensions");
        return this.b.get(this.getOffset(dims));
    }

    @Override
    public void set(float v, int ... dims) {
        Preconditions.checkArgument((dims.length <= this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified for tensor");
        Preconditions.checkArgument((dims.length == this.shape.dims() ? 1 : 0) != 0, (Object)"Must specify all dimensions");
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Can't modify a read only buffer");
        this.b.put(this.getOffset(dims), v);
    }

    @Override
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override
    public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) {
        Preconditions.checkArgument((this.dType == src.dType ? 1 : 0) != 0, (Object)"Different types");
        this.segment.asSlice((long)this.getMemorySegmentOffset(destOffset), length * this.dType.size()).copyFrom(src.getMemorySegment().asSlice((long)src.getMemorySegmentOffset(srcOffset), length * this.dType.size()));
    }

    @Override
    public int getMemorySegmentOffset(int offset) {
        return offset * 4;
    }

    @Override
    public FloatVector getVector(VectorSpecies<Float> species, int ... voffset) {
        int offset = this.getOffset(voffset);
        return FloatVector.fromMemorySegment(species, (MemorySegment)this.segment, (long)this.getMemorySegmentOffset(offset), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void intoTensor(FloatVector vector, int ... aoffset) {
        int offset = this.getOffset(aoffset);
        vector.intoMemorySegment(this.segment, (long)this.getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void clear() {
        this.segment.fill((byte)0);
    }

    public String toString() {
        float[] sample = new float[DebugSupport.isDebug() ? this.b.remaining() : Math.min(10, this.b.remaining())];
        this.b.duplicate().get(sample);
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < sample.length; ++i) {
            sb.append(String.format("%8.4f", Float.valueOf(sample[i])));
            if (i >= sample.length - 1) continue;
            sb.append(", ");
        }
        return "FloatBufferTensor{name='" + this.name + "' shape=" + String.valueOf(this.shape) + ",\nb={" + String.valueOf(sb) + "...}";
    }
}

