/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.embedding.EmbeddingNormalizer;
import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.evaluator.config.OnnxEvaluatorConfig;
import ai.vespa.modelintegration.utils.ModelPathHelper;
import ai.vespa.modelintegration.utils.OnnxExternalDataResolver;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensors;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

@Beta
public class HuggingFaceEmbedder
extends AbstractComponent
implements Embedder {
    private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
    private final Embedder.Runtime runtime;
    private final ModelAnalysis analysis;
    private final boolean normalize;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final String prependQuery;
    private final String prependDocument;

    static ModelAnalysis analyze(OnnxEvaluator evaluator, HuggingFaceEmbedderConfig config) {
        Map<String, TensorType> inputs = evaluator.getInputInfo();
        int numInputs = inputs.size();
        String inputIdsName = config.transformerInputIds();
        String attentionMaskName = "";
        String tokenTypeIdsName = "";
        HuggingFaceEmbedder.validateName(inputs, inputIdsName, "input");
        if (numInputs > 1) {
            attentionMaskName = config.transformerAttentionMask();
            HuggingFaceEmbedder.validateName(inputs, attentionMaskName, "input");
            if (numInputs > 2) {
                tokenTypeIdsName = config.transformerTokenTypeIds();
                HuggingFaceEmbedder.validateName(inputs, tokenTypeIdsName, "input");
                if (numInputs > 3) {
                    throw new IllegalArgumentException("Model needs more than 3 inputs: " + String.valueOf(inputs.keySet()));
                }
            }
        }
        Map<String, TensorType> outputs = evaluator.getOutputInfo();
        String outputName = config.transformerOutput();
        HuggingFaceEmbedder.validateName(outputs, outputName, "output");
        int outputDimensions = outputs.get(outputName).dimensions().size();
        PoolingStrategy poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
        if (outputDimensions == 2) {
            if (poolingStrategy != PoolingStrategy.NONE) {
                throw new IllegalArgumentException("Expected pooling-strategy 'none' with 2 output dimensions");
            }
        } else if (outputDimensions == 3) {
            if (poolingStrategy == PoolingStrategy.NONE) {
                throw new IllegalArgumentException("Unexpected pooling-strategy 'none' with 3 output dimensions");
            }
        } else {
            throw new IllegalArgumentException("Expected 2 or 3 output dimensions for '" + outputName + "', but got type: " + String.valueOf(outputs.get(outputName)));
        }
        return new ModelAnalysis(numInputs, inputIdsName, attentionMaskName, tokenTypeIdsName, outputName, outputDimensions, poolingStrategy);
    }

    @Inject
    public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig embedderConfig, OnnxEvaluatorConfig onnxConfig, ModelPathHelper modelHelper) {
        this.runtime = runtime;
        OnnxExternalDataResolver resolver = new OnnxExternalDataResolver(modelHelper);
        String modelPath = resolver.resolveOnnxModel(embedderConfig.transformerModelReference()).toString();
        OnnxEvaluatorOptions onnxOpts = OnnxEvaluatorOptions.of(onnxConfig);
        this.evaluator = onnx.evaluatorOf(modelPath, onnxOpts);
        this.analysis = HuggingFaceEmbedder.analyze(this.evaluator, embedderConfig);
        this.normalize = embedderConfig.normalize();
        this.prependQuery = embedderConfig.prependQuery();
        this.prependDocument = embedderConfig.prependDocument();
        Path tokenizerPath = modelHelper.getModelPathResolvingIfNecessary(embedderConfig.tokenizerPathReference());
        HuggingFaceTokenizer.Builder builder = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(tokenizerPath).setPadding(false);
        ModelInfo info = HuggingFaceTokenizer.getModelInfo((Path)tokenizerPath);
        log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
        if (info.maxLength() == -1 || info.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            int maxLength = info.maxLength() > 0 && info.maxLength() <= embedderConfig.transformerMaxTokens() ? info.maxLength() : embedderConfig.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
    }

    private static void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Integer> embed(String s, Embedder.Context context) {
        long start = System.nanoTime();
        List tokens = this.tokenizer.embed(s, context);
        this.runtime.sampleSequenceLength((long)tokens.size(), context);
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return tokens;
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    public Tensor embed(String text, Embedder.Context context, TensorType targetType) {
        if (targetType.dimensions().size() != 1) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(targetType) + "': should only have one dimension.");
        }
        if (!((TensorType.Dimension)targetType.dimensions().get(0)).isIndexed()) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(targetType) + "': dimension should be indexed.");
        }
        HFEmbeddingResult embeddingResult = this.lookupOrEvaluate(context, this.prependInstruction(text, context));
        IndexedTensor tokenEmbeddings = embeddingResult.output;
        if (targetType.valueType() == TensorType.Value.INT8) {
            return this.binaryQuantization(embeddingResult, targetType);
        }
        Tensor result = this.analysis.poolingStrategy.toSentenceEmbedding(targetType, (Tensor)tokenEmbeddings, embeddingResult.attentionMask);
        return this.normalize ? EmbeddingNormalizer.normalize(result, targetType) : result;
    }

    String prependInstruction(String text, Embedder.Context context) {
        if (this.prependQuery != null && !this.prependQuery.isEmpty() && context.getDestination().startsWith("query")) {
            return this.prependQuery + " " + text;
        }
        if (this.prependDocument != null && !this.prependDocument.isEmpty()) {
            return this.prependDocument + " " + text;
        }
        return text;
    }

    private HFEmbeddingResult lookupOrEvaluate(Embedder.Context context, String text) {
        HFEmbedderCacheKey key = new HFEmbedderCacheKey(context.getEmbedderId(), text);
        return (HFEmbeddingResult)context.computeCachedValueIfAbsent((Object)key, () -> this.evaluate(context, text));
    }

    private HFEmbeddingResult evaluate(Embedder.Context context, String text) {
        Map<String, Tensor> inputs;
        long start = System.nanoTime();
        Encoding encoding = this.tokenizer.encode(text, context.getLanguage());
        this.runtime.sampleSequenceLength((long)encoding.ids().size(), context);
        Tensor inputSequence = this.createTensorRepresentation(encoding.ids(), "d1").expand("d0");
        Tensor attentionMask = this.createTensorRepresentation(encoding.attentionMask(), "d1").expand("d0");
        if (this.analysis.useAttentionMask()) {
            if (this.analysis.useTokenTypeIds()) {
                Tensor tokenTypeIds = this.createTensorRepresentation(encoding.typeIds(), "d1").expand("d0");
                inputs = Map.of(this.analysis.inputIdsName(), inputSequence, this.analysis.attentionMaskName(), attentionMask, this.analysis.tokenTypeIdsName(), tokenTypeIds);
            } else {
                inputs = Map.of(this.analysis.inputIdsName(), inputSequence, this.analysis.attentionMaskName, attentionMask);
            }
        } else {
            inputs = Map.of(this.analysis.inputIdsName(), inputSequence);
        }
        IndexedTensor tokenEmbeddings = (IndexedTensor)this.evaluator.evaluate(inputs).get(this.analysis.outputName());
        long[] resultShape = tokenEmbeddings.shape();
        if (resultShape.length != this.analysis.outputDimensions()) {
            throw new IllegalArgumentException("Expected " + this.analysis.outputDimensions + " output dimensions for output name '" + this.analysis.outputName() + "': [batch, sequence, embedding], got " + resultShape.length);
        }
        this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - start) / 1000000.0, context);
        return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId());
    }

    private Tensor binaryQuantization(HFEmbeddingResult embeddingResult, TensorType targetType) {
        long outputDimensions = embeddingResult.output().shape()[2];
        long targetDimensions = (Long)((TensorType.Dimension)targetType.dimensions().get(0)).size().get();
        long targetUnpackagedDimensions = 8L * targetDimensions;
        if (targetUnpackagedDimensions > outputDimensions) {
            throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDimensions + " int8's");
        }
        TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(((TensorType.Dimension)targetType.indexedSubtype().dimensions().get(0)).name(), targetUnpackagedDimensions).build();
        Tensor result = this.analysis.poolingStrategy().toSentenceEmbedding(poolingType, (Tensor)embeddingResult.output(), embeddingResult.attentionMask());
        result = this.normalize ? EmbeddingNormalizer.normalize(result, poolingType) : result;
        Tensor packedResult = Tensors.packBits((Tensor)result);
        if (!packedResult.type().equals((Object)targetType)) {
            throw new IllegalStateException("Expected pack_bits to produce " + String.valueOf(targetType) + ", but got " + String.valueOf(packedResult.type()));
        }
        return packedResult;
    }

    private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)input.get(i).longValue(), new long[]{i});
        }
        return builder.build();
    }

    record ModelAnalysis(int numInputs, String inputIdsName, String attentionMaskName, String tokenTypeIdsName, String outputName, int outputDimensions, PoolingStrategy poolingStrategy) {
        boolean useAttentionMask() {
            return !this.attentionMaskName.isEmpty();
        }

        boolean useTokenTypeIds() {
            return !this.tokenTypeIdsName.isEmpty();
        }
    }

    protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {
    }

    protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) {
    }
}

