/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.a2a.server;

import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.BaseAgent;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import io.a2a.A2A;
import io.a2a.server.agentexecution.AgentExecutor;
import io.a2a.server.agentexecution.RequestContext;
import io.a2a.server.events.EventQueue;
import io.a2a.server.tasks.TaskUpdater;
import io.a2a.spec.Event;
import io.a2a.spec.JSONRPCError;
import io.a2a.spec.Message;
import io.a2a.spec.MessageSendParams;
import io.a2a.spec.Part;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class GraphAgentExecutor
implements AgentExecutor {
    private static final Logger LOGGER = LoggerFactory.getLogger(GraphAgentExecutor.class);
    private static final Set<String> IGNORE_NODE_TYPE = Set.of("preLlm", "postLlm", "preTool", "tool", "postTool");
    public static final String STREAMING_METADATA_KEY = "isStreaming";
    private final BaseAgent executeAgent;

    public GraphAgentExecutor(BaseAgent executeAgent) {
        this.executeAgent = executeAgent;
    }

    private Task newTask(Message request) {
        String contextId = request.getContextId();
        if (contextId == null || contextId.isEmpty()) {
            contextId = UUID.randomUUID().toString();
        }
        String id = UUID.randomUUID().toString();
        if (request.getTaskId() != null && !request.getTaskId().isEmpty()) {
            id = request.getTaskId();
        }
        return new Task(id, contextId, new TaskStatus(TaskState.SUBMITTED), null, List.of(request), null);
    }

    public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
        try {
            Message message = context.getParams().message();
            StringBuilder sb = new StringBuilder();
            for (Part each : message.getParts()) {
                if (!Part.Kind.TEXT.equals((Object)each.getKind())) continue;
                sb.append(((TextPart)each).getText()).append("\n");
            }
            Map<String, Object> input = Map.of("messages", List.of(new UserMessage(sb.toString().trim())));
            if (this.isStreamRequest(context)) {
                this.executeStreamTask(input, context, eventQueue);
            } else {
                this.executeForNonStreamTask(input, context, eventQueue);
            }
        }
        catch (Exception e) {
            LOGGER.error("Agent execution failed", (Throwable)e);
            eventQueue.enqueueEvent((Event)A2A.toAgentMessage((String)("Agent execution failed: " + e.getMessage())));
        }
    }

    public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
    }

    private boolean isStreamRequest(RequestContext context) {
        MessageSendParams params = context.getParams();
        if (null == params.metadata()) {
            return false;
        }
        if (!params.metadata().containsKey(STREAMING_METADATA_KEY)) {
            return false;
        }
        return (Boolean)params.metadata().get(STREAMING_METADATA_KEY);
    }

    private RunnableConfig getRunnableConfig(RequestContext context) {
        RunnableConfig.Builder builder = RunnableConfig.builder();
        MessageSendParams params = context.getParams();
        if (params != null && params.metadata() != null) {
            Object threadIdObj;
            Map metadata = params.metadata();
            if (metadata.containsKey("threadId") && (threadIdObj = metadata.get("threadId")) instanceof String) {
                builder.threadId((String)threadIdObj);
            }
            for (Map.Entry entry : metadata.entrySet()) {
                builder.addMetadata((String)entry.getKey(), entry.getValue());
            }
        }
        return builder.build();
    }

    private void executeStreamTask(Map<String, Object> input, RequestContext context, EventQueue eventQueue) throws GraphStateException, GraphRunnerException {
        RunnableConfig runnableConfig = this.getRunnableConfig(context);
        Flux generator = this.executeAgent.stream(input, runnableConfig);
        Task task = context.getTask();
        if (task == null) {
            task = this.newTask(context.getMessage());
            eventQueue.enqueueEvent((Event)task);
        }
        TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
        taskUpdater.submit();
        generator.subscribe((Consumer)new ReactAgentNodeOutputConsumer(taskUpdater), throwable -> {
            LOGGER.error("Agent execution failed", throwable);
            taskUpdater.fail(A2A.toAgentMessage((String)throwable.getMessage()));
        }, () -> ((TaskUpdater)taskUpdater).complete());
        this.waitTaskCompleted(task);
    }

    private void executeForNonStreamTask(Map<String, Object> input, RequestContext context, EventQueue eventQueue) throws GraphStateException, GraphRunnerException {
        RunnableConfig runnableConfig = this.getRunnableConfig(context);
        Optional result = this.executeAgent.invoke(input, runnableConfig);
        String outputText = ((OverAllState)result.get()).data().containsKey(this.executeAgent.outputKey()) ? String.valueOf(((OverAllState)result.get()).data().get(this.executeAgent.outputKey())) : "No output key in result.";
        Task task = context.getTask();
        if (task == null) {
            task = this.newTask(context.getMessage());
            eventQueue.enqueueEvent((Event)task);
        }
        TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
        boolean taskComplete = true;
        boolean requireUserInput = false;
        if (!taskComplete && !requireUserInput) {
            taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
        } else if (requireUserInput) {
            taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
        } else {
            taskUpdater.addArtifact(List.of(new TextPart(outputText)), UUID.randomUUID().toString(), "conversation_result", Map.of("output", outputText));
            taskUpdater.complete();
        }
    }

    private void waitTaskCompleted(Task task) {
        while (!task.getStatus().state().equals((Object)TaskState.COMPLETED) && !task.getStatus().state().equals((Object)TaskState.CANCELED)) {
            try {
                TimeUnit.SECONDS.sleep(1L);
            }
            catch (InterruptedException interruptedException) {}
        }
    }

    private static class ReactAgentNodeOutputConsumer
    implements Consumer<NodeOutput> {
        private final TaskUpdater taskUpdater;
        private final AtomicInteger artifactNum;

        private ReactAgentNodeOutputConsumer(TaskUpdater taskUpdater) {
            this.taskUpdater = taskUpdater;
            this.artifactNum = new AtomicInteger();
        }

        @Override
        public void accept(NodeOutput nodeOutput) {
            if (nodeOutput.isSTART() || nodeOutput.isEND() || IGNORE_NODE_TYPE.contains(nodeOutput.node())) {
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("Agent parts output: {}", (Object)this.buildDebugDetailInfo(nodeOutput));
                }
                return;
            }
            String content = "";
            if (nodeOutput instanceof StreamingOutput) {
                content = ((StreamingOutput)nodeOutput).chunk();
            }
            if (!StringUtils.hasLength((String)content)) {
                return;
            }
            this.taskUpdater.addArtifact(Collections.singletonList(new TextPart(content)), null, String.valueOf(this.artifactNum.incrementAndGet()), Map.of());
        }

        private String buildDebugDetailInfo(NodeOutput nodeOutput) {
            JSONObject outputJson = new JSONObject();
            outputJson.put("data", (Object)nodeOutput.state().data());
            outputJson.put("node", (Object)nodeOutput.node());
            return JSON.toJSONString((Object)outputJson);
        }
    }
}

