package ai.vespa.embedding.huggingface;

import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

@Beta
/* loaded from: input_file:ai/vespa/embedding/huggingface/HuggingFaceEmbedder.class */
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
    private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final boolean normalize;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final PoolingStrategy poolingStrategy;
    private final String prependQuery;
    private final String prependDocument;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey.class */
    public static final class HFEmbedderCacheKey extends Record {
        private final String embedderId;
        private final Object embeddedValue;

        protected HFEmbedderCacheKey(String str, Object obj) {
            this.embedderId = str;
            this.embeddedValue = obj;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, HFEmbedderCacheKey.class), HFEmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, HFEmbedderCacheKey.class), HFEmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, HFEmbedderCacheKey.class, Object.class), HFEmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String embedderId() {
            return this.embedderId;
        }

        public Object embeddedValue() {
            return this.embeddedValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult.class */
    public static final class HFEmbeddingResult extends Record {
        private final IndexedTensor output;
        private final Tensor attentionMask;
        private final String embedderId;

        protected HFEmbeddingResult(IndexedTensor indexedTensor, Tensor tensor, String str) {
            this.output = indexedTensor;
            this.attentionMask = tensor;
            this.embedderId = str;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, HFEmbeddingResult.class), HFEmbeddingResult.class, "output;attentionMask;embedderId", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->output:Lcom/yahoo/tensor/IndexedTensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->attentionMask:Lcom/yahoo/tensor/Tensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->embedderId:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, HFEmbeddingResult.class), HFEmbeddingResult.class, "output;attentionMask;embedderId", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->output:Lcom/yahoo/tensor/IndexedTensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->attentionMask:Lcom/yahoo/tensor/Tensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->embedderId:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, HFEmbeddingResult.class, Object.class), HFEmbeddingResult.class, "output;attentionMask;embedderId", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->output:Lcom/yahoo/tensor/IndexedTensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->attentionMask:Lcom/yahoo/tensor/Tensor;", "FIELD:Lai/vespa/embedding/huggingface/HuggingFaceEmbedder$HFEmbeddingResult;->embedderId:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public IndexedTensor output() {
            return this.output;
        }

        public Tensor attentionMask() {
            return this.attentionMask;
        }

        public String embedderId() {
            return this.embedderId;
        }
    }

    @Inject
    public HuggingFaceEmbedder(OnnxRuntime onnxRuntime, Embedder.Runtime runtime, HuggingFaceEmbedderConfig huggingFaceEmbedderConfig) {
        this.runtime = runtime;
        this.inputIdsName = huggingFaceEmbedderConfig.transformerInputIds();
        this.attentionMaskName = huggingFaceEmbedderConfig.transformerAttentionMask();
        this.outputName = huggingFaceEmbedderConfig.transformerOutput();
        this.normalize = huggingFaceEmbedderConfig.normalize();
        this.prependQuery = huggingFaceEmbedderConfig.prependQuery();
        this.prependDocument = huggingFaceEmbedderConfig.prependDocument();
        Path path = Paths.get(huggingFaceEmbedderConfig.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder padding = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(path).setPadding(false);
        ModelInfo modelInfo = HuggingFaceTokenizer.getModelInfo(path);
        log.fine(() -> {
            return "'%s' has info '%s'".formatted(path, modelInfo);
        });
        if (modelInfo.maxLength() == -1 || modelInfo.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            padding.setTruncation(true).setMaxLength((modelInfo.maxLength() <= 0 || modelInfo.maxLength() > huggingFaceEmbedderConfig.transformerMaxTokens()) ? huggingFaceEmbedderConfig.transformerMaxTokens() : modelInfo.maxLength());
        }
        this.tokenizer = padding.build();
        this.poolingStrategy = PoolingStrategy.fromString(huggingFaceEmbedderConfig.poolingStrategy().toString());
        OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
        if (huggingFaceEmbedderConfig.transformerGpuDevice() >= 0) {
            onnxEvaluatorOptions.setGpuDevice(huggingFaceEmbedderConfig.transformerGpuDevice());
        }
        onnxEvaluatorOptions.setExecutionMode(huggingFaceEmbedderConfig.transformerExecutionMode().toString());
        onnxEvaluatorOptions.setThreads(huggingFaceEmbedderConfig.transformerInterOpThreads(), huggingFaceEmbedderConfig.transformerIntraOpThreads());
        this.evaluator = onnxRuntime.evaluatorOf(huggingFaceEmbedderConfig.transformerModel().toString(), onnxEvaluatorOptions);
        this.tokenTypeIdsName = detectTokenTypeIds(huggingFaceEmbedderConfig, this.evaluator);
        validateModel();
    }

    private static String detectTokenTypeIds(HuggingFaceEmbedderConfig huggingFaceEmbedderConfig, OnnxEvaluator onnxEvaluator) {
        return onnxEvaluator.getInputInfo().size() < 3 ? "" : huggingFaceEmbedderConfig.transformerTokenTypeIds();
    }

    private void validateModel() {
        Map<String, TensorType> inputInfo = this.evaluator.getInputInfo();
        validateName(inputInfo, this.inputIdsName, "input");
        validateName(inputInfo, this.attentionMaskName, "input");
        if (!this.tokenTypeIdsName.isEmpty()) {
            validateName(inputInfo, this.tokenTypeIdsName, "input");
        }
        validateName(this.evaluator.getOutputInfo(), this.outputName, "output");
    }

    private static void validateName(Map<String, TensorType> map, String str, String str2) {
        if (!map.containsKey(str)) {
            throw new IllegalArgumentException("Model does not contain required " + str2 + ": '" + str + "'. Model contains: " + String.join(",", map.keySet()));
        }
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        long nanoTime = System.nanoTime();
        List<Integer> embed = this.tokenizer.embed(str, context);
        this.runtime.sampleSequenceLength(embed.size(), context);
        this.runtime.sampleEmbeddingLatency((System.nanoTime() - nanoTime) / 1000000.0d, context);
        return embed;
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        if (tensorType.dimensions().size() != 1) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(tensorType) + "': should only have one dimension.");
        }
        if (!((TensorType.Dimension) tensorType.dimensions().get(0)).isIndexed()) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(tensorType) + "': dimension should be indexed.");
        }
        HFEmbeddingResult lookupOrEvaluate = lookupOrEvaluate(context, prependInstruction(str, context));
        Tensor tensor = lookupOrEvaluate.output;
        if (tensorType.valueType() == TensorType.Value.INT8) {
            return binaryQuantization(lookupOrEvaluate, tensorType);
        }
        Tensor sentenceEmbedding = this.poolingStrategy.toSentenceEmbedding(tensorType, tensor, lookupOrEvaluate.attentionMask);
        return this.normalize ? normalize(sentenceEmbedding, tensorType) : sentenceEmbedding;
    }

    String prependInstruction(String str, Embedder.Context context) {
        return (this.prependQuery == null || this.prependQuery.isEmpty() || !context.getDestination().startsWith("query")) ? (this.prependDocument == null || this.prependDocument.isEmpty()) ? str : this.prependDocument + " " + str : this.prependQuery + " " + str;
    }

    Tensor normalize(Tensor tensor, TensorType tensorType) {
        double d = 0.0d;
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        for (int i = 0; i < ((Long) ((TensorType.Dimension) tensorType.dimensions().get(0)).size().get()).longValue(); i++) {
            double d2 = tensor.get(TensorAddress.of(new int[]{i}));
            d += d2 * d2;
        }
        double sqrt = Math.sqrt(d);
        for (int i2 = 0; i2 < ((Long) ((TensorType.Dimension) tensorType.dimensions().get(0)).size().get()).longValue(); i2++) {
            of.cell(tensor.get(TensorAddress.of(new int[]{i2})) / sqrt, new long[]{i2});
        }
        return of.build();
    }

    private HFEmbeddingResult lookupOrEvaluate(Embedder.Context context, String str) {
        return (HFEmbeddingResult) context.computeCachedValueIfAbsent(new HFEmbedderCacheKey(context.getEmbedderId(), str), () -> {
            return evaluate(context, str);
        });
    }

    private HFEmbeddingResult evaluate(Embedder.Context context, String str) {
        long nanoTime = System.nanoTime();
        Encoding encode = this.tokenizer.encode(str, context.getLanguage());
        this.runtime.sampleSequenceLength(encode.ids().size(), context);
        IndexedTensor createTensorRepresentation = createTensorRepresentation(encode.ids(), "d1");
        IndexedTensor createTensorRepresentation2 = createTensorRepresentation(encode.attentionMask(), "d1");
        IndexedTensor createTensorRepresentation3 = this.tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encode.typeIds(), "d1");
        IndexedTensor indexedTensor = this.evaluator.evaluate((this.tokenTypeIdsName.isEmpty() || createTensorRepresentation3.isEmpty()) ? Map.of(this.inputIdsName, createTensorRepresentation.expand("d0"), this.attentionMaskName, createTensorRepresentation2.expand("d0")) : Map.of(this.inputIdsName, createTensorRepresentation.expand("d0"), this.attentionMaskName, createTensorRepresentation2.expand("d0"), this.tokenTypeIdsName, createTensorRepresentation3.expand("d0"))).get(this.outputName);
        long[] shape = indexedTensor.shape();
        if (shape.length != 3) {
            throw new IllegalArgumentException("Expected 3 output dimensions for output name '" + this.outputName + "': [batch, sequence, embedding], got " + shape.length);
        }
        this.runtime.sampleEmbeddingLatency((System.nanoTime() - nanoTime) / 1000000.0d, context);
        return new HFEmbeddingResult(indexedTensor, createTensorRepresentation2, context.getEmbedderId());
    }

    private Tensor binaryQuantization(HFEmbeddingResult hFEmbeddingResult, TensorType tensorType) {
        long j = hFEmbeddingResult.output().shape()[2];
        long longValue = 8 * ((Long) ((TensorType.Dimension) tensorType.dimensions().get(0)).size().get()).longValue();
        if (longValue > j) {
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("Cannot pack " + j + " into " + illegalArgumentException + " int8's");
            throw illegalArgumentException;
        }
        TensorType build = new TensorType.Builder(TensorType.Value.FLOAT).indexed(((TensorType.Dimension) tensorType.indexedSubtype().dimensions().get(0)).name(), longValue).build();
        Tensor sentenceEmbedding = this.poolingStrategy.toSentenceEmbedding(build, hFEmbeddingResult.output(), hFEmbeddingResult.attentionMask());
        return binarize((IndexedTensor) (this.normalize ? normalize(sentenceEmbedding, build) : sentenceEmbedding), tensorType);
    }

    public static Tensor binarize(IndexedTensor indexedTensor, TensorType tensorType) {
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        BitSet bitSet = new BitSet(8);
        int i = 0;
        for (int i2 = 0; i2 < indexedTensor.sizeAsInt(); i2++) {
            int i3 = 7 - (i2 % 8);
            if (indexedTensor.get(i2) > 0.0d) {
                bitSet.set(i3);
            } else {
                bitSet.clear(i3);
            }
            if ((i2 + 1) % 8 == 0) {
                of.cell(TensorAddress.of(new int[]{i}), bitSet.toByteArray().length == 0 ? (byte) 0 : r0[0]);
                i++;
                bitSet = new BitSet(8);
            }
        }
        return of.build();
    }

    private IndexedTensor createTensorRepresentation(List<Long> list, String str) {
        int size = list.size();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(new TensorType.Builder(TensorType.Value.FLOAT).indexed(str, size).build());
        for (int i = 0; i < size; i++) {
            of.cell((float) list.get(i).longValue(), new long[]{i});
        }
        return of.build();
    }
}
