package ai.vespa.search.llm;

import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:ai/vespa/search/llm/RAGSearcher.class */
public class RAGSearcher extends LLMSearcher {
    private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
    private static final String CONTEXT_PROPERTY = "context";
    private static final String FIELDS_TO_INCLUDE_PROPERTY = "fields";

    @Inject
    public RAGSearcher(LlmSearcherConfig llmSearcherConfig, ComponentRegistry<LanguageModel> componentRegistry) {
        super(llmSearcherConfig, componentRegistry);
        log.info("Starting " + RAGSearcher.class.getName() + " with language model " + llmSearcherConfig.providerId());
    }

    @Override // ai.vespa.search.llm.LLMSearcher, com.yahoo.search.Searcher
    public Result search(Query query, Execution execution) {
        Result search = execution.search(query);
        execution.fill(search);
        return complete(query, buildPrompt(query, search), search, execution);
    }

    protected Prompt buildPrompt(Query query, Result result) {
        String prompt = getPrompt(query);
        if (prompt.contains("@query")) {
            prompt = prompt.replace("@query", query.getModel().getQueryString());
        }
        String lookupProperty = lookupProperty("context", query);
        if (lookupProperty == null || !lookupProperty.equals("skip")) {
            if (!prompt.contains("{context}")) {
                prompt = "{context}\n" + prompt;
            }
            prompt = prompt.replace("{context}", buildContext(result));
        }
        return StringPrompt.from(prompt);
    }

    private String buildContext(Result result) {
        Set<String> fieldsToInclude = getFieldsToInclude(result.getQuery());
        StringBuilder sb = new StringBuilder();
        int i = 1;
        Iterator<Hit> it = result.hits().iterator();
        while (it.hasNext()) {
            Hit next = it.next();
            int i2 = i;
            i++;
            sb.append("document [").append(i2).append("]:\n");
            next.fields().forEach((str, obj) -> {
                if (fieldsToInclude.isEmpty() || fieldsToInclude.contains(str)) {
                    sb.append(str).append(": ").append(obj).append("\n");
                }
            });
            sb.append("\n");
        }
        return sb.toString();
    }

    private Set<String> getFieldsToInclude(Query query) {
        String lookupProperty = lookupProperty(FIELDS_TO_INCLUDE_PROPERTY, query);
        return lookupProperty != null ? (Set) Arrays.stream(lookupProperty.split(",")).map((v0) -> {
            return v0.trim();
        }).collect(Collectors.toSet()) : new HashSet();
    }
}
