package com.github.tjake.jlama.model;

import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;
import net.jafama.FastMath;

/* loaded from: input_file:com/github/tjake/jlama/model/LayerNorm.class */
public class LayerNorm {
    protected final AbstractModel m;
    private final AbstractTensor bias;
    protected final AbstractTensor weights;

    public LayerNorm(AbstractModel abstractModel, AbstractTensor abstractTensor, AbstractTensor abstractTensor2) {
        this.m = abstractModel;
        this.bias = abstractTensor;
        this.weights = abstractTensor2;
    }

    public AbstractTensor forward(AbstractTensor abstractTensor) {
        Preconditions.checkArgument(abstractTensor.shape().dims() == 2);
        Preconditions.checkArgument(abstractTensor.shape().last() == this.m.c.embeddingLength);
        return forward(abstractTensor, 0, this.m.c.embeddingLength);
    }

    public AbstractTensor forward(AbstractTensor abstractTensor, int i, int i2) {
        int first = abstractTensor.shape().first();
        AbstractTensor copyShape = abstractTensor.copyShape();
        for (int i3 = 0; i3 < first; i3++) {
            float f = 0.0f;
            float f2 = 0.0f;
            int i4 = i + i2;
            for (int i5 = i; i5 < i4; i5++) {
                float f3 = abstractTensor.get(i3, i5);
                f += f3;
                f2 += f3 * f3;
            }
            float f4 = f / this.m.c.embeddingLength;
            float sqrt = 1.0f / ((float) FastMath.sqrt(((f2 / this.m.c.embeddingLength) - (f4 * f4)) + this.m.c.layerNormEps));
            for (int i6 = i; i6 < i4; i6++) {
                copyShape.set(((abstractTensor.get(i3, i6) - f4) * sqrt * this.weights.get(0, i6)) + this.bias.get(0, i6), i3, i6);
            }
        }
        return copyShape;
    }
}
