package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.util.BiIntConsumer;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import com.google.common.base.Preconditions;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.class */
public final class PanamaTensorOperations implements TensorOperations {
    private static final Logger logger = LoggerFactory.getLogger(PanamaTensorOperations.class);
    static final ByteVector Q4_BYTE_SUB_128 = ByteVector.broadcast(ByteVector.SPECIES_128, 8);
    static final ByteVector Q4_BYTE_MASK_128 = ByteVector.broadcast(ByteVector.SPECIES_128, 15);
    static final ByteVector Q4_BYTE_SHIFT_128 = ByteVector.broadcast(ByteVector.SPECIES_128, 4);
    static final ByteVector Q4_BYTE_SUB_64 = ByteVector.broadcast(ByteVector.SPECIES_64, 8);
    static final ByteVector Q4_BYTE_MASK_64 = ByteVector.broadcast(ByteVector.SPECIES_64, 15);
    static final ByteVector Q4_BYTE_SHIFT_64 = ByteVector.broadcast(ByteVector.SPECIES_64, 4);
    static final IntVector BF16_BYTE_SHIFT = IntVector.broadcast(IntVector.SPECIES_PREFERRED, 16);
    static final IntVector BF16_BYTE_SHIFT_512 = IntVector.broadcast(IntVector.SPECIES_512, 16);
    static final FloatVector F32_ROUND_UP_512 = FloatVector.broadcast(FloatVector.SPECIES_512, 0.5f);
    static final IntVector BF16_BYTE_SHIFT_256 = IntVector.broadcast(IntVector.SPECIES_256, 16);
    static final FloatVector F32_ROUND_UP_256 = FloatVector.broadcast(FloatVector.SPECIES_256, 0.5f);
    static final IntVector BF16_BYTE_SHIFT_128 = IntVector.broadcast(IntVector.SPECIES_128, 16);
    static final FloatVector F32_ROUND_UP_128 = FloatVector.broadcast(FloatVector.SPECIES_128, 0.5f);
    static final VectorMask<Byte> BYTE_MASK_32 = VectorMask.fromValues(ByteVector.SPECIES_64, new boolean[]{true, true, true, true, false, false, false, false});
    private final MachineSpec.Type vectorType;

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$Gemmer.class */
    private abstract class Gemmer {
        final int k;
        final AbstractTensor a;
        final AbstractTensor b;
        final AbstractTensor c;
        final int aColumnOffset;
        final int bColumnOffset;
        final int rOffset;

        Gemmer(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            this.k = i;
            this.a = abstractTensor;
            this.b = abstractTensor2;
            this.c = abstractTensor3;
            this.aColumnOffset = i2;
            this.bColumnOffset = i3;
            this.rOffset = i4;
        }

        void matmul(int i, int i2, int i3, int i4) {
            mnpack(i, i2, i3, i4);
        }

        private void mnpack(int i, int i2, int i3, int i4) {
            if (i2 - i <= 0 || i4 - i3 <= 0) {
                return;
            }
            int pickKernel = pickKernel(i, i2, i3, i4);
            int i5 = pickKernel >> 4;
            int i6 = pickKernel & 15;
            int i7 = i + (((i2 - i) / i5) * i5);
            int i8 = i3 + (((i4 - i3) / i6) * i6);
            mnpack(i7, i2, i3, i8);
            mnpack(i, i7, i8, i4);
        }

        protected abstract int pickKernel(int i, int i2, int i3, int i4);

        void kernel(int i, int i2, int i3, int i4, int i5, int i6, BiIntConsumer biIntConsumer) {
            int i7 = (i5 - i4) / i6;
            int i8 = ((i2 - i) / i3) * i7;
            for (int i9 = 0; i9 < i8; i9++) {
                biIntConsumer.accept(i + ((i9 / i7) * i3), i4 + ((i9 % i7) * i6));
            }
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerBF16.class */
    private class GemmerBF16 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BFloat16BufferTensor a;
        final BFloat16BufferTensor b;

        GemmerBF16(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.matmul1x1 = initMatmul1x1();
            this.a = (BFloat16BufferTensor) abstractTensor;
            this.b = (BFloat16BufferTensor) abstractTensor2;
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            return (1 << 4) | 1;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.aColumnOffset + this.k;
                int i3 = this.bColumnOffset + this.k;
                int length = ShortVector.SPECIES_PREFERRED.length();
                for (int i4 = this.bColumnOffset; i < i2 && i4 < i3; i4 += length) {
                    ShortVector mo41getVector = this.a.mo41getVector(ShortVector.SPECIES_PREFERRED, i, i);
                    FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats();
                    ShortVector mo41getVector2 = this.b.mo41getVector(ShortVector.SPECIES_PREFERRED, i2, i4);
                    zero = reinterpretAsFloats2.fma(mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats(), reinterpretAsFloats.fma(mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats(), zero));
                    i += length;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerF32.class */
    private class GemmerF32 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;

        GemmerF32(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = initMatmul1x4();
            this.matmul3x4 = initMatmul3x4();
            this.matmul4x1 = initMatmul4x1();
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            int i5;
            int i6;
            if (i2 - i >= 3 && i4 - i3 >= 4) {
                i5 = 3;
                i6 = 4;
                kernel(i, i2, 3, i3, i4, 4, this.matmul3x4);
            } else if (i2 - i >= 4 && i4 - i3 >= 1) {
                i5 = 4;
                i6 = 1;
                kernel(i, i2, 4, i3, i4, 1, this.matmul4x1);
            } else if (i2 - i < 1 || i4 - i3 < 4) {
                i5 = 1;
                i6 = 1;
                kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            } else {
                i5 = 1;
                i6 = 4;
                kernel(i, i2, 1, i3, i4, 4, this.matmul1x4);
            }
            return (i5 << 4) | i6;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = this.aColumnOffset + this.k;
                int i4 = this.bColumnOffset + this.k;
                while (true) {
                    if (i >= i3 && i2 >= i4) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
                        return;
                    } else {
                        zero = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i, i).reinterpretAsFloats().fma(this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2, i2).reinterpretAsFloats(), zero);
                        i += FloatVector.SPECIES_PREFERRED.length();
                        i2 += FloatVector.SPECIES_PREFERRED.length();
                    }
                }
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = this.aColumnOffset + this.k;
                int i4 = this.bColumnOffset + this.k;
                while (true) {
                    if (i >= i3 && i2 >= i4) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + 0 + this.rOffset);
                        this.c.set(zero2.reduceLanes(VectorOperators.ADD), i, i2 + 1 + this.rOffset);
                        this.c.set(zero3.reduceLanes(VectorOperators.ADD), i, i2 + 2 + this.rOffset);
                        this.c.set(zero4.reduceLanes(VectorOperators.ADD), i, i2 + 3 + this.rOffset);
                        return;
                    }
                    FloatVector reinterpretAsFloats = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i, i).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats2 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 0, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats3 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 1, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats4 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 2, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats5 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 3, i2).reinterpretAsFloats();
                    zero = reinterpretAsFloats.fma(reinterpretAsFloats2, zero);
                    zero2 = reinterpretAsFloats.fma(reinterpretAsFloats3, zero2);
                    zero3 = reinterpretAsFloats.fma(reinterpretAsFloats4, zero3);
                    zero4 = reinterpretAsFloats.fma(reinterpretAsFloats5, zero4);
                    i += FloatVector.SPECIES_PREFERRED.length();
                    i2 += FloatVector.SPECIES_PREFERRED.length();
                }
            };
        }

        protected BiIntConsumer initMatmul3x4() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero5 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero6 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero7 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero8 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero9 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero10 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero11 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero12 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = this.aColumnOffset + this.k;
                int i4 = this.bColumnOffset + this.k;
                while (true) {
                    if (i >= i3 && i2 >= i4) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i + 0, i2 + 0 + this.rOffset);
                        this.c.set(zero2.reduceLanes(VectorOperators.ADD), i + 0, i2 + 1 + this.rOffset);
                        this.c.set(zero3.reduceLanes(VectorOperators.ADD), i + 0, i2 + 2 + this.rOffset);
                        this.c.set(zero4.reduceLanes(VectorOperators.ADD), i + 0, i2 + 3 + this.rOffset);
                        this.c.set(zero5.reduceLanes(VectorOperators.ADD), i + 1, i2 + 0 + this.rOffset);
                        this.c.set(zero6.reduceLanes(VectorOperators.ADD), i + 1, i2 + 1 + this.rOffset);
                        this.c.set(zero7.reduceLanes(VectorOperators.ADD), i + 1, i2 + 2 + this.rOffset);
                        this.c.set(zero8.reduceLanes(VectorOperators.ADD), i + 1, i2 + 3 + this.rOffset);
                        this.c.set(zero9.reduceLanes(VectorOperators.ADD), i + 2, i2 + 0 + this.rOffset);
                        this.c.set(zero10.reduceLanes(VectorOperators.ADD), i + 2, i2 + 1 + this.rOffset);
                        this.c.set(zero11.reduceLanes(VectorOperators.ADD), i + 2, i2 + 2 + this.rOffset);
                        this.c.set(zero12.reduceLanes(VectorOperators.ADD), i + 2, i2 + 3 + this.rOffset);
                        return;
                    }
                    FloatVector reinterpretAsFloats = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 0, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats2 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 1, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats3 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 2, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats4 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2 + 3, i2).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats5 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 0, i).reinterpretAsFloats();
                    zero = reinterpretAsFloats5.fma(reinterpretAsFloats, zero);
                    zero2 = reinterpretAsFloats5.fma(reinterpretAsFloats2, zero2);
                    zero3 = reinterpretAsFloats5.fma(reinterpretAsFloats3, zero3);
                    zero4 = reinterpretAsFloats5.fma(reinterpretAsFloats4, zero4);
                    FloatVector reinterpretAsFloats6 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 1, i).reinterpretAsFloats();
                    zero5 = reinterpretAsFloats6.fma(reinterpretAsFloats, zero5);
                    zero6 = reinterpretAsFloats6.fma(reinterpretAsFloats2, zero6);
                    zero7 = reinterpretAsFloats6.fma(reinterpretAsFloats3, zero7);
                    zero8 = reinterpretAsFloats6.fma(reinterpretAsFloats4, zero8);
                    FloatVector reinterpretAsFloats7 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 2, i).reinterpretAsFloats();
                    zero9 = reinterpretAsFloats7.fma(reinterpretAsFloats, zero9);
                    zero10 = reinterpretAsFloats7.fma(reinterpretAsFloats2, zero10);
                    zero11 = reinterpretAsFloats7.fma(reinterpretAsFloats3, zero11);
                    zero12 = reinterpretAsFloats7.fma(reinterpretAsFloats4, zero12);
                    i += FloatVector.SPECIES_PREFERRED.length();
                    i2 += FloatVector.SPECIES_PREFERRED.length();
                }
            };
        }

        protected BiIntConsumer initMatmul4x1() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = this.aColumnOffset + this.k;
                int i4 = this.bColumnOffset + this.k;
                while (true) {
                    if (i >= i3 && i2 >= i4) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i + 0, i2 + this.rOffset);
                        this.c.set(zero2.reduceLanes(VectorOperators.ADD), i + 1, i2 + this.rOffset);
                        this.c.set(zero3.reduceLanes(VectorOperators.ADD), i + 2, i2 + this.rOffset);
                        this.c.set(zero4.reduceLanes(VectorOperators.ADD), i + 3, i2 + this.rOffset);
                        return;
                    }
                    FloatVector reinterpretAsFloats = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 0, i).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats2 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 1, i).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats3 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 2, i).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats4 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i + 3, i).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats5 = this.b.mo41getVector(FloatVector.SPECIES_PREFERRED, i2, i2).reinterpretAsFloats();
                    zero = reinterpretAsFloats.fma(reinterpretAsFloats5, zero);
                    zero2 = reinterpretAsFloats2.fma(reinterpretAsFloats5, zero2);
                    zero3 = reinterpretAsFloats3.fma(reinterpretAsFloats5, zero3);
                    zero4 = reinterpretAsFloats4.fma(reinterpretAsFloats5, zero4);
                    i += FloatVector.SPECIES_PREFERRED.length();
                    i2 += FloatVector.SPECIES_PREFERRED.length();
                }
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerF32BF16.class */
    private class GemmerF32BF16 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final FloatBufferTensor a;
        final BFloat16BufferTensor b;

        GemmerF32BF16(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.matmul1x1 = initMatmul1x1();
            this.a = (FloatBufferTensor) abstractTensor;
            this.b = (BFloat16BufferTensor) abstractTensor2;
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            return (1 << 4) | 1;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
                int i = this.aColumnOffset;
                int i2 = this.aColumnOffset + this.k;
                int i3 = this.bColumnOffset + this.k;
                int length = ShortVector.SPECIES_PREFERRED.length();
                for (int i4 = this.bColumnOffset; i < i2 && i4 < i3; i4 += length) {
                    FloatVector mo41getVector = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i, i);
                    FloatVector mo41getVector2 = this.a.mo41getVector(FloatVector.SPECIES_PREFERRED, i, i + FloatVector.SPECIES_PREFERRED.length());
                    ShortVector mo41getVector3 = this.b.mo41getVector(ShortVector.SPECIES_PREFERRED, i2, i4);
                    zero = mo41getVector2.fma(mo41getVector3.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats(), mo41getVector.fma(mo41getVector3.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, PanamaTensorOperations.BF16_BYTE_SHIFT).reinterpretAsFloats(), zero));
                    i += length;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerF32Q4_256.class */
    private class GemmerF32Q4_256 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q4ByteBufferTensor b;
        final FloatBufferTensor a;

        GemmerF32Q4_256(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.a = (FloatBufferTensor) abstractTensor;
            this.b = (Q4ByteBufferTensor) abstractTensor2;
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = initMatmul1x4();
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            return (1 << 4) | 1;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = i + this.k;
                int i4 = i2 + this.k;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_256);
                while (i < i3 && i2 < i4) {
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, this.b.getFactorForIndex(i2, i2));
                    ByteVector mo41getVector = this.b.mo41getVector(ByteVector.SPECIES_128, i2, i2);
                    ByteVector sub = mo41getVector.and(PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    ByteVector sub2 = mo41getVector.lanewise(VectorOperators.LSHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    FloatVector mul = this.a.mo41getVector(FloatVector.SPECIES_256, i, i).mul(sub.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector mul2 = this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 8).mul(sub.castShape(FloatVector.SPECIES_256, 1));
                    zero = mul.add(mul2).add(this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 16).mul(sub2.castShape(FloatVector.SPECIES_256, 0))).add(this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 16 + 8).mul(sub2.castShape(FloatVector.SPECIES_256, 1))).fma(broadcast, zero);
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = i + this.k;
                int i4 = i2 + this.k;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_256);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_256);
                while (i < i3 && i2 < i4) {
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, this.b.getFactorForIndex(i2 + 0, i2));
                    FloatVector broadcast2 = FloatVector.broadcast(FloatVector.SPECIES_256, this.b.getFactorForIndex(i2 + 1, i2));
                    FloatVector mo41getVector = this.a.mo41getVector(FloatVector.SPECIES_256, i, i);
                    FloatVector mo41getVector2 = this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 8);
                    FloatVector mo41getVector3 = this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 16);
                    FloatVector mo41getVector4 = this.a.mo41getVector(FloatVector.SPECIES_256, i, i + 16 + 8);
                    ByteVector mo41getVector5 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 0, i2);
                    ByteVector sub = mo41getVector5.and(PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    ByteVector sub2 = mo41getVector5.lanewise(VectorOperators.LSHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    FloatVector mul = mo41getVector.mul(sub.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector mul2 = mo41getVector2.mul(sub.castShape(FloatVector.SPECIES_256, 1));
                    zero = mul.add(mul2).add(mo41getVector3.mul(sub2.castShape(FloatVector.SPECIES_256, 0))).add(mo41getVector4.mul(sub2.castShape(FloatVector.SPECIES_256, 1))).fma(broadcast, zero);
                    ByteVector mo41getVector6 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 1, i2);
                    ByteVector sub3 = mo41getVector6.and(PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    ByteVector sub4 = mo41getVector6.lanewise(VectorOperators.LSHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128);
                    FloatVector mul3 = mo41getVector.mul(sub3.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector mul4 = mo41getVector2.mul(sub3.castShape(FloatVector.SPECIES_256, 1));
                    zero2 = mul3.add(mul4).add(mo41getVector3.mul(sub4.castShape(FloatVector.SPECIES_256, 0))).add(mo41getVector4.mul(sub4.castShape(FloatVector.SPECIES_256, 1))).fma(broadcast2, zero2);
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + 0 + this.rOffset);
                this.c.set(zero2.reduceLanes(VectorOperators.ADD), i, i2 + 1 + this.rOffset);
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerF32Q4_512.class */
    private class GemmerF32Q4_512 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q4ByteBufferTensor b;
        final FloatBufferTensor a;

        GemmerF32Q4_512(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.a = (FloatBufferTensor) abstractTensor;
            this.b = (Q4ByteBufferTensor) abstractTensor2;
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = initMatmul1x4();
            this.matmul3x4 = null;
            this.matmul4x1 = initMatmul4x1();
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            int i5;
            int i6;
            if (i2 - i >= 4 && i4 - i3 >= 1) {
                i5 = 4;
                i6 = 1;
                kernel(i, i2, 4, i3, i4, 1, this.matmul4x1);
            } else if (i2 - i < 1 || i4 - i3 < 4) {
                i5 = 1;
                i6 = 1;
                kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            } else {
                i5 = 1;
                i6 = 4;
                kernel(i, i2, 1, i3, i4, 4, this.matmul1x4);
            }
            return (i5 << 4) | i6;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = i + this.k;
                int i4 = i2 + this.k;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                while (i < i3 && i2 < i4) {
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2, i2));
                    FloatVector mo41getVector = this.a.mo41getVector(FloatVector.SPECIES_512, i, i);
                    FloatVector mo41getVector2 = this.a.mo41getVector(FloatVector.SPECIES_512, i, i + 16);
                    ByteVector mo41getVector3 = this.b.mo41getVector(ByteVector.SPECIES_128, i2, i2);
                    zero = mo41getVector2.fma(mo41getVector3.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast), mo41getVector.fma(mo41getVector3.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast), zero));
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
            };
        }

        protected final BiIntConsumer initMatmul4x1() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = i + this.k;
                int i4 = i2 + this.k;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_512);
                while (i < i3 && i2 < i4) {
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2, i2));
                    ByteVector mo41getVector = this.b.mo41getVector(ByteVector.SPECIES_128, i2, i2);
                    Vector mul = mo41getVector.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast);
                    Vector mul2 = mo41getVector.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast);
                    FloatVector mo41getVector2 = this.a.mo41getVector(FloatVector.SPECIES_512, i, i);
                    FloatVector mo41getVector3 = this.a.mo41getVector(FloatVector.SPECIES_512, i, i + 16);
                    FloatVector mo41getVector4 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 1, i);
                    FloatVector mo41getVector5 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 1, i + 16);
                    FloatVector mo41getVector6 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 2, i);
                    FloatVector mo41getVector7 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 2, i + 16);
                    FloatVector mo41getVector8 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 3, i);
                    FloatVector mo41getVector9 = this.a.mo41getVector(FloatVector.SPECIES_512, i + 3, i + 16);
                    zero = mo41getVector3.fma(mul2, mo41getVector2.fma(mul, zero));
                    zero2 = mo41getVector5.fma(mul2, mo41getVector4.fma(mul, zero2));
                    zero3 = mo41getVector7.fma(mul2, mo41getVector6.fma(mul, zero3));
                    zero4 = mo41getVector9.fma(mul2, mo41getVector8.fma(mul, zero4));
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i + 0, i2 + this.rOffset);
                this.c.set(zero2.reduceLanes(VectorOperators.ADD), i + 1, i2 + this.rOffset);
                this.c.set(zero3.reduceLanes(VectorOperators.ADD), i + 2, i2 + this.rOffset);
                this.c.set(zero4.reduceLanes(VectorOperators.ADD), i + 3, i2 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                int i3 = i + this.k;
                int i4 = i2 + this.k;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_512);
                while (i < i3 && i2 < i4) {
                    FloatVector mo41getVector = this.a.mo41getVector(FloatVector.SPECIES_512, i, i);
                    FloatVector mo41getVector2 = this.a.mo41getVector(FloatVector.SPECIES_512, i, i + 16);
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2 + 0, i2));
                    ByteVector mo41getVector3 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 0, i2);
                    zero = mo41getVector2.fma(mo41getVector3.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast), mo41getVector.fma(mo41getVector3.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast), zero));
                    FloatVector broadcast2 = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2 + 1, i2));
                    ByteVector mo41getVector4 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 1, i2);
                    zero2 = mo41getVector2.fma(mo41getVector4.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast2), mo41getVector.fma(mo41getVector4.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast2), zero2));
                    FloatVector broadcast3 = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2 + 2, i2));
                    ByteVector mo41getVector5 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 2, i2);
                    zero3 = mo41getVector2.fma(mo41getVector5.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast3), mo41getVector.fma(mo41getVector5.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast3), zero3));
                    FloatVector broadcast4 = FloatVector.broadcast(FloatVector.SPECIES_512, this.b.getFactorForIndex(i2 + 3, i2));
                    ByteVector mo41getVector6 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 3, i2);
                    zero4 = mo41getVector2.fma(mo41getVector6.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast4), mo41getVector.fma(mo41getVector6.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul(broadcast4), zero4));
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + 0 + this.rOffset);
                this.c.set(zero2.reduceLanes(VectorOperators.ADD), i, i2 + 1 + this.rOffset);
                this.c.set(zero3.reduceLanes(VectorOperators.ADD), i, i2 + 2 + this.rOffset);
                this.c.set(zero4.reduceLanes(VectorOperators.ADD), i, i2 + 3 + this.rOffset);
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerI8Q4_256.class */
    private class GemmerI8Q4_256 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_256(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.a = (Q8ByteBufferTensor) abstractTensor;
            this.b = (Q4ByteBufferTensor) abstractTensor2;
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = null;
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            return (1 << 4) | 1;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                int i = this.k / 32;
                int i2 = this.aColumnOffset;
                int i3 = this.bColumnOffset;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_256);
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 >= i) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
                        return;
                    }
                    FloatVector mul = this.a.getBlockF().mo41getVector(FloatVector.SPECIES_256, i, (int) (0.03125f * i2)).mul(this.b.getBlockF().mo41getVector(FloatVector.SPECIES_256, i2, (int) (0.03125f * i3)));
                    int i6 = 0;
                    while (i6 < FloatVector.SPECIES_256.length()) {
                        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, mul.lane(i6));
                        ByteVector mo41getVector = this.a.mo41getVector(ByteVector.SPECIES_256, i, i2);
                        Vector convertShape = mo41getVector.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                        Vector convertShape2 = mo41getVector.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 1);
                        ByteVector mo41getVector2 = this.b.mo41getVector(ByteVector.SPECIES_128, i2, i3);
                        Vector add = mo41getVector2.and(PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(convertShape).add(mo41getVector2.lanewise(VectorOperators.LSHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(convertShape2));
                        zero = broadcast.fma(add.convertShape(VectorOperators.S2F, FloatVector.SPECIES_256, 0).add(add.convertShape(VectorOperators.S2F, FloatVector.SPECIES_256, 1)), zero);
                        i6++;
                        i2 += 32;
                        i3 += 32;
                    }
                    i4 = i5 + FloatVector.SPECIES_256.length();
                }
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerI8Q4_512.class */
    private class GemmerI8Q4_512 extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_512(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.a = (Q8ByteBufferTensor) abstractTensor;
            this.b = (Q4ByteBufferTensor) abstractTensor2;
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = initMatmul1x4();
            this.matmul3x4 = initMatmul3x4();
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            int i5;
            int i6;
            if (i2 - i >= 2 && i4 - i3 >= 2) {
                i5 = 2;
                i6 = 2;
                kernel(i, i2, 2, i3, i4, 2, this.matmul3x4);
            } else if (i2 - i < 1 || i4 - i3 < 4) {
                i5 = 1;
                i6 = 1;
                kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            } else {
                i5 = 1;
                i6 = 4;
                kernel(i, i2, 1, i3, i4, 4, this.matmul1x4);
            }
            return (i5 << 4) | i6;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                int i3 = 0;
                while (i3 < this.k) {
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, this.a.getFactorForIndex(i, i) * this.b.getFactorForIndex(i2, i2));
                    ShortVector reinterpretAsShorts = this.a.mo41getVector(ByteVector.SPECIES_256, i, i).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0).reinterpretAsShorts();
                    ByteVector mo41getVector = this.b.mo41getVector(ByteVector.SPECIES_128, i2, i2);
                    zero = broadcast.fma(mo41getVector.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(reinterpretAsShorts.castShape(ShortVector.SPECIES_256, 0)).add(mo41getVector.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(reinterpretAsShorts.castShape(ShortVector.SPECIES_256, 1))).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0), zero);
                    i3 += 32;
                    i += 32;
                    i2 += 32;
                }
                this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, i2) -> {
                int i = this.k / 32;
                int i2 = this.aColumnOffset;
                int i3 = this.bColumnOffset;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_512);
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 >= i) {
                        float reduceLanes = zero.reduceLanes(VectorOperators.ADD);
                        float reduceLanes2 = zero2.reduceLanes(VectorOperators.ADD);
                        float reduceLanes3 = zero3.reduceLanes(VectorOperators.ADD);
                        float reduceLanes4 = zero4.reduceLanes(VectorOperators.ADD);
                        this.c.set(reduceLanes, i, i2 + 0 + this.rOffset);
                        this.c.set(reduceLanes2, i, i2 + 1 + this.rOffset);
                        this.c.set(reduceLanes3, i, i2 + 2 + this.rOffset);
                        this.c.set(reduceLanes4, i, i2 + 3 + this.rOffset);
                        return;
                    }
                    int i6 = 0;
                    while (i6 < FloatVector.SPECIES_512.length()) {
                        float factorForIndex = this.a.getFactorForIndex(i + 0, i2);
                        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * this.b.getFactorForIndex(i2 + 0, i3));
                        FloatVector broadcast2 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * this.b.getFactorForIndex(i2 + 1, i3));
                        FloatVector broadcast3 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * this.b.getFactorForIndex(i2 + 2, i3));
                        FloatVector broadcast4 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * this.b.getFactorForIndex(i2 + 3, i3));
                        Vector convertShape = this.a.mo41getVector(ByteVector.SPECIES_256, i, i2).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                        Vector castShape = convertShape.castShape(ShortVector.SPECIES_256, 0);
                        Vector castShape2 = convertShape.castShape(ShortVector.SPECIES_256, 1);
                        ByteVector mo41getVector = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 0, i3);
                        ByteVector mo41getVector2 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 1, i3);
                        ByteVector mo41getVector3 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 2, i3);
                        ByteVector mo41getVector4 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 3, i3);
                        Vector convertShape2 = mo41getVector.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape).add(mo41getVector.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector convertShape3 = mo41getVector2.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape).add(mo41getVector2.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector convertShape4 = mo41getVector3.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape).add(mo41getVector3.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector convertShape5 = mo41getVector4.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape).add(mo41getVector4.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        zero = broadcast.fma(convertShape2, zero);
                        zero2 = broadcast2.fma(convertShape3, zero2);
                        zero3 = broadcast3.fma(convertShape4, zero3);
                        zero4 = broadcast4.fma(convertShape5, zero4);
                        i6++;
                        i2 += 32;
                        i3 += 32;
                    }
                    i4 = i5 + FloatVector.SPECIES_512.length();
                }
            };
        }

        protected BiIntConsumer initMatmul3x4() {
            return (i, i2) -> {
                int i = this.aColumnOffset;
                int i2 = this.bColumnOffset;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero2 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero3 = FloatVector.zero(FloatVector.SPECIES_512);
                FloatVector zero4 = FloatVector.zero(FloatVector.SPECIES_512);
                int i3 = 0;
                while (i3 < this.k) {
                    float factorForIndex = this.a.getFactorForIndex(i + 0, i);
                    float factorForIndex2 = this.a.getFactorForIndex(i + 1, i);
                    float factorForIndex3 = this.b.getFactorForIndex(i2 + 0, i2);
                    float factorForIndex4 = this.b.getFactorForIndex(i2 + 1, i2);
                    FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * factorForIndex3);
                    FloatVector broadcast2 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex * factorForIndex4);
                    FloatVector broadcast3 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex2 * factorForIndex3);
                    FloatVector broadcast4 = FloatVector.broadcast(FloatVector.SPECIES_512, factorForIndex2 * factorForIndex4);
                    Vector convertShape = this.a.mo41getVector(ByteVector.SPECIES_256, i + 0, i).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                    Vector convertShape2 = this.a.mo41getVector(ByteVector.SPECIES_256, i + 1, i).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                    Vector castShape = convertShape.castShape(ShortVector.SPECIES_256, 0);
                    Vector castShape2 = convertShape.castShape(ShortVector.SPECIES_256, 1);
                    Vector castShape3 = convertShape2.castShape(ShortVector.SPECIES_256, 0);
                    Vector castShape4 = convertShape2.castShape(ShortVector.SPECIES_256, 1);
                    ByteVector mo41getVector = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 0, i2);
                    ByteVector mo41getVector2 = this.b.mo41getVector(ByteVector.SPECIES_128, i2 + 1, i2);
                    Vector convertShape3 = mo41getVector.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector convertShape4 = mo41getVector.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector convertShape5 = mo41getVector2.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector convertShape6 = mo41getVector2.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_128).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_128).sub(PanamaTensorOperations.Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector convertShape7 = convertShape3.mul(castShape).add(convertShape4.mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector convertShape8 = convertShape5.mul(castShape).add(convertShape6.mul(castShape2)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector convertShape9 = convertShape3.mul(castShape3).add(convertShape4.mul(castShape4)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector convertShape10 = convertShape5.mul(castShape3).add(convertShape6.mul(castShape4)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    zero = broadcast.fma(convertShape7, zero);
                    zero2 = broadcast2.fma(convertShape8, zero2);
                    zero3 = broadcast3.fma(convertShape9, zero3);
                    zero4 = broadcast4.fma(convertShape10, zero4);
                    i3 += 32;
                    i += 32;
                    i2 += 32;
                }
                float reduceLanes = zero.reduceLanes(VectorOperators.ADD);
                float reduceLanes2 = zero2.reduceLanes(VectorOperators.ADD);
                float reduceLanes3 = zero3.reduceLanes(VectorOperators.ADD);
                float reduceLanes4 = zero4.reduceLanes(VectorOperators.ADD);
                this.c.set(reduceLanes, i + 0, i2 + 0 + this.rOffset);
                this.c.set(reduceLanes2, i + 0, i2 + 1 + this.rOffset);
                this.c.set(reduceLanes3, i + 1, i2 + 0 + this.rOffset);
                this.c.set(reduceLanes4, i + 1, i2 + 1 + this.rOffset);
            };
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/PanamaTensorOperations$GemmerI8Q4_arm.class */
    private class GemmerI8Q4_arm extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_arm(PanamaTensorOperations panamaTensorOperations, int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, int i3, int i4) {
            super(panamaTensorOperations, i, abstractTensor, abstractTensor2, abstractTensor3, i2, i3, i4);
            this.a = (Q8ByteBufferTensor) abstractTensor;
            this.b = (Q4ByteBufferTensor) abstractTensor2;
            this.matmul1x1 = initMatmul1x1();
            this.matmul1x4 = null;
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override // com.github.tjake.jlama.tensor.operations.PanamaTensorOperations.Gemmer
        protected int pickKernel(int i, int i2, int i3, int i4) {
            kernel(i, i2, 1, i3, i4, 1, this.matmul1x1);
            return (1 << 4) | 1;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, i2) -> {
                int i = this.k / 32;
                int i2 = this.aColumnOffset;
                int i3 = this.bColumnOffset;
                FloatVector zero = FloatVector.zero(FloatVector.SPECIES_128);
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 >= i) {
                        this.c.set(zero.reduceLanes(VectorOperators.ADD), i, i2 + this.rOffset);
                        return;
                    }
                    FloatVector mul = this.a.getBlockF().mo41getVector(FloatVector.SPECIES_128, i, (int) (0.03125f * i2)).mul(this.b.getBlockF().mo41getVector(FloatVector.SPECIES_128, i2, (int) (0.03125f * i3)));
                    int i6 = 0;
                    while (i6 < FloatVector.SPECIES_128.length()) {
                        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_128, mul.lane(i6));
                        ByteVector mo41getVector = this.a.mo41getVector(ByteVector.SPECIES_128, i, i2);
                        ByteVector mo41getVector2 = this.a.mo41getVector(ByteVector.SPECIES_128, i, i2 + 16);
                        Vector convertShape = mo41getVector.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
                        Vector convertShape2 = mo41getVector.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);
                        Vector convertShape3 = mo41getVector2.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
                        Vector convertShape4 = mo41getVector2.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);
                        ByteVector mo41getVector3 = this.b.mo41getVector(ByteVector.SPECIES_64, i2, i3);
                        ByteVector mo41getVector4 = this.b.mo41getVector(ByteVector.SPECIES_64, i2, i3 + 16);
                        ByteVector sub = mo41getVector3.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_64).sub(PanamaTensorOperations.Q4_BYTE_SUB_64);
                        ByteVector sub2 = mo41getVector3.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_64).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_64).sub(PanamaTensorOperations.Q4_BYTE_SUB_64);
                        ShortVector add = ShortVector.zero(ShortVector.SPECIES_128).add(convertShape.mul(sub.castShape(ShortVector.SPECIES_128, 0))).add(convertShape2.mul(mo41getVector4.lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_64).sub(PanamaTensorOperations.Q4_BYTE_SUB_64).castShape(ShortVector.SPECIES_128, 0))).add(convertShape3.mul(sub2.castShape(ShortVector.SPECIES_128, 0))).add(convertShape4.mul(mo41getVector4.lanewise(VectorOperators.ASHR, PanamaTensorOperations.Q4_BYTE_SHIFT_64).lanewise(VectorOperators.AND, PanamaTensorOperations.Q4_BYTE_MASK_64).sub(PanamaTensorOperations.Q4_BYTE_SUB_64).castShape(ShortVector.SPECIES_128, 0)));
                        zero = zero.add(add.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0).mul(broadcast)).add(add.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1).mul(broadcast));
                        i6++;
                        i2 += 32;
                        i3 += 32;
                    }
                    i4 = i5 + FloatVector.SPECIES_128.length();
                }
            };
        }
    }

    public PanamaTensorOperations(MachineSpec.Type type) {
        this.vectorType = type;
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public String name() {
        return "Panama Vector Operations";
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public int parallelSplitSize() {
        return PhysicalCoreExecutor.instance.get().getCoreCount();
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void batchDotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        Gemmer gemmerBF16;
        Preconditions.checkArgument(abstractTensor2.dims() == 2 && abstractTensor3.dims() == 2 && abstractTensor.dims() == 2);
        Preconditions.checkArgument(abstractTensor2.shape().dim(0) == abstractTensor.shape().dim(0), "BAD M");
        Preconditions.checkArgument(i4 == 0 || i4 >= i5, "Result offset must be >= b row offset");
        int dim = abstractTensor2.shape().dim(0);
        switch (abstractTensor2.dType()) {
            case F32:
                switch (abstractTensor3.dType()) {
                    case F32:
                        gemmerBF16 = new GemmerF32(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                        break;
                    case BF16:
                        gemmerBF16 = new GemmerF32BF16(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                        break;
                    case Q4:
                        switch (this.vectorType) {
                            case AVX_256:
                                gemmerBF16 = new GemmerF32Q4_256(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                                break;
                            case AVX_512:
                                gemmerBF16 = new GemmerF32Q4_512(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                                break;
                            default:
                                throw new UnsupportedOperationException(this.vectorType.name());
                        }
                    default:
                        throw new UnsupportedOperationException(abstractTensor3.dType().name());
                }
            case BF16:
                switch (abstractTensor3.dType()) {
                    case BF16:
                        gemmerBF16 = new GemmerBF16(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                        break;
                    default:
                        throw new UnsupportedOperationException(abstractTensor3.dType().name());
                }
            case Q4:
            default:
                throw new UnsupportedOperationException(abstractTensor2.dType().name() + " " + abstractTensor3.dType().name());
            case I8:
                switch (abstractTensor3.dType()) {
                    case Q4:
                        switch (this.vectorType) {
                            case AVX_256:
                                gemmerBF16 = new GemmerI8Q4_256(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                                break;
                            case AVX_512:
                                gemmerBF16 = new GemmerI8Q4_512(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                                break;
                            case ARM_128:
                                gemmerBF16 = new GemmerI8Q4_arm(this, i3, abstractTensor2, abstractTensor3, abstractTensor, i, i2, i4);
                                break;
                            default:
                                throw new UnsupportedOperationException(this.vectorType.name());
                        }
                    default:
                        throw new UnsupportedOperationException(abstractTensor3.dType().name());
                }
        }
        gemmerBF16.matmul(0, dim, i5, i5 + i6);
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public AbstractTensor quantize(AbstractTensor abstractTensor, DType dType, int i, int i2) {
        Preconditions.checkArgument(abstractTensor.dims() == 2 && i2 % 32 == 0);
        switch (abstractTensor.dType()) {
            case F32:
                switch (dType) {
                    case BF16:
                        return quantizeBF16((FloatBufferTensor) abstractTensor, i, i2);
                    case I8:
                        switch (this.vectorType) {
                            case AVX_256:
                                return quantizeQ8_256((FloatBufferTensor) abstractTensor, i, i2);
                            case AVX_512:
                                return quantizeQ8_512((FloatBufferTensor) abstractTensor, i, i2);
                            case ARM_128:
                                return quantizeQ8_arm((FloatBufferTensor) abstractTensor, i, i2);
                            default:
                                throw new UnsupportedOperationException();
                        }
                    default:
                        throw new UnsupportedOperationException("F32 => " + String.valueOf(dType));
                }
            case BF16:
                switch (dType) {
                    case F32:
                        return quantizeBF16_F32((BFloat16BufferTensor) abstractTensor, i, i2);
                    case I8:
                        switch (this.vectorType) {
                            case AVX_256:
                                return quantizeBF16_Q8_256((BFloat16BufferTensor) abstractTensor, i, i2);
                            case AVX_512:
                                return quantizeBF16_Q8_512((BFloat16BufferTensor) abstractTensor, i, i2);
                            case ARM_128:
                                return quantizeBF16_Q8_arm((BFloat16BufferTensor) abstractTensor, i, i2);
                            default:
                                throw new UnsupportedOperationException();
                        }
                    default:
                        throw new UnsupportedOperationException("BF16 => " + String.valueOf(dType));
                }
            default:
                throw new UnsupportedOperationException(String.valueOf(abstractTensor.dType()));
        }
    }

    public BFloat16BufferTensor quantizeBF16(FloatBufferTensor floatBufferTensor, int i, int i2) {
        return new BFloat16BufferTensor(floatBufferTensor);
    }

    public FloatBufferTensor quantizeBF16_F32(BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        FloatBufferTensor floatBufferTensor = (FloatBufferTensor) TensorCache.instance.get(DType.F32, bFloat16BufferTensor.shape());
        int first = bFloat16BufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            int i4 = i;
            while (true) {
                int i5 = i4;
                if (i5 < i + i2) {
                    ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_PREFERRED, i3, i5);
                    FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
                    FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
                    floatBufferTensor.intoTensor(reinterpretAsFloats, i3, i5);
                    floatBufferTensor.intoTensor(reinterpretAsFloats2, i3, i5 + FloatVector.SPECIES_PREFERRED.length());
                    i4 = i5 + ShortVector.SPECIES_PREFERRED.length();
                }
            }
        }
        return floatBufferTensor;
    }

    public Q8ByteBufferTensor quantizeQ8_512(FloatBufferTensor floatBufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, floatBufferTensor.shape());
        int first = floatBufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                FloatVector mo41getVector = floatBufferTensor.mo41getVector(FloatVector.SPECIES_512, i3, i4);
                FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_512, i3, i4 + 16);
                float reduceLanes = mo41getVector.abs().max(mo41getVector2.abs()).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = mo41getVector.mul(broadcast).add(F32_ROUND_UP_512);
                FloatVector add2 = mo41getVector2.mul(broadcast).add(F32_ROUND_UP_512);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, i3, i4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, i3, i4 + 16);
                try {
                    q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return q8ByteBufferTensor;
    }

    public Q8ByteBufferTensor quantizeQ8_256(FloatBufferTensor floatBufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, floatBufferTensor.shape());
        int first = floatBufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                FloatVector mo41getVector = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, i3, i4);
                FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, i3, i4 + 8);
                FloatVector mo41getVector3 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, i3, i4 + 16);
                FloatVector mo41getVector4 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, i3, i4 + 24);
                float reduceLanes = mo41getVector.abs().max(mo41getVector2.abs()).max(mo41getVector3.abs().max(mo41getVector4.abs())).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = mo41getVector.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add2 = mo41getVector2.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add3 = mo41getVector3.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add4 = mo41getVector4.mul(broadcast).add(F32_ROUND_UP_256);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes3 = add3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes4 = add4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, i3, i4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, i3, i4 + 8);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes3, i3, i4 + 16);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes4, i3, i4 + 24);
                q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
            }
        }
        return q8ByteBufferTensor;
    }

    public Q8ByteBufferTensor quantizeQ8_arm(FloatBufferTensor floatBufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, floatBufferTensor.shape());
        int first = floatBufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                FloatVector mo41getVector = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 0);
                FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 4);
                FloatVector mo41getVector3 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 8);
                FloatVector mo41getVector4 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 12);
                FloatVector mo41getVector5 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 16);
                FloatVector mo41getVector6 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 20);
                FloatVector mo41getVector7 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 24);
                FloatVector mo41getVector8 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, i3, i4 + 28);
                float reduceLanes = mo41getVector.abs().max(mo41getVector2.abs()).max(mo41getVector3.abs().max(mo41getVector4.abs())).max(mo41getVector5.abs().max(mo41getVector6.abs()).max(mo41getVector7.abs().max(mo41getVector8.abs()))).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_128, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = mo41getVector.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add2 = mo41getVector2.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add3 = mo41getVector3.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add4 = mo41getVector4.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add5 = mo41getVector5.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add6 = mo41getVector6.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add7 = mo41getVector7.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add8 = mo41getVector8.mul(broadcast).add(F32_ROUND_UP_128);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes3 = add3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes4 = add4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes5 = add5.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes6 = add6.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes7 = add7.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes8 = add8.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, BYTE_MASK_32, i3, i4 + 0);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, BYTE_MASK_32, i3, i4 + 4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes3, BYTE_MASK_32, i3, i4 + 8);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes4, BYTE_MASK_32, i3, i4 + 12);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes5, BYTE_MASK_32, i3, i4 + 16);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes6, BYTE_MASK_32, i3, i4 + 20);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes7, BYTE_MASK_32, i3, i4 + 24);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes8, BYTE_MASK_32, i3, i4 + 28);
                q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
            }
        }
        return q8ByteBufferTensor;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_512(BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, bFloat16BufferTensor.shape());
        int first = bFloat16BufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_512, i3, i4);
                FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_512).reinterpretAsFloats();
                FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_512).reinterpretAsFloats();
                float reduceLanes = reinterpretAsFloats.abs().max(reinterpretAsFloats2.abs()).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = reinterpretAsFloats.mul(broadcast).add(F32_ROUND_UP_512);
                FloatVector add2 = reinterpretAsFloats2.mul(broadcast).add(F32_ROUND_UP_512);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, i3, i4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, i3, i4 + 16);
                try {
                    q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return q8ByteBufferTensor;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_256(BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, bFloat16BufferTensor.shape());
        int first = bFloat16BufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_256, i3, i4);
                FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                ShortVector mo41getVector2 = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_256, i3, i4 + 16);
                FloatVector reinterpretAsFloats3 = mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                FloatVector reinterpretAsFloats4 = mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                float reduceLanes = reinterpretAsFloats.abs().max(reinterpretAsFloats2.abs()).max(reinterpretAsFloats3.abs().max(reinterpretAsFloats4.abs())).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = reinterpretAsFloats.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add2 = reinterpretAsFloats2.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add3 = reinterpretAsFloats3.mul(broadcast).add(F32_ROUND_UP_256);
                FloatVector add4 = reinterpretAsFloats4.mul(broadcast).add(F32_ROUND_UP_256);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes3 = add3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes4 = add4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, i3, i4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, i3, i4 + 8);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes3, i3, i4 + 16);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes4, i3, i4 + 24);
                q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
            }
        }
        return q8ByteBufferTensor;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_arm(BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) TensorCache.instance.get(DType.I8, bFloat16BufferTensor.shape());
        int first = bFloat16BufferTensor.shape().first();
        for (int i3 = 0; i3 < first; i3++) {
            for (int i4 = i; i4 < i + i2; i4 += 32) {
                ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, i3, i4);
                FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                ShortVector mo41getVector2 = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, i3, i4 + 8);
                FloatVector reinterpretAsFloats3 = mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector reinterpretAsFloats4 = mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                ShortVector mo41getVector3 = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, i3, i4 + 16);
                FloatVector reinterpretAsFloats5 = mo41getVector3.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector reinterpretAsFloats6 = mo41getVector3.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                ShortVector mo41getVector4 = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, i3, i4 + 24);
                FloatVector reinterpretAsFloats7 = mo41getVector4.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector reinterpretAsFloats8 = mo41getVector4.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                float reduceLanes = reinterpretAsFloats.abs().max(reinterpretAsFloats2.abs()).max(reinterpretAsFloats3.abs().max(reinterpretAsFloats4.abs())).max(reinterpretAsFloats5.abs().max(reinterpretAsFloats6.abs()).max(reinterpretAsFloats7.abs().max(reinterpretAsFloats8.abs()))).reduceLanes(VectorOperators.MAX);
                float f = reduceLanes / 127.0f;
                FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_128, reduceLanes != 0.0f ? 127.0f / reduceLanes : 0.0f);
                FloatVector add = reinterpretAsFloats.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add2 = reinterpretAsFloats2.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add3 = reinterpretAsFloats3.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add4 = reinterpretAsFloats4.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add5 = reinterpretAsFloats5.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add6 = reinterpretAsFloats6.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add7 = reinterpretAsFloats7.mul(broadcast).add(F32_ROUND_UP_128);
                FloatVector add8 = reinterpretAsFloats8.mul(broadcast).add(F32_ROUND_UP_128);
                ByteVector reinterpretAsBytes = add.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes2 = add2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes3 = add3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes4 = add4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes5 = add5.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes6 = add6.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes7 = add7.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector reinterpretAsBytes8 = add8.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes, BYTE_MASK_32, i3, i4 + 0);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes2, BYTE_MASK_32, i3, i4 + 4);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes3, BYTE_MASK_32, i3, i4 + 8);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes4, BYTE_MASK_32, i3, i4 + 12);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes5, BYTE_MASK_32, i3, i4 + 16);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes6, BYTE_MASK_32, i3, i4 + 20);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes7, BYTE_MASK_32, i3, i4 + 24);
                q8ByteBufferTensor.intoTensor(reinterpretAsBytes8, BYTE_MASK_32, i3, i4 + 28);
                q8ByteBufferTensor.getBlockF().set(f, i3, (int) (i4 * 0.03125f));
            }
        }
        return q8ByteBufferTensor;
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void maccumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        Preconditions.checkArgument(abstractTensor.dType() == abstractTensor2.dType());
        Preconditions.checkArgument(i2 % 8 == 0);
        boolean z = abstractTensor2.shape().first() > 1;
        for (int i3 = 0; i3 < abstractTensor.shape().first(); i3++) {
            AbstractTensor slice = abstractTensor.slice(i3);
            AbstractTensor slice2 = z ? abstractTensor2.slice(i3) : abstractTensor2;
            switch (slice.dType()) {
                case F32:
                    maccumulateF32((FloatBufferTensor) slice, (FloatBufferTensor) slice2, i, i2);
                    break;
                case BF16:
                    maccumulateBF16((BFloat16BufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                    break;
                default:
                    throw new UnsupportedOperationException(slice.dType().name());
            }
        }
    }

    void maccumulateF32(FloatBufferTensor floatBufferTensor, FloatBufferTensor floatBufferTensor2, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_PREFERRED.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            floatBufferTensor.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i3).mul(floatBufferTensor2.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i3)), 0, i3);
            i4 = i3 + FloatVector.SPECIES_PREFERRED.length();
        }
        while (i3 < i + i2) {
            floatBufferTensor.set(floatBufferTensor.get(0, i3) * floatBufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    void maccumulateBF16(BFloat16BufferTensor bFloat16BufferTensor, BFloat16BufferTensor bFloat16BufferTensor2, int i, int i2) {
        int loopBound = i + ShortVector.SPECIES_PREFERRED.loopBound(i2);
        int i3 = i;
        int length = ShortVector.SPECIES_PREFERRED.length() / 2;
        while (i3 < loopBound) {
            ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_PREFERRED, 0, i3);
            FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            ShortVector mo41getVector2 = bFloat16BufferTensor2.mo41getVector(ShortVector.SPECIES_PREFERRED, 0, i3);
            bFloat16BufferTensor.intoTensor((ShortVector) reinterpretAsFloats.mul(mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats()).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, 0).blend(reinterpretAsFloats2.mul(mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats()).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1), VectorMask.fromLong(ShortVector.SPECIES_PREFERRED, (1 << length) - 1).not()), 0, i3);
            i3 += ShortVector.SPECIES_PREFERRED.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) * bFloat16BufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void accumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        boolean z = abstractTensor2.shape().first() > 1;
        for (int i3 = 0; i3 < abstractTensor.shape().first(); i3++) {
            AbstractTensor slice = abstractTensor.slice(i3);
            AbstractTensor slice2 = z ? abstractTensor2.slice(i3) : abstractTensor2;
            switch (slice.dType()) {
                case F32:
                    switch (slice2.dType()) {
                        case F32:
                            accumulateF32((FloatBufferTensor) slice, (FloatBufferTensor) slice2, i, i2);
                            break;
                        case BF16:
                            switch (this.vectorType) {
                                case AVX_256:
                                case AVX_512:
                                    accumulateF32BF16_256((FloatBufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                                    break;
                                case ARM_128:
                                    accumulateF32BF16_arm((FloatBufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                                    break;
                                default:
                                    throw new UnsupportedOperationException();
                            }
                        case Q4:
                            switch (this.vectorType) {
                                case AVX_256:
                                case AVX_512:
                                    accumulateF32Q4_256((FloatBufferTensor) slice, (Q4ByteBufferTensor) slice2, i, i2);
                                    break;
                                case ARM_128:
                                    accumulateF32Q4_arm((FloatBufferTensor) slice, (Q4ByteBufferTensor) slice2, i, i2);
                                    break;
                                default:
                                    throw new UnsupportedOperationException();
                            }
                        default:
                            throw new UnsupportedOperationException("F32 => " + String.valueOf(slice2.dType()));
                    }
                case BF16:
                    switch (slice2.dType()) {
                        case BF16:
                            switch (this.vectorType) {
                                case AVX_256:
                                    accumulateBF16_256((BFloat16BufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                                    break;
                                case AVX_512:
                                    accumulateBF16_512((BFloat16BufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                                    break;
                                case ARM_128:
                                    accumulateBF16_arm((BFloat16BufferTensor) slice, (BFloat16BufferTensor) slice2, i, i2);
                                    break;
                                default:
                                    throw new UnsupportedOperationException();
                            }
                        default:
                            throw new UnsupportedOperationException();
                    }
                default:
                    throw new UnsupportedOperationException(String.valueOf(slice.dType()));
            }
        }
    }

    private void accumulateF32Q4_arm(FloatBufferTensor floatBufferTensor, Q4ByteBufferTensor q4ByteBufferTensor, int i, int i2) {
        int i3 = i;
        int i4 = i;
        int loopBound = i + FloatVector.SPECIES_128.loopBound(i2);
        while (i3 < loopBound) {
            FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_128, q4ByteBufferTensor.getFactorForIndex(0, i4));
            FloatVector mo41getVector = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3);
            FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 4);
            FloatVector mo41getVector3 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 8);
            FloatVector mo41getVector4 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 12);
            FloatVector mo41getVector5 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 16);
            FloatVector mo41getVector6 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 20);
            FloatVector mo41getVector7 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 24);
            FloatVector mo41getVector8 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3 + 28);
            ByteVector mo41getVector9 = q4ByteBufferTensor.mo41getVector(ByteVector.SPECIES_64, 0, i4);
            ByteVector mo41getVector10 = q4ByteBufferTensor.mo41getVector(ByteVector.SPECIES_64, 0, i4 + 16);
            ByteVector sub = mo41getVector9.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
            ByteVector sub2 = mo41getVector9.lanewise(VectorOperators.ASHR, Q4_BYTE_SHIFT_64).lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
            Vector castShape = sub.castShape(ShortVector.SPECIES_128, 0);
            Vector convertShape = castShape.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector convertShape2 = castShape.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            Vector castShape2 = sub2.castShape(ShortVector.SPECIES_128, 0);
            Vector convertShape3 = castShape2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector convertShape4 = castShape2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            ByteVector sub3 = mo41getVector10.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
            ByteVector sub4 = mo41getVector10.lanewise(VectorOperators.ASHR, Q4_BYTE_SHIFT_64).lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
            Vector castShape3 = sub3.castShape(ShortVector.SPECIES_128, 0);
            Vector convertShape5 = castShape3.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector convertShape6 = castShape3.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            Vector castShape4 = sub4.castShape(ShortVector.SPECIES_128, 0);
            Vector convertShape7 = castShape4.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector convertShape8 = castShape4.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            floatBufferTensor.intoTensor(mo41getVector.add(convertShape.mul(broadcast)), 0, i3);
            floatBufferTensor.intoTensor(mo41getVector2.add(convertShape2.mul(broadcast)), 0, i3 + 4);
            floatBufferTensor.intoTensor(mo41getVector3.add(convertShape5.mul(broadcast)), 0, i3 + 8);
            floatBufferTensor.intoTensor(mo41getVector4.add(convertShape6.mul(broadcast)), 0, i3 + 12);
            floatBufferTensor.intoTensor(mo41getVector5.add(convertShape3.mul(broadcast)), 0, i3 + 16);
            floatBufferTensor.intoTensor(mo41getVector6.add(convertShape4.mul(broadcast)), 0, i3 + 20);
            floatBufferTensor.intoTensor(mo41getVector7.add(convertShape7.mul(broadcast)), 0, i3 + 24);
            floatBufferTensor.intoTensor(mo41getVector8.add(convertShape8.mul(broadcast)), 0, i3 + 28);
            i3 += 32;
            i4 += 32;
        }
    }

    void accumulateF32(FloatBufferTensor floatBufferTensor, FloatBufferTensor floatBufferTensor2, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_PREFERRED.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            floatBufferTensor.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i3).add(floatBufferTensor2.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i3)), 0, i3);
            i4 = i3 + FloatVector.SPECIES_PREFERRED.length();
        }
        while (i3 < i + i2) {
            floatBufferTensor.set(floatBufferTensor.get(0, i3) + floatBufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    void accumulateF32Q4_256(FloatBufferTensor floatBufferTensor, Q4ByteBufferTensor q4ByteBufferTensor, int i, int i2) {
        int i3 = i;
        int i4 = i;
        int loopBound = i + FloatVector.SPECIES_256.loopBound(i2);
        while (i3 < loopBound) {
            FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, q4ByteBufferTensor.getFactorForIndex(0, i4));
            ByteVector mo41getVector = q4ByteBufferTensor.mo41getVector(ByteVector.SPECIES_128, 0, i4);
            ByteVector sub = mo41getVector.and((byte) 15).sub((byte) 8);
            ByteVector sub2 = mo41getVector.lanewise(VectorOperators.LSHR, 4L).sub((byte) 8);
            FloatVector add = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, 0, i3).add(sub.castShape(FloatVector.SPECIES_256, 0).mul(broadcast));
            FloatVector add2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, 0, i3 + 8).add(sub.castShape(FloatVector.SPECIES_256, 1).mul(broadcast));
            FloatVector add3 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, 0, i3 + 16).add(sub2.castShape(FloatVector.SPECIES_256, 0).mul(broadcast));
            FloatVector add4 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, 0, i3 + 16 + 8).add(sub2.castShape(FloatVector.SPECIES_256, 1).mul(broadcast));
            floatBufferTensor.intoTensor(add, 0, i3);
            floatBufferTensor.intoTensor(add2, 0, i3 + 8);
            floatBufferTensor.intoTensor(add3, 0, i3 + 16);
            floatBufferTensor.intoTensor(add4, 0, i3 + 16 + 8);
            i3 += 32;
            i4 += 32;
        }
    }

    void accumulateF32BF16_256(FloatBufferTensor floatBufferTensor, BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_256.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            floatBufferTensor.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_256, 0, i3).add(bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats()), 0, i3);
            i4 = i3 + FloatVector.SPECIES_256.length();
        }
        while (i3 < i + i2) {
            floatBufferTensor.set(floatBufferTensor.get(0, i3) + bFloat16BufferTensor.get(0, i3), 0, i3);
            i3++;
        }
    }

    void accumulateF32BF16_arm(FloatBufferTensor floatBufferTensor, BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_128.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            floatBufferTensor.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_128, 0, i3).add(bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_64, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats()), 0, i3);
            i4 = i3 + FloatVector.SPECIES_128.length();
        }
        while (i3 < i + i2) {
            floatBufferTensor.set(floatBufferTensor.get(0, i3) + bFloat16BufferTensor.get(0, i3), 0, i3);
            i3++;
        }
    }

    void accumulateBF16_arm(BFloat16BufferTensor bFloat16BufferTensor, BFloat16BufferTensor bFloat16BufferTensor2, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_128.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            bFloat16BufferTensor.intoTensor((ShortVector) bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_64, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats().add(bFloat16BufferTensor2.mo41getVector(ShortVector.SPECIES_64, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_128).reinterpretAsFloats()).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_128).convertShape(VectorOperators.I2S, ShortVector.SPECIES_64, 0), 0, i3);
            i4 = i3 + FloatVector.SPECIES_128.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) + bFloat16BufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    void accumulateBF16_256(BFloat16BufferTensor bFloat16BufferTensor, BFloat16BufferTensor bFloat16BufferTensor2, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_256.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            bFloat16BufferTensor.intoTensor((ShortVector) bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats().add(bFloat16BufferTensor2.mo41getVector(ShortVector.SPECIES_128, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats()).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_256).convertShape(VectorOperators.I2S, ShortVector.SPECIES_128, 0), 0, i3);
            i4 = i3 + FloatVector.SPECIES_256.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) + bFloat16BufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    void accumulateBF16_512(BFloat16BufferTensor bFloat16BufferTensor, BFloat16BufferTensor bFloat16BufferTensor2, int i, int i2) {
        int i3;
        int loopBound = i + FloatVector.SPECIES_512.loopBound(i2);
        int i4 = i;
        while (true) {
            i3 = i4;
            if (i3 >= loopBound) {
                break;
            }
            bFloat16BufferTensor.intoTensor((ShortVector) bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_256, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_512).reinterpretAsFloats().add(bFloat16BufferTensor2.mo41getVector(ShortVector.SPECIES_256, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_512).reinterpretAsFloats()).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_512).convertShape(VectorOperators.I2S, ShortVector.SPECIES_256, 0), 0, i3);
            i4 = i3 + FloatVector.SPECIES_512.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) + bFloat16BufferTensor2.get(0, i3), 0, i3);
            i3++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void scale(float f, AbstractTensor abstractTensor, int i, int i2) {
        for (int i3 = 0; i3 < abstractTensor.shape().first(); i3++) {
            AbstractTensor slice = abstractTensor.slice(i3);
            switch (slice.dType()) {
                case F32:
                    scaleF32(f, (FloatBufferTensor) slice, i, i2);
                    break;
                case BF16:
                    switch (this.vectorType) {
                        case AVX_256:
                            scaleBF16_256(f, (BFloat16BufferTensor) slice, i, i2);
                            break;
                        case AVX_512:
                            scaleBF16_512(f, (BFloat16BufferTensor) slice, i, i2);
                            break;
                        default:
                            throw new UnsupportedOperationException();
                    }
                default:
                    throw new UnsupportedOperationException();
            }
        }
    }

    public void scaleF32(float f, FloatBufferTensor floatBufferTensor, int i, int i2) {
        int loopBound = FloatVector.SPECIES_PREFERRED.loopBound(i2) + i;
        int i3 = i;
        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, f);
        while (i3 < loopBound) {
            floatBufferTensor.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i3).mul(broadcast), 0, i3);
            i3 += FloatVector.SPECIES_PREFERRED.length();
        }
        while (i3 < i + i2) {
            floatBufferTensor.set(floatBufferTensor.get(0, i3) * f, 0, i3);
            i3++;
        }
    }

    public void scaleBF16_512(float f, BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        int loopBound = FloatVector.SPECIES_512.loopBound(i2) + i;
        int i3 = i;
        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_512, f);
        while (i3 < loopBound) {
            bFloat16BufferTensor.intoTensor((ShortVector) bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_256, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_512).reinterpretAsFloats().mul(broadcast).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_512).convertShape(VectorOperators.I2S, ShortVector.SPECIES_256, 0), 0, i3);
            i3 += FloatVector.SPECIES_512.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) * f, 0, i3);
            i3++;
        }
    }

    public void scaleBF16_256(float f, BFloat16BufferTensor bFloat16BufferTensor, int i, int i2) {
        int loopBound = FloatVector.SPECIES_256.loopBound(i2) + i;
        int i3 = i;
        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_256, f);
        while (i3 < loopBound) {
            bFloat16BufferTensor.intoTensor((ShortVector) bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_128, 0, i3).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256).reinterpretAsFloats().mul(broadcast).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT_256).convertShape(VectorOperators.I2S, ShortVector.SPECIES_128, 0), 0, i3);
            i3 += FloatVector.SPECIES_256.length();
        }
        while (i3 < i + i2) {
            bFloat16BufferTensor.set(bFloat16BufferTensor.get(0, i3) * f, 0, i3);
            i3++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void saxpy(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        Preconditions.checkArgument(abstractTensor2.shape().first() == 1);
        Preconditions.checkArgument(abstractTensor.dType() == abstractTensor2.dType() || (abstractTensor.dType() == DType.BF16 && abstractTensor2.dType() == DType.F32));
        Preconditions.checkArgument(i3 % 2 == 0);
        switch (abstractTensor.dType()) {
            case F32:
                saxpyF32(f, (FloatBufferTensor) abstractTensor, (FloatBufferTensor) abstractTensor2, i, i2, i3);
                return;
            case BF16:
                switch (abstractTensor2.dType()) {
                    case F32:
                        saxpyBF16F32(f, (BFloat16BufferTensor) abstractTensor, (FloatBufferTensor) abstractTensor2, i, i2, i3);
                        return;
                    case BF16:
                        saxpyBF16(f, (BFloat16BufferTensor) abstractTensor, (BFloat16BufferTensor) abstractTensor2, i, i2, i3);
                        return;
                    default:
                        throw new UnsupportedOperationException();
                }
            default:
                throw new UnsupportedOperationException();
        }
    }

    void saxpyF32(float f, FloatBufferTensor floatBufferTensor, FloatBufferTensor floatBufferTensor2, int i, int i2, int i3) {
        int loopBound = FloatVector.SPECIES_PREFERRED.loopBound(i3);
        int i4 = i;
        int i5 = i2;
        FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, f);
        while (i4 < i + loopBound && i5 < i2 + loopBound) {
            floatBufferTensor2.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i4).fma(broadcast, floatBufferTensor2.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i5)), 0, i5);
            i4 += FloatVector.SPECIES_PREFERRED.length();
            i5 += FloatVector.SPECIES_PREFERRED.length();
        }
        while (i4 < i + i3 && i5 < i2 + i3) {
            floatBufferTensor2.set(floatBufferTensor2.get(0, i5) + (f * floatBufferTensor.get(0, i4)), 0, i5);
            i4++;
            i5++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void saxpy(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        Preconditions.checkArgument(i3 % 2 == 0);
        switch (abstractTensor2.dType()) {
            case F32:
                saxpyF32(abstractTensor, (FloatBufferTensor) abstractTensor2, (FloatBufferTensor) abstractTensor3, i, i2, i3, i4, i5, i6);
                return;
            case BF16:
                switch (abstractTensor3.dType()) {
                    case F32:
                        saxpyBF16F32(abstractTensor, abstractTensor2, abstractTensor3, i, i2, i3, i4, i5, i6);
                        return;
                    case BF16:
                        saxpyBF16(abstractTensor, abstractTensor2, abstractTensor3, i, i2, i3, i4, i5, i6);
                        return;
                    default:
                        throw new UnsupportedOperationException();
                }
            default:
                throw new UnsupportedOperationException();
        }
    }

    public void saxpyF32(AbstractTensor abstractTensor, FloatBufferTensor floatBufferTensor, FloatBufferTensor floatBufferTensor2, int i, int i2, int i3, int i4, int i5, int i6) {
        int loopBound = FloatVector.SPECIES_PREFERRED.loopBound(i3);
        int i7 = i4;
        int i8 = i5;
        int i9 = (i6 - (i6 % 4)) + i4;
        while (i7 < i9) {
            int i10 = i;
            FloatVector broadcast = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, abstractTensor.get(0, i7 + 0));
            FloatVector broadcast2 = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, abstractTensor.get(0, i7 + 1));
            FloatVector broadcast3 = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, abstractTensor.get(0, i7 + 2));
            FloatVector broadcast4 = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, abstractTensor.get(0, i7 + 3));
            for (int i11 = i2; i10 < i + loopBound && i11 < i2 + loopBound; i11 += FloatVector.SPECIES_PREFERRED.length()) {
                FloatVector mo41getVector = floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, i8 + 0, i10);
                FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, i8 + 1, i10);
                floatBufferTensor2.intoTensor(floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, i8 + 3, i10).fma(broadcast4, floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, i8 + 2, i10).fma(broadcast3, mo41getVector2.fma(broadcast2, mo41getVector.fma(broadcast, floatBufferTensor2.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i11))))), 0, i11);
                i10 += FloatVector.SPECIES_PREFERRED.length();
            }
            i7 += 4;
            i8 += 4;
        }
        while (i7 < i4 + i6) {
            saxpyF32(abstractTensor.get(0, i7), (FloatBufferTensor) floatBufferTensor.slice(i8), floatBufferTensor2, i, i2, i3);
            i7++;
            i8++;
        }
    }

    public void saxpyBF16(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        BFloat16BufferTensor bFloat16BufferTensor = (BFloat16BufferTensor) abstractTensor2;
        BFloat16BufferTensor bFloat16BufferTensor2 = (BFloat16BufferTensor) abstractTensor3;
        int i7 = i4 + i6;
        int i8 = i4;
        int i9 = i5;
        while (i8 < i7) {
            saxpyBF16(abstractTensor.get(0, i8), (BFloat16BufferTensor) bFloat16BufferTensor.slice(i9), bFloat16BufferTensor2, i, i2, i3);
            i8++;
            i9++;
        }
    }

    public void saxpyBF16F32(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        BFloat16BufferTensor bFloat16BufferTensor = (BFloat16BufferTensor) abstractTensor2;
        FloatBufferTensor floatBufferTensor = (FloatBufferTensor) abstractTensor3;
        int i7 = i4 + i6;
        int i8 = i4;
        int i9 = i5;
        while (i8 < i7) {
            saxpyBF16F32(abstractTensor.get(0, i8), (BFloat16BufferTensor) bFloat16BufferTensor.slice(i9), floatBufferTensor, i, i2, i3);
            i8++;
            i9++;
        }
    }

    void saxpyBF16(float f, BFloat16BufferTensor bFloat16BufferTensor, BFloat16BufferTensor bFloat16BufferTensor2, int i, int i2, int i3) {
        int loopBound = ShortVector.SPECIES_PREFERRED.loopBound(i3);
        Preconditions.checkArgument(loopBound == i3);
        int i4 = i;
        int length = ShortVector.SPECIES_PREFERRED.length();
        int length2 = ShortVector.SPECIES_PREFERRED.length() / 2;
        for (int i5 = i2; i4 < i + loopBound && i5 < i2 + loopBound; i5 += length) {
            ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_PREFERRED, 0, i4);
            FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            ShortVector mo41getVector2 = bFloat16BufferTensor2.mo41getVector(ShortVector.SPECIES_PREFERRED, 0, i5);
            bFloat16BufferTensor2.intoTensor((ShortVector) mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats().add(reinterpretAsFloats.mul(f)).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, 0).blend(mo41getVector2.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats().add(reinterpretAsFloats2.mul(f)).reinterpretAsInts().lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1), VectorMask.fromLong(ShortVector.SPECIES_PREFERRED, (1 << length2) - 1).not()), 0, i5);
            i4 += length;
        }
    }

    void saxpyBF16F32(float f, BFloat16BufferTensor bFloat16BufferTensor, FloatBufferTensor floatBufferTensor, int i, int i2, int i3) {
        int loopBound = ShortVector.SPECIES_PREFERRED.loopBound(i3);
        Preconditions.checkArgument(loopBound == i3);
        int i4 = i;
        int length = ShortVector.SPECIES_PREFERRED.length();
        for (int i5 = i2; i4 < i + loopBound && i5 < i2 + loopBound; i5 += length) {
            ShortVector mo41getVector = bFloat16BufferTensor.mo41getVector(ShortVector.SPECIES_PREFERRED, 0, i4);
            FloatVector reinterpretAsFloats = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector reinterpretAsFloats2 = mo41getVector.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector mo41getVector2 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i5);
            FloatVector mo41getVector3 = floatBufferTensor.mo41getVector(FloatVector.SPECIES_PREFERRED, 0, i5 + FloatVector.SPECIES_PREFERRED.length());
            FloatVector add = mo41getVector2.add(reinterpretAsFloats.mul(f));
            FloatVector add2 = mo41getVector3.add(reinterpretAsFloats2.mul(f));
            floatBufferTensor.intoTensor(add, 0, i5);
            floatBufferTensor.intoTensor(add2, 0, i5 + FloatVector.SPECIES_PREFERRED.length());
            i4 += length;
        }
    }
}
