package ai.vespa.llm.generation;

import ai.vespa.llm.completion.Prompt;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

@Beta
/* loaded from: input_file:ai/vespa/llm/generation/OnnxEncoderDecoderTextGenerator.class */
public class OnnxEncoderDecoderTextGenerator extends AbstractComponent implements TextGenerator {
    private static final int TOKEN_EOS = 1;
    private static final String BATCH_DIMENSION = "d0";
    private static final String SEQUENCE_DIMENSION = "d1";
    private final int tokenizerMaxTokens;
    private final String encoderInputIdsName;
    private final String encoderAttentionMaskName;
    private final String encoderOutputName;
    private final String decoderInputIdsName;
    private final String decoderAttentionMaskName;
    private final String decoderEncoderHiddenStateName;
    private final String decoderOutputName;
    private final SentencePieceEmbedder tokenizer;
    private final OnnxEvaluator encoder;
    private final OnnxEvaluator decoder;

    @Inject
    public OnnxEncoderDecoderTextGenerator(OnnxRuntime onnxRuntime, OnnxEncoderDecoderTextGeneratorConfig onnxEncoderDecoderTextGeneratorConfig) {
        this.tokenizer = new SentencePieceEmbedder.Builder(onnxEncoderDecoderTextGeneratorConfig.tokenizerModel().toString()).build();
        this.tokenizerMaxTokens = onnxEncoderDecoderTextGeneratorConfig.tokenizerMaxTokens();
        this.encoderInputIdsName = onnxEncoderDecoderTextGeneratorConfig.encoderModelInputIdsName();
        this.encoderAttentionMaskName = onnxEncoderDecoderTextGeneratorConfig.encoderModelAttentionMaskName();
        this.encoderOutputName = onnxEncoderDecoderTextGeneratorConfig.encoderModelOutputName();
        OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
        onnxEvaluatorOptions.setExecutionMode(onnxEncoderDecoderTextGeneratorConfig.encoderOnnxExecutionMode().toString());
        onnxEvaluatorOptions.setThreads(onnxEncoderDecoderTextGeneratorConfig.encoderOnnxInterOpThreads(), onnxEncoderDecoderTextGeneratorConfig.encoderOnnxIntraOpThreads());
        this.encoder = onnxRuntime.evaluatorOf(onnxEncoderDecoderTextGeneratorConfig.encoderModel().toString(), onnxEvaluatorOptions);
        this.decoderInputIdsName = onnxEncoderDecoderTextGeneratorConfig.decoderModelInputIdsName();
        this.decoderAttentionMaskName = onnxEncoderDecoderTextGeneratorConfig.decoderModelAttentionMaskName();
        this.decoderEncoderHiddenStateName = onnxEncoderDecoderTextGeneratorConfig.decoderModelEncoderHiddenStateName();
        this.decoderOutputName = onnxEncoderDecoderTextGeneratorConfig.decoderModelOutputName();
        OnnxEvaluatorOptions onnxEvaluatorOptions2 = new OnnxEvaluatorOptions();
        onnxEvaluatorOptions2.setExecutionMode(onnxEncoderDecoderTextGeneratorConfig.decoderOnnxExecutionMode().toString());
        onnxEvaluatorOptions2.setThreads(onnxEncoderDecoderTextGeneratorConfig.decoderOnnxInterOpThreads(), onnxEncoderDecoderTextGeneratorConfig.decoderOnnxIntraOpThreads());
        this.decoder = onnxRuntime.evaluatorOf(onnxEncoderDecoderTextGeneratorConfig.decoderModel().toString(), onnxEvaluatorOptions2);
        validateModels();
    }

    public String generate(String str, TextGeneratorDecoderOptions textGeneratorDecoderOptions) {
        switch (textGeneratorDecoderOptions.getSearchMethod()) {
            case GREEDY:
                return generateGreedy(str, textGeneratorDecoderOptions);
            default:
                return generateNotImplemented(textGeneratorDecoderOptions);
        }
    }

    public String generate(String str) {
        return generate(str, new TextGeneratorDecoderOptions());
    }

    public String generate(Prompt prompt, TextGenerator.Context context) {
        return generate(prompt.asString());
    }

    public void deconstruct() {
        this.encoder.close();
        this.decoder.close();
    }

    private String generateNotImplemented(TextGeneratorDecoderOptions textGeneratorDecoderOptions) {
        throw new UnsupportedOperationException("Search method '" + textGeneratorDecoderOptions.getSearchMethod() + "' is currently not implemented");
    }

    private String generateGreedy(String str, TextGeneratorDecoderOptions textGeneratorDecoderOptions) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(0);
        Tensor createTensorRepresentation = createTensorRepresentation(tokenize(str), SEQUENCE_DIMENSION);
        Tensor expand = createAttentionMask(createTensorRepresentation).expand(BATCH_DIMENSION);
        Tensor evaluateEncoder = evaluateEncoder(createTensorRepresentation.expand(BATCH_DIMENSION), expand);
        while (arrayList.size() < textGeneratorDecoderOptions.getMaxLength()) {
            arrayList.add(Integer.valueOf(findMostProbableToken(evaluateDecoder(createTensorRepresentation(arrayList, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION), expand, evaluateEncoder), arrayList.size() - 1, BATCH_DIMENSION, SEQUENCE_DIMENSION)));
        }
        return detokenize(arrayList);
    }

    private Tensor evaluateEncoder(Tensor tensor, Tensor tensor2) {
        return this.encoder.evaluate(Map.of(this.encoderInputIdsName, tensor, this.encoderAttentionMaskName, tensor2), this.encoderOutputName);
    }

    private IndexedTensor evaluateDecoder(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        IndexedTensor evaluate = this.decoder.evaluate(Map.of(this.decoderInputIdsName, tensor, this.decoderAttentionMaskName, tensor2, this.decoderEncoderHiddenStateName, tensor3), this.decoderOutputName);
        if (evaluate instanceof IndexedTensor) {
            return evaluate;
        }
        throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'");
    }

    private static int findMostProbableToken(IndexedTensor indexedTensor, int i, String str, String str2) {
        if (indexedTensor.type().rank() != 3) {
            throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. Got: " + indexedTensor.type());
        }
        IndexedTensor.SubspaceIterator cellIterator = indexedTensor.cellIterator(new PartialAddress.Builder(2).add(str, 0L).add(str2, i).build(), DimensionSizes.of(indexedTensor.type()));
        Double value = cellIterator.next().getValue();
        int i2 = 0;
        int i3 = 1;
        while (cellIterator.hasNext()) {
            Double value2 = cellIterator.next().getValue();
            if (value2.doubleValue() >= value.doubleValue() && i3 != 1) {
                value = value2;
                i2 = i3;
            }
            i3++;
        }
        return i2;
    }

    private List<Integer> tokenize(String str) {
        List<Integer> embed = this.tokenizer.embed(str, new Embedder.Context("tokenizer"));
        List<Integer> subList = embed.size() >= this.tokenizerMaxTokens ? embed.subList(0, this.tokenizerMaxTokens - 1) : embed;
        subList.add(1);
        return subList;
    }

    private String detokenize(List<Integer> list) {
        return this.tokenizer.decode(list, new Embedder.Context("tokenizer"), true);
    }

    private static Tensor createTensorRepresentation(List<Integer> list, String str) {
        int size = list.size();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(new TensorType.Builder(TensorType.Value.FLOAT).indexed(str, size).build());
        for (int i = 0; i < size; i++) {
            of.cell(list.get(i).intValue(), new long[]{i});
        }
        return of.build();
    }

    private static Tensor createAttentionMask(Tensor tensor) {
        return tensor.map(d -> {
            return d > 0.0d ? 1.0d : 0.0d;
        });
    }

    private void validateModels() {
        Map<String, TensorType> inputInfo = this.encoder.getInputInfo();
        validateName(inputInfo, this.encoderInputIdsName, "input");
        validateName(inputInfo, this.encoderAttentionMaskName, "input");
        validateName(this.encoder.getOutputInfo(), this.encoderOutputName, "output");
        Map<String, TensorType> inputInfo2 = this.decoder.getInputInfo();
        validateName(inputInfo2, this.decoderInputIdsName, "input");
        validateName(inputInfo2, this.decoderAttentionMaskName, "input");
        validateName(inputInfo2, this.decoderEncoderHiddenStateName, "input");
        validateName(this.decoder.getOutputInfo(), this.decoderOutputName, "output");
    }

    private void validateName(Map<String, TensorType> map, String str, String str2) {
        if (!map.containsKey(str)) {
            throw new IllegalArgumentException("Model does not contain required " + str2 + ": '" + str + "'. Model contains: " + String.join(",", map.keySet()));
        }
    }
}
