/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.model.tool;

import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.observation.DefaultToolCallingObservationConvention;
import org.springframework.ai.tool.observation.ToolCallingObservationContext;
import org.springframework.ai.tool.observation.ToolCallingObservationConvention;
import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public final class DefaultToolCallingManager
implements ToolCallingManager {
    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);
    private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY = ObservationRegistry.NOOP;
    private static final ToolCallingObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultToolCallingObservationConvention();
    private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER = new DelegatingToolCallbackResolver(List.of());
    private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR = DefaultToolExecutionExceptionProcessor.builder().build();
    private static final String POSSIBLE_LLM_TOOL_NAME_CHANGE_WARNING = "LLM may have adapted the tool name '{}', especially if the name was truncated due to length limits. If this is the case, you can customize the prefixing and processing logic using McpToolNamePrefixGenerator";
    private final ObservationRegistry observationRegistry;
    private final ToolCallbackResolver toolCallbackResolver;
    private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;
    private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry cannot be null");
        Assert.notNull((Object)toolCallbackResolver, (String)"toolCallbackResolver cannot be null");
        Assert.notNull((Object)toolExecutionExceptionProcessor, (String)"toolCallExceptionConverter cannot be null");
        this.observationRegistry = observationRegistry;
        this.toolCallbackResolver = toolCallbackResolver;
        this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
    }

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
        Assert.notNull((Object)chatOptions, (String)"chatOptions cannot be null");
        ArrayList<ToolCallback> toolCallbacks = new ArrayList<ToolCallback>(chatOptions.getToolCallbacks());
        for (String toolName : chatOptions.getToolNames()) {
            if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getToolDefinition().name().equals(toolName))) continue;
            ToolCallback toolCallback = this.toolCallbackResolver.resolve(toolName);
            if (toolCallback == null) {
                logger.warn(POSSIBLE_LLM_TOOL_NAME_CHANGE_WARNING, (Object)toolName);
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }
            toolCallbacks.add(toolCallback);
        }
        return toolCallbacks.stream().map(ToolCallback::getToolDefinition).toList();
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
        Assert.notNull((Object)prompt, (String)"prompt cannot be null");
        Assert.notNull((Object)chatResponse, (String)"chatResponse cannot be null");
        Optional<Generation> toolCallGeneration = chatResponse.getResults().stream().filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())).findFirst();
        if (toolCallGeneration.isEmpty()) {
            throw new IllegalStateException("No tool call requested by the chat model");
        }
        AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
        ToolContext toolContext = DefaultToolCallingManager.buildToolContext(prompt, assistantMessage);
        InternalToolExecutionResult internalToolExecutionResult = this.executeToolCall(prompt, assistantMessage, toolContext);
        List<Message> conversationHistory = this.buildConversationHistoryAfterToolExecution((List<Message>)prompt.getInstructions(), assistantMessage, internalToolExecutionResult.toolResponseMessage());
        return ToolExecutionResult.builder().conversationHistory(conversationHistory).returnDirect(internalToolExecutionResult.returnDirect()).build();
    }

    private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) {
        ToolCallingChatOptions toolCallingChatOptions;
        Map<String, Object> toolContextMap = Map.of();
        ChatOptions chatOptions = prompt.getOptions();
        if (chatOptions instanceof ToolCallingChatOptions && !CollectionUtils.isEmpty((toolCallingChatOptions = (ToolCallingChatOptions)chatOptions).getToolContext())) {
            toolContextMap = new HashMap<String, Object>(toolCallingChatOptions.getToolContext());
            toolContextMap.put("TOOL_CALL_HISTORY", DefaultToolCallingManager.buildConversationHistoryBeforeToolExecution(prompt, assistantMessage));
        }
        return new ToolContext(toolContextMap);
    }

    private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt prompt, AssistantMessage assistantMessage) {
        ArrayList<Message> messageHistory = new ArrayList<Message>((Collection<Message>)prompt.copy().getInstructions());
        messageHistory.add(AssistantMessage.builder().content(assistantMessage.getText()).properties(assistantMessage.getMetadata()).toolCalls(assistantMessage.getToolCalls()).build());
        return messageHistory;
    }

    private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage, ToolContext toolContext) {
        List<Object> toolCallbacks = List.of();
        ChatOptions chatOptions = prompt.getOptions();
        if (chatOptions instanceof ToolCallingChatOptions) {
            ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)chatOptions;
            toolCallbacks = toolCallingChatOptions.getToolCallbacks();
        }
        ArrayList<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<ToolResponseMessage.ToolResponse>();
        Boolean returnDirect = null;
        for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
            String finalToolInputArguments;
            logger.debug("Executing tool call: {}", (Object)toolCall.name());
            String toolName = toolCall.name();
            String toolInputArguments = toolCall.arguments();
            if (!StringUtils.hasText((String)toolInputArguments)) {
                logger.warn("Tool call arguments are null or empty for tool: {}. Using empty JSON object as default.", (Object)toolName);
                finalToolInputArguments = "{}";
            } else {
                finalToolInputArguments = toolInputArguments;
            }
            ToolCallback toolCallback = toolCallbacks.stream().filter(tool -> toolName.equals(tool.getToolDefinition().name())).findFirst().orElseGet(() -> this.toolCallbackResolver.resolve(toolName));
            if (toolCallback == null) {
                logger.warn(POSSIBLE_LLM_TOOL_NAME_CHANGE_WARNING, (Object)toolName);
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }
            returnDirect = returnDirect == null ? Boolean.valueOf(toolCallback.getToolMetadata().returnDirect()) : Boolean.valueOf(returnDirect != false && toolCallback.getToolMetadata().returnDirect());
            ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder().toolDefinition(toolCallback.getToolDefinition()).toolMetadata(toolCallback.getToolMetadata()).toolCallArguments(finalToolInputArguments).build();
            String toolCallResult = (String)ToolCallingObservationDocumentation.TOOL_CALL.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
                String toolResult;
                try {
                    toolResult = toolCallback.call(finalToolInputArguments, toolContext);
                }
                catch (ToolExecutionException ex) {
                    toolResult = this.toolExecutionExceptionProcessor.process(ex);
                }
                observationContext.setToolCallResult(toolResult);
                return toolResult;
            });
            toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolCallResult != null ? toolCallResult : ""));
        }
        return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
    }

    private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages, AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
        ArrayList<Message> messages = new ArrayList<Message>(previousMessages);
        messages.add(assistantMessage);
        messages.add(toolResponseMessage);
        return messages;
    }

    public void setObservationConvention(ToolCallingObservationConvention observationConvention) {
        this.observationConvention = observationConvention;
    }

    public static Builder builder() {
        return new Builder();
    }

    private record InternalToolExecutionResult(ToolResponseMessage toolResponseMessage, boolean returnDirect) {
    }

    public static final class Builder {
        private ObservationRegistry observationRegistry = DEFAULT_OBSERVATION_REGISTRY;
        private ToolCallbackResolver toolCallbackResolver = DEFAULT_TOOL_CALLBACK_RESOLVER;
        private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;

        private Builder() {
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
            this.toolCallbackResolver = toolCallbackResolver;
            return this;
        }

        public Builder toolExecutionExceptionProcessor(ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
            this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
            return this;
        }

        public DefaultToolCallingManager build() {
            return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver, this.toolExecutionExceptionProcessor);
        }
    }
}

