/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.rag.query.router;

import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.QueryRouter;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

public class LanguageModelQueryRouter
implements QueryRouter {
    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("Based on the user query, determine the most suitable data source(s) to retrieve relevant information from the following options:\n{{options}}\nIt is very important that your answer consists of either a single number or multiple numbers separated by commas and nothing else!\nUser query: {{query}}");
    private final ChatLanguageModel chatLanguageModel;
    private final PromptTemplate promptTemplate;
    private final String options;
    private final Map<Integer, ContentRetriever> idToRetriever;

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map<ContentRetriever, String> retrieverToDescription) {
        this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE);
    }

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map<ContentRetriever, String> retrieverToDescription, PromptTemplate promptTemplate) {
        this.chatLanguageModel = ValidationUtils.ensureNotNull(chatLanguageModel, "chatLanguageModel");
        ValidationUtils.ensureNotEmpty(retrieverToDescription, "retrieverToDescription");
        this.promptTemplate = Utils.getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        HashMap<Integer, ContentRetriever> idToRetriever = new HashMap<Integer, ContentRetriever>();
        StringBuilder optionsBuilder = new StringBuilder();
        int id = 1;
        for (Map.Entry<ContentRetriever, String> entry : retrieverToDescription.entrySet()) {
            idToRetriever.put(id, ValidationUtils.ensureNotNull(entry.getKey(), "ContentRetriever"));
            if (id > 1) {
                optionsBuilder.append("\n");
            }
            optionsBuilder.append(id);
            optionsBuilder.append(": ");
            optionsBuilder.append(ValidationUtils.ensureNotBlank(entry.getValue(), "ContentRetriever description"));
            ++id;
        }
        this.idToRetriever = idToRetriever;
        this.options = optionsBuilder.toString();
    }

    @Override
    public Collection<ContentRetriever> route(Query query) {
        Prompt prompt = this.createPrompt(query);
        String response = this.chatLanguageModel.generate(prompt.text());
        return this.parse(response);
    }

    protected Prompt createPrompt(Query query) {
        HashMap<String, Object> variables = new HashMap<String, Object>();
        variables.put("query", query.text());
        variables.put("options", this.options);
        return this.promptTemplate.apply(variables);
    }

    protected Collection<ContentRetriever> parse(String choices) {
        return Arrays.stream(choices.split(",")).map(String::trim).map(Integer::parseInt).map(this.idToRetriever::get).collect(Collectors.toList());
    }

    public static LanguageModelQueryRouterBuilder builder() {
        return new LanguageModelQueryRouterBuilder();
    }

    public static class LanguageModelQueryRouterBuilder {
        private ChatLanguageModel chatLanguageModel;
        private Map<ContentRetriever, String> retrieverToDescription;
        private PromptTemplate promptTemplate;

        LanguageModelQueryRouterBuilder() {
        }

        public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public LanguageModelQueryRouterBuilder retrieverToDescription(Map<ContentRetriever, String> retrieverToDescription) {
            this.retrieverToDescription = retrieverToDescription;
            return this;
        }

        public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public LanguageModelQueryRouter build() {
            return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate);
        }

        public String toString() {
            return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ")";
        }
    }
}

