/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.chain;

import dev.langchain4j.Experimental;
import dev.langchain4j.chain.Chain;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.injector.ContentInjector;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;

@Experimental
public class RetrievalQAChain
implements Chain<Query, String> {
    private final ChatModel chatModel;
    private final RetrievalAugmentor retrievalAugmentor;

    public RetrievalQAChain(ChatModel chatModel, RetrievalAugmentor retrievalAugmentor) {
        this.chatModel = chatModel;
        this.retrievalAugmentor = retrievalAugmentor;
    }

    public String execute(Query query) {
        UserMessage userMessage = this.augment(query);
        return this.chatModel.chat(userMessage.singleText());
    }

    private UserMessage augment(Query query) {
        UserMessage from = UserMessage.from((String)query.text());
        Metadata metadata = query.metadata() == null ? Metadata.from((ChatMessage)from, null, null) : query.metadata();
        AugmentationRequest request = new AugmentationRequest((ChatMessage)from, metadata);
        AugmentationResult result = this.retrievalAugmentor.augment(request);
        return (UserMessage)result.chatMessage();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private ChatModel chatModel;
        private final DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder augmentorBuilder = DefaultRetrievalAugmentor.builder();
        private RetrievalAugmentor retrievalAugmentor;

        public Builder chatModel(ChatModel chatModel) {
            this.chatModel = chatModel;
            return this;
        }

        public Builder contentRetriever(ContentRetriever contentRetriever) {
            if (contentRetriever != null) {
                this.augmentorBuilder.contentRetriever(contentRetriever);
            }
            return this;
        }

        public Builder prompt(PromptTemplate promptTemplate) {
            DefaultContentInjector contentInjector = DefaultContentInjector.builder().promptTemplate(promptTemplate).build();
            this.augmentorBuilder.contentInjector((ContentInjector)contentInjector);
            return this;
        }

        public Builder retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
            this.retrievalAugmentor = retrievalAugmentor;
            return this;
        }

        public RetrievalQAChain build() {
            if (this.retrievalAugmentor == null) {
                return new RetrievalQAChain(this.chatModel, (RetrievalAugmentor)this.augmentorBuilder.build());
            }
            return new RetrievalQAChain(this.chatModel, this.retrievalAugmentor);
        }
    }
}

