/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.google.common.base.Preconditions;

public class NaiveTensorOperations
implements TensorOperations {
    @Override
    public String name() {
        return "Naive Java Operations";
    }

    @Override
    public void accumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        Preconditions.checkArgument((a.dims() == b.dims() ? 1 : 0) != 0);
        boolean isBatch = b.shape().first() > 1;
        for (int ai = 0; ai < a.shape().first(); ++ai) {
            AbstractTensor as = a.slice(ai);
            AbstractTensor bs = isBatch ? b.slice(ai) : b;
            int i = offset;
            while (i < offset + length) {
                as.set(as.get(0, i) + bs.get(0, i), 0, i++);
            }
        }
    }

    @Override
    public void maccumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        Preconditions.checkArgument((a.size() == b.size() && a.dims() == b.dims() ? 1 : 0) != 0);
        boolean isBatch = b.shape().first() > 1;
        for (int ai = 0; ai < a.shape().first(); ++ai) {
            AbstractTensor as = a.slice(ai);
            AbstractTensor bs = isBatch ? b.slice(ai) : b;
            int i = offset;
            while (i < offset + length) {
                as.set(as.get(0, i) * bs.get(0, i), 0, i++);
            }
        }
    }

    @Override
    public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int boffset, int limit) {
        Preconditions.checkArgument((a.dims() == b.dims() && a.shape().first() == 1 ? 1 : 0) != 0);
        int alen = aoffset + limit;
        int blen = boffset + limit;
        float s = 0.0f;
        while (aoffset < alen && boffset < blen) {
            s += a.get(0, aoffset++) * b.get(0, boffset++);
        }
        return s;
    }

    @Override
    public void batchDotProduct(AbstractTensor result, AbstractTensor a, AbstractTensor b, int aColumnOffset, int bColumnOffset, int columnLength, int rRowOffset, int bRowOffset, int rowChunkSize) {
        Preconditions.checkArgument((a.dims() == 2 && b.dims() == 2 && result.dims() == 2 ? 1 : 0) != 0);
        int bRowLimit = bRowOffset + rowChunkSize;
        for (int i = 0; i < a.shape().first(); ++i) {
            int j = bRowOffset;
            int r = rRowOffset;
            while (j < bRowLimit) {
                float d = this.dotProduct(a.slice(i), b.slice(j++), aColumnOffset, bColumnOffset, columnLength);
                result.set(d, i, r++);
            }
        }
    }

    @Override
    public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
        Preconditions.checkArgument((x.shape().first() == 1 && y.shape().first() == 1 ? 1 : 0) != 0);
        int xo = xoffset;
        int yo = yoffset;
        while (xo < xoffset + limit && yo < yoffset + limit) {
            float v = alpha * x.get(0, xo++) + y.get(0, yo);
            y.set(v, 0, yo++);
        }
    }

    @Override
    public void scale(float factor, AbstractTensor x, int offset, int length) {
        int limit = offset + length;
        for (int b = 0; b < x.shape().first(); ++b) {
            int i = offset;
            while (i < limit) {
                x.set(x.get(b, i) * factor, b, i++);
            }
        }
    }

    @Override
    public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int length) {
        return t.quantize(qtype, true);
    }
}

