package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.language.Language;
import com.yahoo.language.huggingface.Encoding;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.process.Embedder;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

@Beta
/* loaded from: input_file:ai/vespa/embedding/ColBertEmbedder.class */
public class ColBertEmbedder extends AbstractComponent implements Embedder {
    private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String outputName;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;
    private final int maxTransformerTokens;
    private final int maxQueryTokens;
    private final int maxDocumentTokens;
    private final long startSequenceToken;
    private final long endSequenceToken;
    private final long maskSequenceToken;
    private final long padSequenceToken;
    private final long querySequenceToken;
    private final long documentSequenceToken;
    private final Set<Long> skipTokens;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey.class */
    public static final class EmbedderCacheKey extends Record {
        private final String embedderId;
        private final Object embeddedValue;

        EmbedderCacheKey(String str, Object obj) {
            this.embedderId = str;
            this.embeddedValue = obj;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbedderCacheKey.class), EmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbedderCacheKey.class), EmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbedderCacheKey.class, Object.class), EmbedderCacheKey.class, "embedderId;embeddedValue", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embedderId:Ljava/lang/String;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbedderCacheKey;->embeddedValue:Ljava/lang/Object;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String embedderId() {
            return this.embedderId;
        }

        public Object embeddedValue() {
            return this.embeddedValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/embedding/ColBertEmbedder$EmbeddingResult.class */
    public static final class EmbeddingResult extends Record {
        private final int inputIdSize;
        private final Map<String, Tensor> outputs;

        EmbeddingResult(int i, Map<String, Tensor> map) {
            this.inputIdSize = i;
            this.outputs = map;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbeddingResult.class), EmbeddingResult.class, "inputIdSize;outputs", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->inputIdSize:I", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbeddingResult.class), EmbeddingResult.class, "inputIdSize;outputs", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->inputIdSize:I", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbeddingResult.class, Object.class), EmbeddingResult.class, "inputIdSize;outputs", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->inputIdSize:I", "FIELD:Lai/vespa/embedding/ColBertEmbedder$EmbeddingResult;->outputs:Ljava/util/Map;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public int inputIdSize() {
            return this.inputIdSize;
        }

        public Map<String, Tensor> outputs() {
            return this.outputs;
        }
    }

    /* loaded from: input_file:ai/vespa/embedding/ColBertEmbedder$TransformerInput.class */
    public static final class TransformerInput extends Record {
        private final List<Long> inputIds;
        private final List<Long> attentionMask;

        public TransformerInput(List<Long> list, List<Long> list2) {
            this.inputIds = list;
            this.attentionMask = list2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, TransformerInput.class), TransformerInput.class, "inputIds;attentionMask", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->inputIds:Ljava/util/List;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->attentionMask:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, TransformerInput.class), TransformerInput.class, "inputIds;attentionMask", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->inputIds:Ljava/util/List;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->attentionMask:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, TransformerInput.class, Object.class), TransformerInput.class, "inputIds;attentionMask", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->inputIds:Ljava/util/List;", "FIELD:Lai/vespa/embedding/ColBertEmbedder$TransformerInput;->attentionMask:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public List<Long> inputIds() {
            return this.inputIds;
        }

        public List<Long> attentionMask() {
            return this.attentionMask;
        }
    }

    @Inject
    public ColBertEmbedder(OnnxRuntime onnxRuntime, Embedder.Runtime runtime, ColBertEmbedderConfig colBertEmbedderConfig) {
        this.runtime = runtime;
        this.inputIdsName = colBertEmbedderConfig.transformerInputIds();
        this.attentionMaskName = colBertEmbedderConfig.transformerAttentionMask();
        this.outputName = colBertEmbedderConfig.transformerOutput();
        this.maxTransformerTokens = colBertEmbedderConfig.transformerMaxTokens();
        this.maxDocumentTokens = Math.min(colBertEmbedderConfig.maxDocumentTokens(), this.maxTransformerTokens);
        this.maxQueryTokens = Math.min(colBertEmbedderConfig.maxQueryTokens(), this.maxTransformerTokens);
        this.startSequenceToken = colBertEmbedderConfig.transformerStartSequenceToken();
        this.endSequenceToken = colBertEmbedderConfig.transformerEndSequenceToken();
        this.maskSequenceToken = colBertEmbedderConfig.transformerMaskToken();
        this.padSequenceToken = colBertEmbedderConfig.transformerPadToken();
        this.querySequenceToken = colBertEmbedderConfig.queryTokenId();
        this.documentSequenceToken = colBertEmbedderConfig.documentTokenId();
        Path path = Paths.get(colBertEmbedderConfig.tokenizerPath().toString(), new String[0]);
        HuggingFaceTokenizer.Builder padding = new HuggingFaceTokenizer.Builder().addSpecialTokens(false).addDefaultModel(path).setPadding(false);
        ModelInfo modelInfo = HuggingFaceTokenizer.getModelInfo(path);
        if (modelInfo.maxLength() == -1 || modelInfo.truncation() != ModelInfo.TruncationStrategy.LONGEST_FIRST) {
            padding.setTruncation(true).setMaxLength((modelInfo.maxLength() <= 0 || modelInfo.maxLength() > colBertEmbedderConfig.transformerMaxTokens()) ? colBertEmbedderConfig.transformerMaxTokens() : modelInfo.maxLength());
        }
        this.tokenizer = padding.build();
        this.skipTokens = new HashSet();
        PUNCTUATION.chars().forEach(i -> {
            this.skipTokens.addAll(this.tokenizer.encode(Character.toString((char) i), (Language) null).ids());
        });
        OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
        if (colBertEmbedderConfig.transformerGpuDevice() >= 0) {
            onnxEvaluatorOptions.setGpuDevice(colBertEmbedderConfig.transformerGpuDevice());
        }
        onnxEvaluatorOptions.setExecutionMode(colBertEmbedderConfig.transformerExecutionMode().toString());
        onnxEvaluatorOptions.setThreads(colBertEmbedderConfig.transformerInterOpThreads(), colBertEmbedderConfig.transformerIntraOpThreads());
        this.evaluator = onnxRuntime.evaluatorOf(colBertEmbedderConfig.transformerModel().toString(), onnxEvaluatorOptions);
        validateModel();
    }

    private void validateModel() {
        Map<String, TensorType> inputInfo = this.evaluator.getInputInfo();
        validateName(inputInfo, this.inputIdsName, "input");
        validateName(inputInfo, this.attentionMaskName, "input");
        validateName(this.evaluator.getOutputInfo(), this.outputName, "output");
    }

    private static void validateName(Map<String, TensorType> map, String str, String str2) {
        if (!map.containsKey(str)) {
            throw new IllegalArgumentException("Model does not contain required " + str2 + ": '" + str + "'. Model contains: " + String.join(",", map.keySet()));
        }
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        if (validTensorType(tensorType)) {
            return context.getDestination().startsWith("query") ? embedQuery(str, context, tensorType) : embedDocument(str, context, tensorType);
        }
        throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. Wanted a mixed 2-d mapped-indexed tensor, got " + String.valueOf(tensorType));
    }

    public void deconstruct() {
        this.evaluator.close();
        this.tokenizer.close();
    }

    protected TransformerInput buildTransformerInput(List<Long> list, int i, boolean z) {
        if (!z) {
            list = list.stream().filter(l -> {
                return !this.skipTokens.contains(l);
            }).toList();
        }
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        if (list.size() > i - 3) {
            list = list.subList(0, i - 3);
        }
        arrayList.add(Long.valueOf(this.startSequenceToken));
        arrayList.add(Long.valueOf(z ? this.querySequenceToken : this.documentSequenceToken));
        arrayList.addAll(list);
        arrayList.add(Long.valueOf(this.endSequenceToken));
        int size = arrayList.size();
        long j = z ? this.maskSequenceToken : this.padSequenceToken;
        int i2 = z ? i - size : 0;
        for (int i3 = 0; i3 < i2; i3++) {
            arrayList.add(Long.valueOf(j));
        }
        for (int i4 = 0; i4 < size; i4++) {
            arrayList2.add(1L);
        }
        for (int i5 = 0; i5 < i2; i5++) {
            arrayList2.add(0L);
        }
        return new TransformerInput(arrayList, arrayList2);
    }

    protected Tensor embedQuery(String str, Embedder.Context context, TensorType tensorType) {
        if (tensorType.valueType() == TensorType.Value.INT8) {
            throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
        }
        EmbeddingResult lookupOrEvaluate = lookupOrEvaluate(context, str, true);
        return toFloatTensor(lookupOrEvaluate.outputs.get(this.outputName), tensorType, lookupOrEvaluate.inputIdSize);
    }

    protected Tensor embedDocument(String str, Embedder.Context context, TensorType tensorType) {
        EmbeddingResult lookupOrEvaluate = lookupOrEvaluate(context, str, false);
        IndexedTensor indexedTensor = lookupOrEvaluate.outputs.get(this.outputName);
        return tensorType.valueType() == TensorType.Value.INT8 ? toBitTensor(indexedTensor, tensorType, lookupOrEvaluate.inputIdSize) : toFloatTensor(indexedTensor, tensorType, lookupOrEvaluate.inputIdSize);
    }

    protected EmbeddingResult lookupOrEvaluate(Embedder.Context context, String str, boolean z) {
        return (EmbeddingResult) context.computeCachedValueIfAbsent(new EmbedderCacheKey(context.getEmbedderId(), str), () -> {
            return evaluate(context, str, z);
        });
    }

    private EmbeddingResult evaluate(Embedder.Context context, String str, boolean z) {
        long nanoTime = System.nanoTime();
        Encoding encode = this.tokenizer.encode(str, context.getLanguage());
        this.runtime.sampleSequenceLength(encode.ids().size(), context);
        TransformerInput buildTransformerInput = buildTransformerInput(encode.ids(), z ? this.maxQueryTokens : this.maxDocumentTokens, z);
        Map<String, Tensor> evaluate = this.evaluator.evaluate(Map.of(this.inputIdsName, createTensorRepresentation(buildTransformerInput.inputIds, "d1").expand("d0"), this.attentionMaskName, createTensorRepresentation(buildTransformerInput.attentionMask, "d1").expand("d0")));
        this.runtime.sampleEmbeddingLatency((System.nanoTime() - nanoTime) / 1000000.0d, context);
        return new EmbeddingResult(buildTransformerInput.inputIds.size(), evaluate);
    }

    public static Tensor toFloatTensor(IndexedTensor indexedTensor, TensorType tensorType, int i) {
        if (indexedTensor.shape().length != 3) {
            throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
        }
        if (tensorType.indexedSubtype().dimensions().size() != 1) {
            throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
        }
        int intValue = ((Long) ((TensorType.Dimension) tensorType.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int i2 = (int) indexedTensor.shape()[2];
        if (intValue > i2) {
            throw new IllegalArgumentException("Not possible to map token vector embedding with " + i2 + " dimensions into tensor with " + intValue);
        }
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < intValue; i4++) {
                of.cell(TensorAddress.of(new int[]{i3, i4}), indexedTensor.get(new long[]{0, i3, i4}));
            }
        }
        return of.build();
    }

    public static Tensor toBitTensor(IndexedTensor indexedTensor, TensorType tensorType, int i) {
        if (tensorType.valueType() != TensorType.Value.INT8) {
            throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing");
        }
        if (indexedTensor.shape().length != 3) {
            throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
        }
        if (tensorType.indexedSubtype().dimensions().size() != 1) {
            throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
        }
        int intValue = ((Long) ((TensorType.Dimension) tensorType.indexedSubtype().dimensions().get(0)).size().get()).intValue();
        int i2 = 8 * intValue;
        int i3 = (int) indexedTensor.shape()[2];
        if (i2 > i3) {
            throw new IllegalArgumentException("Not possible to pack " + i3 + " + dimensions into " + intValue + " dimensions");
        }
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        for (int i4 = 0; i4 < i; i4++) {
            BitSet bitSet = new BitSet(8);
            int i5 = 0;
            for (int i6 = 0; i6 < i2; i6++) {
                int i7 = 7 - (i6 % 8);
                if (indexedTensor.get(new long[]{0, i4, i6}) > 0.0d) {
                    bitSet.set(i7);
                } else {
                    bitSet.clear(i7);
                }
                if ((i6 + 1) % 8 == 0) {
                    of.cell(TensorAddress.of(new int[]{i4, i5}), bitSet.toByteArray().length == 0 ? (byte) 0 : r0[0]);
                    i5++;
                    bitSet = new BitSet(8);
                }
            }
        }
        return of.build();
    }

    public Set<Long> getSkipTokens() {
        return this.skipTokens;
    }

    public static Tensor expandBitTensor(Tensor tensor) {
        UnpackBitsNode unpackBitsNode = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
        MapContext mapContext = new MapContext();
        mapContext.put("input", new TensorValue(tensor));
        return unpackBitsNode.evaluate(mapContext).asTensor();
    }

    protected boolean validTensorType(TensorType tensorType) {
        return tensorType.dimensions().size() == 2 && tensorType.indexedSubtype().rank() == 1;
    }

    private IndexedTensor createTensorRepresentation(List<Long> list, String str) {
        int size = list.size();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(new TensorType.Builder(TensorType.Value.FLOAT).indexed(str, size).build());
        for (int i = 0; i < size; i++) {
            of.cell((float) list.get(i).longValue(), new long[]{i});
        }
        return of.build();
    }
}
