package com.github.tjake.jlama.model;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

/* loaded from: input_file:com/github/tjake/jlama/model/MoEBlock.class */
public class MoEBlock implements FeedForward {
    private final AbstractModel model;
    private final DistributedContext dctx;
    private final AbstractTensor moeGateWeight;
    private final int numberOfExperts;
    private final int numberOfExpertsPerToken;
    private final AbstractTensor[] fullyConnectedWeights;
    private final AbstractTensor[] projectionWeights;
    private final AbstractTensor[] upProjectionWeights;
    private final FloatBufferTensor expertResults;
    private final int[] selectedExperts;
    private final ActivationFunction.Type activationFunction;
    private final AbstractTensor[] batchResults = new AbstractTensor[2];
    private final AbstractTensor[] batchWeights = new AbstractTensor[2];

    public MoEBlock(AbstractModel abstractModel, int i, int i2, ActivationFunction.Type type, AbstractTensor abstractTensor, AbstractTensor[] abstractTensorArr, AbstractTensor[] abstractTensorArr2, AbstractTensor[] abstractTensorArr3) {
        this.model = abstractModel;
        this.dctx = abstractModel.c.dctx();
        this.numberOfExperts = i;
        this.numberOfExpertsPerToken = i2;
        this.moeGateWeight = abstractTensor;
        this.activationFunction = type;
        this.fullyConnectedWeights = abstractTensorArr;
        this.projectionWeights = abstractTensorArr2;
        this.upProjectionWeights = abstractTensorArr3;
        this.expertResults = new FloatBufferTensor(i);
        this.selectedExperts = new int[i2];
    }

    @Override // com.github.tjake.jlama.model.functions.FeedForward
    public AbstractTensor forward(AbstractTensor abstractTensor, Optional<Consumer<List<AbstractTensor>>> optional) {
        int first = abstractTensor.shape().first();
        int i = this.model.c.hiddenLength;
        AbstractTensor makeTensor = this.model.makeTensor(first, this.model.c.embeddingLength);
        AbstractTensor makeTensor2 = this.model.makeTensor(1, i);
        try {
            AbstractTensor makeTensor3 = this.model.makeTensor(1, i);
            try {
                AbstractTensor makeTensor4 = this.model.makeTensor(1, this.model.c.embeddingLength);
                for (int i2 = 0; i2 < first; i2++) {
                    try {
                        AbstractTensor slice = abstractTensor.slice(true, i2);
                        VectorMath.pfor(0, this.numberOfExperts, i3 -> {
                            this.expertResults.set(TensorOperationsProvider.get().dotProduct(slice, this.moeGateWeight.slice(true, i3), 0, 0, this.model.c.embeddingLength), 0, i3);
                        });
                        VectorMath.softMax(this.expertResults, 0, this.numberOfExperts);
                        topk(this.expertResults);
                        for (int i4 = 0; i4 < this.numberOfExpertsPerToken; i4++) {
                            this.batchWeights[0] = this.fullyConnectedWeights[this.selectedExperts[i4]];
                            this.batchWeights[1] = this.upProjectionWeights[this.selectedExperts[i4]];
                            AbstractTensor abstractTensor2 = this.projectionWeights[this.selectedExperts[i4]];
                            this.batchResults[0] = makeTensor2;
                            this.batchResults[1] = makeTensor3;
                            VectorMath.pchunk(this.dctx.hiddenSegmentStart, this.dctx.hiddenSegmentLength, (i5, i6) -> {
                                TensorOperationsProvider.get().dotProductBatchChunk(this.batchResults, slice, this.batchWeights, 0, this.model.c.embeddingLength, i5, i6);
                            });
                            VectorMath.pfor(this.dctx.hiddenSegmentStart, this.dctx.hiddenSegmentEnd, i7 -> {
                                makeTensor2.set(ActivationFunction.eval(this.activationFunction, makeTensor2.get(0, i7)), 0, i7);
                            });
                            TensorOperationsProvider.get().maccumulate(makeTensor2, makeTensor3, this.dctx.hiddenSegmentStart, this.dctx.hiddenSegmentLength);
                            optional.ifPresent(consumer -> {
                                consumer.accept(Collections.singletonList(makeTensor2));
                            });
                            AbstractTensor maybeQuantize = this.model.maybeQuantize(makeTensor2);
                            try {
                                VectorMath.pchunk(0, this.model.c.embeddingLength, (i8, i9) -> {
                                    TensorOperationsProvider.get().dotProductChunk(makeTensor4, maybeQuantize, abstractTensor2, 0, i, i8, i9);
                                });
                                if (maybeQuantize != null) {
                                    maybeQuantize.close();
                                }
                                if (i4 == 0) {
                                    makeTensor.copyFrom(makeTensor4, 0, 0, this.model.c.embeddingLength);
                                } else {
                                    TensorOperationsProvider.get().accumulate(makeTensor.slice(i2), makeTensor4, 0, this.model.c.embeddingLength);
                                }
                            } finally {
                            }
                        }
                    } catch (Throwable th) {
                        if (makeTensor4 != null) {
                            try {
                                makeTensor4.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                if (makeTensor4 != null) {
                    makeTensor4.close();
                }
                if (makeTensor3 != null) {
                    makeTensor3.close();
                }
                if (makeTensor2 != null) {
                    makeTensor2.close();
                }
                return makeTensor;
            } catch (Throwable th3) {
                if (makeTensor3 != null) {
                    try {
                        makeTensor3.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (makeTensor2 != null) {
                try {
                    makeTensor2.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private int[] topk(FloatBufferTensor floatBufferTensor) {
        long size = floatBufferTensor.size();
        for (int i = 0; i < this.numberOfExpertsPerToken; i++) {
            this.selectedExperts[i] = i;
        }
        for (int i2 = this.numberOfExpertsPerToken; i2 < size; i2++) {
            int i3 = 0;
            for (int i4 = 1; i4 < this.numberOfExpertsPerToken; i4++) {
                if (floatBufferTensor.get(0, this.selectedExperts[i4]) < floatBufferTensor.get(0, this.selectedExperts[i3])) {
                    i3 = i4;
                }
            }
            if (floatBufferTensor.get(0, i2) > floatBufferTensor.get(0, this.selectedExperts[i3])) {
                this.selectedExperts[i3] = i2;
            }
        }
        return this.selectedExperts;
    }
}
