package ai.vespa.search.llm;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.processing.Request;
import com.yahoo.processing.rendering.Renderer;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.rendering.JsonRenderer;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.text.Utf8;
import com.yahoo.yolean.Exceptions;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:ai/vespa/search/llm/LLMSearcher.class */
public class LLMSearcher extends Searcher {
    private static final Logger log = Logger.getLogger(LLMSearcher.class.getName());
    private static final String API_KEY_HEADER = "X-LLM-API-KEY";
    private static final String STREAM_PROPERTY = "stream";
    private static final String PROMPT_PROPERTY = "prompt";
    private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt";
    private static final String INCLUDE_HITS_IN_RESULT = "includeHits";
    private final JsonRenderer jsonRenderer = new JsonRenderer();
    private final String propertyPrefix;
    private final String prompt;
    private final boolean stream;
    private final LanguageModel languageModel;
    private final String languageModelId;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/search/llm/LLMSearcher$TokenStats.class */
    public static class TokenStats {
        private long timeToFirstToken;
        private long timeToLastToken;
        private long tokens = 0;
        private final long start = System.currentTimeMillis();

        TokenStats() {
        }

        void onToken() {
            if (this.tokens == 0) {
                this.timeToFirstToken = System.currentTimeMillis() - this.start;
            }
            this.tokens++;
        }

        void onCompletion() {
            this.timeToLastToken = System.currentTimeMillis() - this.start;
        }

        String report() {
            long j = this.timeToLastToken - this.timeToFirstToken;
            long j2 = this.timeToFirstToken;
            long j3 = this.tokens;
            String.format("(%.2f tokens/sec)", Double.valueOf(this.tokens / (j / 1000.0d)));
            return "Time to first token: " + j2 + " ms, Generation time: " + j2 + " ms, Generated tokens: " + j + " " + j2;
        }
    }

    @Inject
    public LLMSearcher(LlmSearcherConfig llmSearcherConfig, ComponentRegistry<LanguageModel> componentRegistry) {
        this.stream = llmSearcherConfig.stream();
        this.languageModelId = llmSearcherConfig.providerId();
        this.languageModel = findLanguageModel(this.languageModelId, componentRegistry);
        this.propertyPrefix = llmSearcherConfig.propertyPrefix();
        this.prompt = loadDefaultPrompt(llmSearcherConfig);
    }

    @Override // com.yahoo.search.Searcher
    public Result search(Query query, Execution execution) {
        return complete(query, StringPrompt.from(getPrompt(query)), null, execution);
    }

    private String loadDefaultPrompt(LlmSearcherConfig llmSearcherConfig) {
        if (llmSearcherConfig.prompt() != null && !llmSearcherConfig.prompt().isEmpty()) {
            return llmSearcherConfig.prompt();
        }
        if (!llmSearcherConfig.promptTemplate().isPresent()) {
            return null;
        }
        Path path = llmSearcherConfig.promptTemplate().get();
        try {
            return new String(Files.readAllBytes(path));
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not read prompt template file: " + String.valueOf(path), e);
        }
    }

    private LanguageModel findLanguageModel(String str, ComponentRegistry<LanguageModel> componentRegistry) throws IllegalArgumentException {
        if (componentRegistry.allComponents().isEmpty()) {
            throw new IllegalArgumentException("No language models were found");
        }
        if (str != null && !str.isEmpty()) {
            LanguageModel languageModel = (LanguageModel) componentRegistry.getComponent(str);
            if (languageModel == null) {
                throw new IllegalArgumentException("No component with id '" + str + "' was found. Available LLM components are: " + ((String) componentRegistry.allComponentsById().keySet().stream().map((v0) -> {
                    return v0.toString();
                }).collect(Collectors.joining(","))));
            }
            return languageModel;
        }
        Optional findFirst = componentRegistry.allComponentsById().entrySet().stream().findFirst();
        if (findFirst.isEmpty()) {
            throw new IllegalArgumentException("No language models were found");
        }
        log.info("Language model provider was not found in config. Fallback to using first available language model: " + String.valueOf(((Map.Entry) findFirst.get()).getKey()));
        return (LanguageModel) ((Map.Entry) findFirst.get()).getValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Result complete(Query query, Prompt prompt, Result result, Execution execution) {
        InferenceParameters inferenceParameters = new InferenceParameters(getApiKeyHeader(query), str -> {
            return lookupProperty(str, query);
        });
        try {
            return lookupPropertyBool(STREAM_PROPERTY, query, this.stream).booleanValue() ? completeAsync(query, prompt, inferenceParameters, result, execution) : completeSync(query, prompt, inferenceParameters, result, execution);
        } catch (RejectedExecutionException e) {
            return new Result(query, new ErrorMessage(429, e.getMessage()));
        }
    }

    private boolean shouldAddPrompt(Query query) {
        return query.getTrace().getLevel() >= 1 || lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false).booleanValue();
    }

    private boolean shouldAddTokenStats(Query query) {
        return query.getTrace().getLevel() >= 1;
    }

    private Result completeAsync(Query query, Prompt prompt, InferenceParameters inferenceParameters, Result result, Execution execution) {
        EventStream eventStream = new EventStream();
        if (shouldAddPrompt(query)) {
            eventStream.add(prompt.asString(), PROMPT_PROPERTY);
        }
        if (shouldAddHits(query) && result != null) {
            eventStream.add(renderHits(result, execution), "hits");
        }
        TokenStats tokenStats = new TokenStats();
        this.languageModel.completeAsync(prompt, inferenceParameters, completion -> {
            tokenStats.onToken();
            handleCompletion(eventStream, completion);
        }).exceptionally(th -> {
            handleException(eventStream, th);
            eventStream.markComplete();
            return Completion.FinishReason.error;
        }).thenAccept(finishReason -> {
            tokenStats.onCompletion();
            if (shouldAddTokenStats(query)) {
                eventStream.add(tokenStats.report(), "stats");
            }
            eventStream.markComplete();
        });
        HitGroup hitGroup = new HitGroup("token_stream");
        hitGroup.add((Hit) eventStream);
        return new Result(query, hitGroup);
    }

    private void handleCompletion(EventStream eventStream, Completion completion) {
        if (completion.finishReason() == Completion.FinishReason.error) {
            eventStream.add(completion.text(), "error");
        } else {
            eventStream.add(completion.text());
        }
    }

    private void handleException(EventStream eventStream, Throwable th) {
        int i = 400;
        if (th instanceof LanguageModelException) {
            i = ((LanguageModelException) th).code();
        } else if (th.getCause() instanceof LanguageModelException) {
            i = th.getCause().code();
        }
        eventStream.error(this.languageModelId, new ErrorMessage(i, "Error in LLM text generation", Exceptions.toMessageString(th)));
    }

    private Result completeSync(Query query, Prompt prompt, InferenceParameters inferenceParameters, Result result, Execution execution) {
        EventStream eventStream = new EventStream();
        if (shouldAddPrompt(query)) {
            eventStream.add(prompt.asString(), PROMPT_PROPERTY);
        }
        if (shouldAddHits(query) && result != null) {
            eventStream.add(renderHits(result, execution), "hits");
        }
        eventStream.add(((Completion) this.languageModel.complete(prompt, inferenceParameters).get(0)).text(), "completion");
        eventStream.markComplete();
        HitGroup hitGroup = new HitGroup("completion");
        hitGroup.add((Hit) eventStream);
        return new Result(query, hitGroup);
    }

    public String getPrompt(Query query) {
        String lookupPropertyWithOrWithoutPrefix = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, str -> {
            return query.m60properties().getString(str);
        });
        if (lookupPropertyWithOrWithoutPrefix != null) {
            return lookupPropertyWithOrWithoutPrefix;
        }
        if (this.prompt != null && !this.prompt.isEmpty()) {
            return this.prompt;
        }
        String queryString = query.getModel().getQueryString();
        if (queryString != null) {
            return queryString;
        }
        throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for '" + this.propertyPrefix + ".prompt', 'prompt' or '@query'.");
    }

    public String getPropertyPrefix() {
        return this.propertyPrefix;
    }

    public String lookupProperty(String str, Query query) {
        return query.m60properties().getString(this.propertyPrefix + "." + str, null);
    }

    public Boolean lookupPropertyBool(String str, Query query, boolean z) {
        return Boolean.valueOf(query.m60properties().getBoolean(this.propertyPrefix + "." + str, z));
    }

    public String lookupPropertyWithOrWithoutPrefix(String str, Function<String, String> function) {
        String apply = function.apply(getPropertyPrefix() + "." + str);
        return apply != null ? apply : function.apply(str);
    }

    public String getApiKeyHeader(Query query) {
        return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, str -> {
            return query.getHttpRequest().getHeader(str);
        });
    }

    private boolean shouldAddHits(Query query) {
        return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false).booleanValue();
    }

    private String renderHits(Result result, Execution execution) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        Renderer clone = this.jsonRenderer.clone();
        clone.init();
        clone.renderResponse(byteArrayOutputStream, result, execution, (Request) null).join();
        return Utf8.toString(byteArrayOutputStream.toByteArray());
    }
}
