package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;

@Beta
/* loaded from: input_file:ai/vespa/embedding/SpladeEmbedder.class */
public class SpladeEmbedder extends AbstractComponent implements Embedder {
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final double termScoreThreshold;
    private final boolean useCustomReduce;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public SpladeEmbedder(OnnxRuntime onnxRuntime, Embedder.Runtime runtime, SpladeEmbedderConfig spladeEmbedderConfig) {
        this(onnxRuntime, runtime, spladeEmbedderConfig, true);
    }

    SpladeEmbedder(OnnxRuntime onnxRuntime, Embedder.Runtime runtime, SpladeEmbedderConfig spladeEmbedderConfig, boolean z) {
        this.runtime = runtime;
        this.inputIdsName = spladeEmbedderConfig.transformerInputIds();
        this.attentionMaskName = spladeEmbedderConfig.transformerAttentionMask();
        this.outputName = spladeEmbedderConfig.transformerOutput();
        this.tokenTypeIdsName = spladeEmbedderConfig.transformerTokenTypeIds();
        this.termScoreThreshold = spladeEmbedderConfig.termScoreThreshold();
        this.useCustomReduce = z;
        Path path = Paths.get(spladeEmbedderConfig.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder padding = new HuggingFaceTokenizer.Builder().addSpecialTokens(true).addDefaultModel(path).setPadding(false);
        ModelInfo modelInfo = HuggingFaceTokenizer.getModelInfo(path);
        if (modelInfo.maxLength() == -1 || modelInfo.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            padding.setTruncation(true).setMaxLength((modelInfo.maxLength() <= 0 || modelInfo.maxLength() > spladeEmbedderConfig.transformerMaxTokens()) ? spladeEmbedderConfig.transformerMaxTokens() : modelInfo.maxLength());
        }
        this.tokenizer = padding.build();
        OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
        if (spladeEmbedderConfig.transformerGpuDevice() >= 0) {
            onnxEvaluatorOptions.setGpuDevice(spladeEmbedderConfig.transformerGpuDevice());
        }
        onnxEvaluatorOptions.setExecutionMode(spladeEmbedderConfig.transformerExecutionMode().toString());
        onnxEvaluatorOptions.setThreads(spladeEmbedderConfig.transformerInterOpThreads(), spladeEmbedderConfig.transformerIntraOpThreads());
        this.evaluator = onnxRuntime.evaluatorOf(spladeEmbedderConfig.transformerModel().toString(), onnxEvaluatorOptions);
        validateModel();
    }

    private void validateModel() {
        Map<String, TensorType> inputInfo = this.evaluator.getInputInfo();
        validateName(inputInfo, this.inputIdsName, "input");
        validateName(inputInfo, this.attentionMaskName, "input");
        validateName(this.evaluator.getOutputInfo(), this.outputName, "output");
    }

    protected boolean verifyTensorType(TensorType tensorType) {
        return tensorType.dimensions().size() == 1 && ((TensorType.Dimension) tensorType.dimensions().get(0)).isMapped();
    }

    private static void validateName(Map<String, TensorType> map, String str, String str2) {
        if (!map.containsKey(str)) {
            throw new IllegalArgumentException("Model does not contain required " + str2 + ": '" + str + "'. Model contains: " + String.join(",", map.keySet()));
        }
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        if (!verifyTensorType(tensorType)) {
            throw new IllegalArgumentException("Invalid splade embedder tensor destination. Wanted a mapped 1-d tensor, got " + tensorType);
        }
        long nanoTime = System.nanoTime();
        Encoding encode = this.tokenizer.encode(str, context.getLanguage());
        this.runtime.sampleSequenceLength(encode.ids().size(), context);
        IndexedTensor indexedTensor = (IndexedTensor) this.evaluator.evaluate(Map.of(this.inputIdsName, createTensorRepresentation(encode.ids(), "d1").expand("d0"), this.attentionMaskName, createTensorRepresentation(encode.attentionMask(), "d1").expand("d0"), this.tokenTypeIdsName, createTensorRepresentation(encode.typeIds(), "d1").expand("d0"))).get(this.outputName);
        Tensor sparsifyCustomReduce = this.useCustomReduce ? sparsifyCustomReduce(indexedTensor, tensorType) : sparsifyReduce(indexedTensor, tensorType);
        this.runtime.sampleEmbeddingLatency((System.nanoTime() - nanoTime) / 1000000.0d, context);
        return sparsifyCustomReduce;
    }

    private Tensor sparsifyReduce(Tensor tensor, TensorType tensorType) {
        IndexedTensor map = tensor.reduce(Reduce.Aggregator.max, new String[]{"d0", "d1"}).map(d -> {
            return Math.log(1.0d + (d > 0.0d ? d : 0.0d));
        });
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        long[] jArr = new long[1];
        for (int i = 0; i < map.size(); i++) {
            double d2 = map.get(i);
            if (d2 > this.termScoreThreshold) {
                jArr[0] = i;
                of.cell().label(((TensorType.Dimension) tensorType.dimensions().get(0)).name(), this.tokenizer.decode(jArr)).value(d2);
            }
        }
        return of.build();
    }

    public Tensor sparsifyCustomReduce(IndexedTensor indexedTensor, TensorType tensorType) {
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        long[] shape = indexedTensor.shape();
        if (shape.length != 3) {
            throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
        }
        if (shape[0] != 1) {
            throw new IllegalArgumentException("Batch size must be 1");
        }
        if (shape[1] > 2147483647L) {
            throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int");
        }
        if (shape[2] > 2147483647L) {
            throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int");
        }
        int i = (int) shape[1];
        int i2 = (int) shape[2];
        String name = ((TensorType.Dimension) tensorType.dimensions().get(0)).name();
        long[] jArr = new long[1];
        DirectIndexedAddress directAddress = indexedTensor.directAddress();
        directAddress.setIndex(0, 0);
        for (int i3 = 0; i3 < i2; i3++) {
            double d = 0.0d;
            directAddress.setIndex(2, i3);
            long stride = directAddress.getStride(1);
            long directIndex = directAddress.getDirectIndex();
            for (int i4 = 0; i4 < i; i4++) {
                double d2 = indexedTensor.get(directIndex + (i4 * stride));
                if (d2 > d) {
                    d = d2;
                }
            }
            double log = Math.log(1.0d + d);
            if (log > this.termScoreThreshold) {
                jArr[0] = i3;
                of.cell().label(name, this.tokenizer.decode(jArr)).value(log);
            }
        }
        return of.build();
    }

    private IndexedTensor createTensorRepresentation(List<Long> list, String str) {
        int size = list.size();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(new TensorType.Builder(TensorType.Value.FLOAT).indexed(str, size).build());
        for (int i = 0; i < size; i++) {
            of.cell((float) list.get(i).longValue(), new long[]{i});
        }
        return of.build();
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }
}
