package com.github.tjake.jlama.model.mixtral;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.MoEBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.RMSNorm;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.mistral.MistralModel;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.Optional;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/model/mixtral/MixtralModel.class */
public class MixtralModel extends MistralModel {
    private static final Logger logger = LoggerFactory.getLogger(MixtralModel.class);

    public MixtralModel(Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(AbstractModel.InferenceType.FULL_GENERATION, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    public MixtralModel(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(inferenceType, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    @Override // com.github.tjake.jlama.model.mistral.MistralModel, com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.MIXTRAL;
    }

    @Override // com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    protected TransformerBlock[] loadTransformerBlockWeights() {
        MixtralConfig mixtralConfig = (MixtralConfig) this.c;
        DType orElse = this.modelQType.orElse(this.modelDType);
        if (orElse != this.modelDType) {
            logger.info("Quantizing model with {} - Please hold...", orElse);
        }
        TransformerBlock[] transformerBlockArr = new TransformerBlock[this.c.dctx().numberOfLayers];
        IntStream.range(this.c.dctx().layerStart, this.c.dctx().layerEnd).parallel().forEach(i -> {
            String str = "model.layers." + i + ".";
            String str2 = str + "self_attn.";
            CausalSelfAttention causalSelfAttention = new CausalSelfAttention(this, i, this.weights.load(str2 + "q_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "k_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "v_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "o_proj.weight").quantize(orElse));
            String str3 = str + "block_sparse_moe.";
            AbstractTensor[] abstractTensorArr = new AbstractTensor[mixtralConfig.numberOfExperts];
            AbstractTensor[] abstractTensorArr2 = new AbstractTensor[mixtralConfig.numberOfExperts];
            AbstractTensor[] abstractTensorArr3 = new AbstractTensor[mixtralConfig.numberOfExperts];
            for (int i = 0; i < mixtralConfig.numberOfExperts; i++) {
                String str4 = str3 + "experts." + i + ".";
                abstractTensorArr[i] = this.weights.load(str4 + "w1.weight", this.c.dctx(), true, false).quantize(orElse);
                abstractTensorArr2[i] = this.weights.load(str4 + "w2.weight").quantize(orElse);
                abstractTensorArr3[i] = this.weights.load(str4 + "w3.weight", this.c.dctx(), true, false).quantize(orElse);
            }
            transformerBlockArr[i] = new TransformerBlock(this, i, new RMSNorm(this, this.weights.load(str + "input_layernorm.weight").quantize(orElse)), causalSelfAttention, new RMSNorm(this, this.weights.load(str + "post_attention_layernorm.weight").quantize(orElse)), new MoEBlock(this, mixtralConfig.numberOfExperts, mixtralConfig.numberOfExpertsPerToken, this.c.activationFunction, this.weights.load(str3 + "gate.weight").quantize(orElse), abstractTensorArr, abstractTensorArr2, abstractTensorArr3));
        });
        return transformerBlockArr;
    }
}
