/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model.bert;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.PoolingLayer;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Optional;

public class BertModel
extends AbstractModel {
    private static final String[] prefixes = new String[]{"", "bert."};

    public BertModel(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(AbstractModel.InferenceType.FORWARD_PASS, c, w, tokenizer, workingDType, workingQType, modelQType);
    }

    public BertModel(AbstractModel.InferenceType inferenceType, Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(inferenceType, c, w, tokenizer, workingDType, workingQType, modelQType);
    }

    protected AbstractTensor loadWeight(String name) {
        for (String prefix : prefixes) {
            String key = prefix + name;
            if (!this.weights.isWeightPresent(key)) continue;
            return this.weights.load(key);
        }
        throw new NoSuchElementException(Arrays.toString(prefixes) + " " + name + " not found in weights");
    }

    @Override
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.BERT;
    }

    @Override
    protected EmbedInput loadInputWeights() {
        AbstractTensor we = this.loadWeight("embeddings.word_embeddings.weight");
        AbstractTensor wte = this.loadWeight("embeddings.token_type_embeddings.weight");
        AbstractTensor wpe = this.loadWeight("embeddings.position_embeddings.weight");
        LayerNorm inputLayerNorm = new LayerNorm(this, this.loadWeight("embeddings.LayerNorm.bias"), this.loadWeight("embeddings.LayerNorm.weight"));
        return (inputToken, position) -> {
            AbstractTensor embedding = this.makeDenseTensor(this.c.embeddingLength);
            int i = 0;
            while (i < this.c.embeddingLength) {
                float v = we.get(inputToken, i) + wte.get(0, i) + wpe.get(position, i);
                embedding.set(v, 0, i++);
            }
            AbstractTensor lnemb = inputLayerNorm.forward(embedding);
            embedding.close();
            return lnemb;
        };
    }

    @Override
    protected TransformerBlock[] loadTransformerBlockWeights() {
        TransformerBlock[] transformerBlocks = new TransformerBlock[this.c.dctx().embeddingSegmentLength];
        for (int i = this.c.dctx().layerStart; i < this.c.dctx().layerEnd; ++i) {
            String b = "encoder.layer." + i + ".";
            String prefix = b + "attention.";
            AbstractTensor keyBias = this.loadWeight(prefix + "self.key.bias");
            AbstractTensor keyWeight = this.loadWeight(prefix + "self.key.weight");
            AbstractTensor queryBias = this.loadWeight(prefix + "self.query.bias");
            AbstractTensor queryWeight = this.loadWeight(prefix + "self.query.weight");
            AbstractTensor valueBias = this.loadWeight(prefix + "self.value.bias");
            AbstractTensor valueWeight = this.loadWeight(prefix + "self.value.weight");
            AbstractTensor outputBias = this.loadWeight(prefix + "output.dense.bias");
            AbstractTensor outputWeight = this.loadWeight(prefix + "output.dense.weight");
            CausalSelfAttention attention = new CausalSelfAttention((AbstractModel)this, i, keyBias, queryBias, valueBias, keyWeight, queryWeight, valueWeight, outputBias, outputWeight);
            prefix = b;
            MLPBlock mlpBlock = new MLPBlock(this, this.c.activationFunction, this.loadWeight(prefix + "intermediate.dense.bias"), this.loadWeight(prefix + "intermediate.dense.weight"), this.loadWeight(prefix + "output.dense.bias"), this.loadWeight(prefix + "output.dense.weight"));
            LayerNorm postAttentionNorm = new LayerNorm(this, this.loadWeight(b + "attention.output.LayerNorm.bias"), this.loadWeight(b + "attention.output.LayerNorm.weight"));
            LayerNorm postMlpNorm = new LayerNorm(this, this.loadWeight(b + "output.LayerNorm.bias"), this.loadWeight(b + "output.LayerNorm.weight"));
            transformerBlocks[i] = new TransformerBlock((AbstractModel)this, i, attention, postAttentionNorm, mlpBlock, postMlpNorm);
        }
        return transformerBlocks;
    }

    @Override
    protected SampleOutput loadOutputWeights() {
        throw new UnsupportedOperationException();
    }

    @Override
    protected PoolingLayer loadPoolingWeights() {
        final AbstractTensor poolerDenseWeight = this.loadWeight("pooler.dense.weight");
        final AbstractTensor poolerDenseBias = this.loadWeight("pooler.dense.bias");
        return new PoolingLayer(){

            @Override
            public AbstractTensor getPoolingWeights() {
                return poolerDenseWeight;
            }

            @Override
            public Optional<AbstractTensor> getPoolingBias() {
                return Optional.of(poolerDenseBias);
            }
        };
    }

    @Override
    protected ClassifyOutput loadClassifierWeights() {
        if (this.c.isClassifier()) {
            final AbstractTensor classifierWeight = this.loadWeight("classifier.weight");
            final AbstractTensor classifierBias = this.loadWeight("classifier.bias");
            return new ClassifyOutput(){

                @Override
                public AbstractTensor getClassificationWeights() {
                    return classifierWeight;
                }

                @Override
                public Optional<AbstractTensor> getClassificationBias() {
                    return Optional.of(classifierBias);
                }
            };
        }
        throw new UnsupportedOperationException("Classification not supported by this model");
    }
}

