/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.spring.ai.agentexecutor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncCommandAction;
import org.bsc.langgraph4j.action.AsyncNodeActionWithConfig;
import org.bsc.langgraph4j.action.Command;
import org.bsc.langgraph4j.action.InterruptionMetadata;
import org.bsc.langgraph4j.agent.AgentEx;
import org.bsc.langgraph4j.prebuilt.MessagesState;
import org.bsc.langgraph4j.spring.ai.agent.CallModelAction;
import org.bsc.langgraph4j.spring.ai.agent.ReactAgent;
import org.bsc.langgraph4j.spring.ai.agent.ReactAgentBuilder;
import org.bsc.langgraph4j.spring.ai.serializer.std.SpringAIStateSerializer;
import org.bsc.langgraph4j.spring.ai.tool.SpringAIToolService;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.state.Channels;
import org.bsc.langgraph4j.utils.CollectionsUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;

public interface AgentExecutorEx {
    public static final Logger log = LoggerFactory.getLogger(AgentExecutorEx.class);

    public static Builder builder() {
        return new Builder();
    }

    private static AsyncCommandAction<State> dispatchAction() {
        return AsyncCommandAction.command_async((state, config) -> state.nextAction().map(Command::new).orElseGet(() -> new Command("model")));
    }

    private static AsyncCommandAction<State> approvalAction() {
        return (state, config) -> {
            CompletableFuture<Command> result = new CompletableFuture<Command>();
            if (state.value("approval_result").isEmpty()) {
                result.completeExceptionally(new IllegalStateException(String.format("resume property '%s' not found!", "approval_result")));
                return result;
            }
            String resumeState = (String)state.value("approval_result").orElseThrow(() -> new IllegalStateException(String.format("resume property '%s' not found!", "approval_result")));
            if (Objects.equals(resumeState, AgentEx.ApprovalState.APPROVED.name())) {
                result.complete(new Command(resumeState, Map.of("approval_result", AgentState.MARK_FOR_REMOVAL)));
            } else {
                String actionName = state.nextAction().map(v -> v.replace("approval_", "")).orElseThrow(() -> new IllegalStateException("no next action found!"));
                if (state.toolCallsAsStream().findAny().isEmpty()) {
                    throw new IllegalStateException("no tool execution request found!");
                }
                List<ToolResponseMessage.ToolResponse> toolResponses = state.toolCallsAsStream().map(toolCall -> {
                    Optional<ToolResponseMessage.ToolResponse> prevToolResponse = state.findToolResponseByToolCall((AssistantMessage.ToolCall)toolCall);
                    if (prevToolResponse.isPresent()) {
                        return prevToolResponse.get();
                    }
                    if (toolCall.name().equals(actionName)) {
                        return new ToolResponseMessage.ToolResponse(toolCall.id(), actionName, "execution has been DENIED!");
                    }
                    return new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), "execution has been SUSPENDED, please re-executed it!");
                }).toList();
                ToolResponseMessage toolResponseMessages = new ToolResponseMessage(toolResponses);
                result.complete(new Command(resumeState, Map.of("messages", toolResponseMessages, "tool_execution_results", toolResponseMessages, "approval_result", AgentState.MARK_FOR_REMOVAL)));
            }
            return result;
        };
    }

    private static AsyncNodeActionWithConfig<State> dispatchTools(Set<String> approvals) {
        return AsyncNodeActionWithConfig.node_async((state, config) -> {
            log.trace("DispatchTools");
            Optional<List> toolExecutionRequests = state.lastMessage().filter(m -> MessageType.ASSISTANT == m.getMessageType()).map(AssistantMessage.class::cast).filter(AssistantMessage::hasToolCalls).map(AssistantMessage::getToolCalls);
            if (toolExecutionRequests.isEmpty()) {
                return Map.of("agent_response", "no tool execution request found!");
            }
            List requests = toolExecutionRequests.get();
            return requests.stream().filter(request -> state.toolExecutionResults().stream().flatMap(r -> r.getResponses().stream()).noneMatch(r -> Objects.equals(r.name(), request.name()))).findFirst().map(result -> approvals.contains(result.name()) ? String.format("approval_%s", result.name()) : result.name()).map(actionId -> Map.of("next_action", actionId)).orElseGet(() -> Map.of("messages", state.toolExecutionResults(), "tool_execution_results", AgentState.MARK_FOR_RESET, "next_action", AgentState.MARK_FOR_REMOVAL));
        });
    }

    public static AsyncNodeActionWithConfig<State> executeTool(SpringAIToolService toolService, String actionName) {
        return (state, config) -> {
            log.trace("ExecuteTool");
            List<AssistantMessage.ToolCall> toolCalls = state.toolCallsAsStream().filter(t -> t.name().equals(actionName)).toList();
            if (toolCalls.isEmpty()) {
                return CompletableFuture.failedFuture(new IllegalArgumentException("no tool execution request found!"));
            }
            return toolService.executeFunctions(toolCalls, state.data(), "tool_execution_results").thenApply(Command::update);
        };
    }

    public static AsyncCommandAction<State> shouldContinue() {
        return AsyncCommandAction.command_async((state, config) -> {
            AssistantMessage assistantMessage;
            Message message = (Message)state.lastMessage().orElseThrow();
            String finishReason = message.getMetadata().getOrDefault("finishReason", "");
            if (Objects.equals(finishReason, "STOP")) {
                return new Command("end");
            }
            if (message instanceof AssistantMessage && (assistantMessage = (AssistantMessage)message).hasToolCalls()) {
                return new Command("continue");
            }
            return new Command("end");
        });
    }

    public static class Builder
    extends ReactAgentBuilder<Builder, State> {
        private final Map<String, AgentEx.ApprovalNodeAction<Message, State>> approvals = new LinkedHashMap<String, AgentEx.ApprovalNodeAction<Message, State>>();

        public Builder approvalOn(String actionId, BiFunction<String, State, InterruptionMetadata<State>> interruptionMetadataProvider) {
            AgentEx.ApprovalNodeAction action = AgentEx.ApprovalNodeAction.builder().interruptionMetadataProvider(interruptionMetadataProvider).build();
            this.approvals.put(actionId, (AgentEx.ApprovalNodeAction<Message, State>)action);
            return this;
        }

        @Override
        public StateGraph<State> build(Function<ReactAgentBuilder<?, ?>, ReactAgent.ChatService> chatServiceFactory) throws GraphStateException {
            if (this.stateSerializer == null) {
                this.stateSerializer = new SpringAIStateSerializer(State::new);
            }
            ReactAgent.ChatService chatService = Objects.requireNonNull(chatServiceFactory, "chatServiceFactory cannot be null!").apply(this);
            SpringAIToolService toolService = new SpringAIToolService(this.tools());
            CallModelAction callModelAction = new CallModelAction(chatService, this.streaming);
            return AgentEx.builder().stateSerializer(this.stateSerializer).schema(State.SCHEMA).toolName(tool -> tool.getToolDefinition().name()).callModelAction(callModelAction).dispatchToolsAction(AgentExecutorEx.dispatchTools(this.approvals.keySet())).executeToolFactory(toolName -> AgentExecutorEx.executeTool(toolService, toolName)).shouldContinueEdge(AgentExecutorEx.shouldContinue()).approvalActionEdge(AgentExecutorEx.approvalAction()).dispatchActionEdge(AgentExecutorEx.dispatchAction()).build((Collection)this.tools, this.approvals);
        }
    }

    public static class State
    extends MessagesState<Message> {
        static final Map<String, Channel<?>> SCHEMA = CollectionsUtils.mergeMap((Map)MessagesState.SCHEMA, Map.of("tool_execution_results", Channels.appender(ArrayList::new)));

        public State(Map<String, Object> initData) {
            super(initData);
        }

        public List<ToolResponseMessage> toolExecutionResults() {
            return (List)this.value("tool_execution_results").orElseThrow(() -> new RuntimeException("messages not found"));
        }

        public Optional<String> nextAction() {
            return this.value("next_action");
        }

        public Optional<ToolResponseMessage.ToolResponse> findToolResponseByToolCall(AssistantMessage.ToolCall toolCall) {
            return this.toolExecutionResults().stream().flatMap(r -> r.getResponses().stream()).filter(r -> toolCall.id().equals(r.id())).findAny();
        }

        private Stream<AssistantMessage.ToolCall> toolCallsAsStream() {
            return this.lastMessage().filter(m -> MessageType.ASSISTANT == m.getMessageType()).map(AssistantMessage.class::cast).filter(AssistantMessage::hasToolCalls).map(AssistantMessage::getToolCalls).stream().flatMap(Collection::stream);
        }
    }
}

