/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.aiservice;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.guardrail.ChatExecutor;
import dev.langchain4j.guardrail.GuardrailRequestParams;
import dev.langchain4j.guardrail.InputGuardrailRequest;
import dev.langchain4j.guardrail.OutputGuardrailException;
import dev.langchain4j.guardrail.OutputGuardrailRequest;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.service.guardrail.GuardrailService;
import io.quarkiverse.langchain4j.guardrails.NoopChatExecutor;
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatEvent;
import io.quarkiverse.langchain4j.runtime.aiservice.CommittableChatMemory;
import io.smallrye.mutiny.Multi;
import jakarta.enterprise.inject.spi.CDI;
import java.lang.annotation.Annotation;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

public class GuardrailsSupport {
    static UserMessage executeInputGuardrails(GuardrailService guardrailService, UserMessage userMessage, AiServiceMethodCreateInfo methodCreateInfo, ChatMemory chatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables) {
        UserMessage um = userMessage;
        if (guardrailService.hasInputGuardrails((Object)methodCreateInfo)) {
            InputGuardrailRequest request = InputGuardrailRequest.builder().userMessage(userMessage).commonParams(GuardrailRequestParams.builder().chatMemory(chatMemory).augmentationResult(augmentationResult).userMessageTemplate(methodCreateInfo.getUserMessageTemplate()).variables(templateVariables).build()).build();
            um = guardrailService.executeGuardrails((Object)methodCreateInfo, request);
        }
        return um;
    }

    static <T> T executeOutputGuardrails(GuardrailService guardrailService, AiServiceMethodCreateInfo methodCreateInfo, ChatResponse response, ChatExecutor chatExecutor, CommittableChatMemory committableChatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables) {
        Object result = null;
        if (guardrailService.hasOutputGuardrails((Object)methodCreateInfo)) {
            OutputGuardrailRequest request = OutputGuardrailRequest.builder().responseFromLLM(response).chatExecutor(chatExecutor).requestParams(GuardrailRequestParams.builder().chatMemory((ChatMemory)committableChatMemory).augmentationResult(augmentationResult).userMessageTemplate(methodCreateInfo.getUserMessageTemplate()).variables(templateVariables).build()).build();
            result = guardrailService.executeGuardrails((Object)methodCreateInfo, request);
        }
        return (T)result;
    }

    static boolean isOutputGuardrailRetry(Throwable t) {
        return t instanceof OutputGuardrailException && t.getMessage().toLowerCase().contains("the guardrails have reached the maximum number of retries.");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Multi<ChatEvent.AccumulatedResponseEvent> accumulate(Multi<ChatEvent> upstream, AiServiceMethodCreateInfo methodCreateInfo) {
        Class<AiServiceMethodImplementationSupport> clazz = AiServiceMethodImplementationSupport.class;
        synchronized (AiServiceMethodImplementationSupport.class) {
            OutputTokenAccumulator accumulator = methodCreateInfo.getOutputTokenAccumulator();
            if (accumulator == null) {
                String cn = methodCreateInfo.getOutputTokenAccumulatorClassName();
                if (cn == null) {
                    // ** MonitorExit[var3_2] (shouldn't be in output)
                    return upstream.collect().in(ChatResponseAccumulator::new, (chatResponseAccumulator, chatEvent) -> {
                        if (chatEvent.getEventType() == ChatEvent.ChatEventType.PartialResponse) {
                            chatResponseAccumulator.stringBuilder.append(((ChatEvent.PartialResponseEvent)chatEvent).getChunk());
                        }
                        if (chatEvent.getEventType() == ChatEvent.ChatEventType.Completed) {
                            chatResponseAccumulator.metadata = ((ChatEvent.ChatCompletedEvent)chatEvent).getChatResponse().metadata();
                        }
                    }).map(acc -> new ChatEvent.AccumulatedResponseEvent(acc.stringBuilder.toString(), acc.metadata)).toMulti();
                }
                try {
                    Class<OutputTokenAccumulator> clazz2 = Class.forName(cn, true, Thread.currentThread().getContextClassLoader()).asSubclass(OutputTokenAccumulator.class);
                    accumulator = (OutputTokenAccumulator)CDI.current().select(clazz2, new Annotation[0]).get();
                    methodCreateInfo.setOutputTokenAccumulator(accumulator);
                }
                catch (Exception e) {
                    throw new RuntimeException("Could not find " + OutputTokenAccumulator.class.getSimpleName() + " implementation class: " + cn, e);
                }
            }
            // ** MonitorExit[var3_2] (shouldn't be in output)
            OutputTokenAccumulator actual = accumulator;
            AtomicReference metadataAtomicReference = new AtomicReference();
            return upstream.invoke(it -> {
                if (it.getEventType() == ChatEvent.ChatEventType.Completed) {
                    metadataAtomicReference.set(((ChatEvent.ChatCompletedEvent)it).getChatResponse().metadata());
                }
            }).filter(it -> it.getEventType() == ChatEvent.ChatEventType.PartialResponse).map(it -> ((ChatEvent.PartialResponseEvent)it).getChunk()).plug(actual::accumulate).map(s -> new ChatEvent.AccumulatedResponseEvent((String)s, Optional.ofNullable((ChatResponseMetadata)metadataAtomicReference.get()).orElse(ChatResponseMetadata.builder().build())));
        }
    }

    private static class ChatResponseAccumulator {
        private final StringBuilder stringBuilder = new StringBuilder();
        private ChatResponseMetadata metadata = null;

        ChatResponseAccumulator() {
        }
    }

    static class OutputGuardrailStreamingMapper
    implements Function<Object, Object> {
        private final GuardrailService guardrailService;
        private final AiServiceMethodCreateInfo methodCreateInfo;
        private final CommittableChatMemory committableChatMemory;
        private final AugmentationResult augmentationResult;
        private final Map<String, Object> templateVariables;
        private final boolean isStringMulti;

        OutputGuardrailStreamingMapper(GuardrailService guardrailService, AiServiceMethodCreateInfo methodCreateInfo, CommittableChatMemory committableChatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables, boolean isStringMulti) {
            this.guardrailService = guardrailService;
            this.methodCreateInfo = methodCreateInfo;
            this.committableChatMemory = committableChatMemory;
            this.augmentationResult = augmentationResult;
            this.templateVariables = templateVariables;
            this.isStringMulti = isStringMulti;
        }

        @Override
        private Object apply(ChatEvent chunk) {
            if (chunk.getEventType() == ChatEvent.ChatEventType.AccumulatedResponse) {
                ChatEvent.AccumulatedResponseEvent accumulatedChunk = (ChatEvent.AccumulatedResponseEvent)chunk;
                ChatResponseMetadata metadata = accumulatedChunk.getMetadata();
                Object guardrailResult = GuardrailsSupport.executeOutputGuardrails(this.guardrailService, this.methodCreateInfo, ChatResponse.builder().aiMessage(AiMessage.from((String)accumulatedChunk.getMessage())).build(), new NoopChatExecutor(), this.committableChatMemory, this.augmentationResult, this.templateVariables);
                if (guardrailResult instanceof ChatResponse) {
                    String message = ((ChatResponse)guardrailResult).aiMessage().text();
                    return this.isStringMulti ? message : new ChatEvent.AccumulatedResponseEvent(message, metadata);
                }
                if (guardrailResult instanceof String) {
                    return this.isStringMulti ? (String)guardrailResult : new ChatEvent.AccumulatedResponseEvent((String)guardrailResult, metadata);
                }
                if (guardrailResult != null) {
                    return this.isStringMulti ? guardrailResult.toString() : new ChatEvent.AccumulatedResponseEvent(guardrailResult.toString(), metadata);
                }
            }
            return chunk;
        }

        @Override
        private Object apply(String chunk) {
            Object guardrailResult = GuardrailsSupport.executeOutputGuardrails(this.guardrailService, this.methodCreateInfo, ChatResponse.builder().aiMessage(AiMessage.from((String)chunk)).build(), new NoopChatExecutor(), this.committableChatMemory, this.augmentationResult, this.templateVariables);
            if (guardrailResult instanceof ChatResponse) {
                return ((ChatResponse)guardrailResult).aiMessage().text();
            }
            if (guardrailResult instanceof String) {
                return (String)guardrailResult;
            }
            if (guardrailResult != null) {
                return guardrailResult.toString();
            }
            return chunk;
        }

        @Override
        public Object apply(Object chunk) {
            if (chunk instanceof ChatEvent) {
                return this.apply((ChatEvent)chunk);
            }
            if (chunk instanceof String) {
                return this.apply((String)chunk);
            }
            return chunk;
        }
    }
}

