package com.github.tjake.jlama.model.gemma;

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.RMSNorm;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.model.llama.LlamaModel;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import java.util.Optional;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/model/gemma/GemmaModel.class */
public class GemmaModel extends LlamaModel {
    private static final Logger logger = LoggerFactory.getLogger(GemmaModel.class);
    private final float embeddingScalingFactor;
    private AbstractTensor wte;

    public GemmaModel(Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        this(AbstractModel.InferenceType.FULL_GENERATION, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    public GemmaModel(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(inferenceType, config, weightLoader, tokenizer, dType, dType2, optional);
        this.embeddingScalingFactor = FloatConversions.bFloat16ToFloat32(FloatConversions.float32ToBFloat16((float) Math.pow(this.c.embeddingLength, 0.5d)));
    }

    @Override // com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.GEMMA;
    }

    @Override // com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    protected TransformerBlock[] loadTransformerBlockWeights() {
        DType orElse = this.modelQType.orElse(this.modelDType);
        if (orElse != this.modelDType) {
            logger.info("Quantizing model with {} - Please hold...", orElse);
        }
        TransformerBlock[] transformerBlockArr = new TransformerBlock[this.c.dctx().numberOfLayers];
        IntStream.range(this.c.dctx().layerStart, this.c.dctx().layerEnd).parallel().forEach(i -> {
            String str = "model.layers." + i + ".";
            String str2 = str + "self_attn.";
            CausalSelfAttention causalSelfAttention = new CausalSelfAttention(this, i, this.weights.load(str2 + "q_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "k_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "v_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str2 + "o_proj.weight", this.c.dctx(), false, true).quantize(orElse));
            String str3 = str + "mlp.";
            transformerBlockArr[i] = new TransformerBlock(this, i, new RMSNorm(this, this.weights.load(str + "input_layernorm.weight").quantize(orElse), 1.0f), causalSelfAttention, new RMSNorm(this, this.weights.load(str + "post_attention_layernorm.weight").quantize(orElse), 1.0f), new MLPBlock(this, this.c.activationFunction, this.weights.load(str3 + "gate_proj.weight", this.c.dctx(), true, false).quantize(orElse), this.weights.load(str3 + "down_proj.weight", this.c.dctx(), false, true).quantize(orElse), this.weights.load(str3 + "up_proj.weight", this.c.dctx(), true, false).quantize(orElse)));
        });
        return transformerBlockArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    public EmbedInput loadInputWeights() {
        if (this.wte == null) {
            this.wte = this.weights.load("model.embed_tokens.weight").quantize(this.workingDType);
        }
        return (i, i2) -> {
            AbstractTensor makeDenseTensor = makeDenseTensor(this.c.embeddingLength);
            AbstractTensor slice = this.wte.slice(true, i);
            if (this.wte.dType() != makeDenseTensor.dType()) {
                slice = TensorOperationsProvider.get().quantize(slice, makeDenseTensor.dType(), 0, this.c.embeddingLength);
            }
            makeDenseTensor.copyFrom(slice, 0, 0, this.c.embeddingLength);
            TensorOperationsProvider.get().scale(this.embeddingScalingFactor, makeDenseTensor, 0, this.c.embeddingLength);
            return makeDenseTensor;
        };
    }

    @Override // com.github.tjake.jlama.model.llama.LlamaModel, com.github.tjake.jlama.model.AbstractModel
    protected SampleOutput loadOutputWeights() {
        DType orElse = this.modelQType.orElse(this.modelDType);
        if (this.wte == null) {
            this.wte = this.weights.load("model.embed_tokens.weight").quantize(this.workingDType);
        }
        final RMSNorm rMSNorm = new RMSNorm(this, this.weights.load("model.norm.weight").quantize(orElse), 1.0f);
        return new SampleOutput(this) { // from class: com.github.tjake.jlama.model.gemma.GemmaModel.1
            final /* synthetic */ GemmaModel this$0;

            {
                this.this$0 = this;
            }

            @Override // com.github.tjake.jlama.model.functions.SampleOutput
            public LayerNorm getOutputLayerNorm() {
                return rMSNorm;
            }

            @Override // com.github.tjake.jlama.model.functions.SampleOutput
            public AbstractTensor getOutputLogitsWeights() {
                return this.this$0.wte;
            }
        };
    }
}
