package com.alibaba.cloud.ai.mcp.nacos2.client.transport;

import com.alibaba.cloud.ai.mcp.nacos2.registry.model.McpNacosConstant;
import com.alibaba.cloud.ai.mcp.nacos2.registry.model.McpToolsInfo;
import com.alibaba.nacos.api.exception.NacosException;
import com.alibaba.nacos.api.naming.NamingService;
import com.alibaba.nacos.api.naming.listener.Event;
import com.alibaba.nacos.api.naming.listener.EventListener;
import com.alibaba.nacos.api.naming.listener.NamingEvent;
import com.alibaba.nacos.api.naming.pojo.Instance;
import com.alibaba.nacos.client.config.NacosConfigService;
import com.alibaba.nacos.common.utils.JacksonUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.client.autoconfigure.NamedClientMcpTransport;
import org.springframework.ai.mcp.client.autoconfigure.configurer.McpAsyncClientConfigurer;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

/* loaded from: input_file:com/alibaba/cloud/ai/mcp/nacos2/client/transport/LoadbalancedMcpAsyncClient.class */
public class LoadbalancedMcpAsyncClient implements EventListener {
    private static final Logger logger;
    private final String serviceName;
    private final NamingService namingService;
    private final NacosConfigService nacosConfigService;
    private final McpClientCommonProperties commonProperties;
    private final WebClient.Builder webClientBuilderTemplate;
    private final McpAsyncClientConfigurer mcpAsyncClientConfigurer;
    private final ObjectMapper objectMapper;
    private final ApplicationContext applicationContext;
    private Map<String, List<String>> md5ToToolsMap;
    private Map<String, List<McpAsyncClient>> md5ToClientMap;
    private List<Instance> instances;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final Long TIME_OUT_MS = 3000L;
    private final AtomicInteger index = new AtomicInteger(0);

    /* loaded from: input_file:com/alibaba/cloud/ai/mcp/nacos2/client/transport/LoadbalancedMcpAsyncClient$Builder.class */
    public static class Builder {
        private String serviceName;
        private String serviceGroup;
        private NamingService namingService;
        private NacosConfigService nacosConfigService;
        private ApplicationContext applicationContext;

        public Builder serviceName(String str) {
            this.serviceName = str;
            return this;
        }

        public Builder serviceGroup(String str) {
            this.serviceGroup = str;
            return this;
        }

        public Builder namingService(NamingService namingService) {
            this.namingService = namingService;
            return this;
        }

        public Builder nacosConfigService(NacosConfigService nacosConfigService) {
            this.nacosConfigService = nacosConfigService;
            return this;
        }

        public Builder applicationContext(ApplicationContext applicationContext) {
            this.applicationContext = applicationContext;
            return this;
        }

        public LoadbalancedMcpAsyncClient build() {
            return new LoadbalancedMcpAsyncClient(this.serviceName, this.serviceGroup, this.namingService, this.nacosConfigService, this.applicationContext);
        }
    }

    public LoadbalancedMcpAsyncClient(String str, String str2, NamingService namingService, NacosConfigService nacosConfigService, ApplicationContext applicationContext) {
        Assert.notNull(str, "serviceName cannot be null");
        Assert.notNull(str2, "serviceGroup cannot be null");
        Assert.notNull(namingService, "namingService cannot be null");
        Assert.notNull(nacosConfigService, "nacosConfigService cannot be null");
        Assert.notNull(applicationContext, "applicationContext cannot be null");
        this.serviceName = str;
        this.nacosConfigService = nacosConfigService;
        this.applicationContext = applicationContext;
        try {
            this.namingService = namingService;
            this.instances = namingService.selectInstances(this.serviceName + "-mcp-service", str2, true);
            this.commonProperties = (McpClientCommonProperties) this.applicationContext.getBean(McpClientCommonProperties.class);
            this.mcpAsyncClientConfigurer = (McpAsyncClientConfigurer) this.applicationContext.getBean(McpAsyncClientConfigurer.class);
            this.objectMapper = (ObjectMapper) this.applicationContext.getBean(ObjectMapper.class);
            this.webClientBuilderTemplate = (WebClient.Builder) this.applicationContext.getBean(WebClient.Builder.class);
        } catch (NacosException e) {
            throw new RuntimeException(String.format("Failed to get instances for service: %s", str));
        }
    }

    public void init() {
        this.md5ToToolsMap = new ConcurrentHashMap();
        this.md5ToClientMap = new ConcurrentHashMap();
        Iterator<Instance> it = this.instances.iterator();
        while (it.hasNext()) {
            updateByAddInstance(it.next());
        }
    }

    public void subscribe() {
        try {
            this.namingService.subscribe(this.serviceName + "-mcp-service", McpNacosConstant.SERVER_GROUP, this);
        } catch (NacosException e) {
            throw new RuntimeException(String.format("Failed to subscribe to service: %s", this.serviceName));
        }
    }

    public McpAsyncClient getMcpAsyncClient() {
        List<McpAsyncClient> mcpAsyncClientList = getMcpAsyncClientList();
        if (mcpAsyncClientList.isEmpty()) {
            throw new IllegalStateException("No McpAsyncClient available");
        }
        return mcpAsyncClientList.get(this.index.getAndUpdate(i -> {
            return (i + 1) % mcpAsyncClientList.size();
        }));
    }

    public List<McpAsyncClient> getMcpAsyncClientList() {
        return this.md5ToClientMap.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).toList();
    }

    public String getServiceName() {
        return this.serviceName;
    }

    public NamingService getNamingService() {
        return this.namingService;
    }

    public List<Instance> getInstances() {
        return this.instances;
    }

    public McpSchema.ServerCapabilities getServerCapabilities() {
        return getMcpAsyncClientList().get(0).getServerCapabilities();
    }

    public McpSchema.Implementation getServerInfo() {
        return getMcpAsyncClientList().get(0).getServerInfo();
    }

    public McpSchema.ClientCapabilities getClientCapabilities() {
        return getMcpAsyncClientList().get(0).getClientCapabilities();
    }

    public McpSchema.Implementation getClientInfo() {
        return getMcpAsyncClientList().get(0).getClientInfo();
    }

    public void close() {
        Iterator<McpAsyncClient> it = getMcpAsyncClientList().iterator();
        while (it.hasNext()) {
            McpAsyncClient next = it.next();
            next.close();
            it.remove();
            logger.info("Closed and removed McpSyncClient: {}", next.getClientInfo().name());
        }
    }

    public Mono<Void> closeGracefully() {
        Iterator<McpAsyncClient> it = getMcpAsyncClientList().iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            McpAsyncClient next = it.next();
            arrayList.add(next.closeGracefully().doOnSuccess(r6 -> {
                it.remove();
                logger.info("Closed and removed McpAsyncClient: {}", next.getClientInfo().name());
            }));
        }
        return Mono.when(arrayList);
    }

    public Mono<Object> ping() {
        return getMcpAsyncClient().ping();
    }

    public Mono<Void> addRoot(McpSchema.Root root) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.addRoot(root);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> removeRoot(String str) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.removeRoot(str);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> rootsListChangedNotification() {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map((v0) -> {
            return v0.rootsListChangedNotification();
        }).collect(Collectors.toList()));
    }

    public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
        String name = callToolRequest.name();
        ArrayList arrayList = new ArrayList();
        this.md5ToToolsMap.forEach((str, list) -> {
            if (list.contains(name)) {
                arrayList.addAll(this.md5ToClientMap.get(str));
            }
        });
        return ((McpAsyncClient) arrayList.get(this.index.getAndUpdate(i -> {
            return (i + 1) % arrayList.size();
        }))).callTool(callToolRequest);
    }

    public Mono<McpSchema.ListToolsResult> listTools() {
        return listToolsInternal(null);
    }

    public Mono<McpSchema.ListToolsResult> listTools(String str) {
        return listToolsInternal(str);
    }

    private Mono<McpSchema.ListToolsResult> listToolsInternal(String str) {
        return loadConfig().flatMap(str2 -> {
            return parseConfig(str2, str);
        });
    }

    private Mono<String> loadConfig() {
        return Mono.fromCallable(() -> {
            String config = this.nacosConfigService.getConfig(this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP, this.TIME_OUT_MS.longValue());
            if (config == null || config.isEmpty()) {
                throw new RuntimeException(String.format("Empty tool config content for dataId: %s, group: %s", this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP));
            }
            return config;
        }).onErrorMap(th -> {
            throw new RuntimeException(String.format("Empty tool config content for dataId: %s, group: %s", this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP));
        });
    }

    private Mono<McpSchema.ListToolsResult> parseConfig(String str, String str2) {
        return Mono.fromCallable(() -> {
            try {
                return new McpSchema.ListToolsResult(((McpToolsInfo) this.objectMapper.readValue(str, McpToolsInfo.class)).getTools(), str2);
            } catch (JsonProcessingException e) {
                logger.error("Failed to parse config for dataId: {}, group: {}", new Object[]{this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP, e});
                throw new RuntimeException(String.format("Failed to parse tool list, dataId: %s, group: %s\"", this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP), e);
            }
        }).onErrorMap(th -> {
            logger.error("Unexpected error during parsing tool config for dataId: {}, group: {}", new Object[]{this.serviceName + "-mcp-tools.json", McpNacosConstant.TOOLS_GROUP, th});
            return new RuntimeException(th);
        });
    }

    public Mono<McpSchema.ListResourcesResult> listResources() {
        return getMcpAsyncClient().listResources();
    }

    public Mono<McpSchema.ListResourcesResult> listResources(String str) {
        return getMcpAsyncClient().listResources(str);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resource) {
        return getMcpAsyncClient().readResource(resource);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
        return getMcpAsyncClient().readResource(readResourceRequest);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
        return getMcpAsyncClient().listResourceTemplates();
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String str) {
        return getMcpAsyncClient().listResourceTemplates(str);
    }

    public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.subscribeResource(subscribeRequest);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.unsubscribeResource(unsubscribeRequest);
        }).collect(Collectors.toList()));
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts() {
        return getMcpAsyncClient().listPrompts();
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts(String str) {
        return getMcpAsyncClient().listPrompts(str);
    }

    public Mono<McpSchema.GetPromptResult> getPrompt(McpSchema.GetPromptRequest getPromptRequest) {
        return getMcpAsyncClient().getPrompt(getPromptRequest);
    }

    public Mono<Void> setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.setLoggingLevel(loggingLevel);
        }).collect(Collectors.toList()));
    }

    public void onEvent(Event event) {
        if (event instanceof NamingEvent) {
            NamingEvent namingEvent = (NamingEvent) event;
            if ((this.serviceName + "-mcp-service").equals(namingEvent.getServiceName())) {
                logger.info("Received service instance change event for service: {}", namingEvent.getServiceName());
                List<Instance> instances = namingEvent.getInstances();
                logger.info("Updated instances count: {}", Integer.valueOf(instances.size()));
                instances.forEach(instance -> {
                    logger.info("Instance: {}:{} (Healthy: {}, Enabled: {}, Metadata: {})", new Object[]{instance.getIp(), Integer.valueOf(instance.getPort()), Boolean.valueOf(instance.isHealthy()), Boolean.valueOf(instance.isEnabled()), JacksonUtils.toJson(instance.getMetadata())});
                });
                updateClientList(instances);
            }
        }
    }

    private void updateClientList(List<Instance> list) {
        Iterator it = ((List) list.stream().filter(instance -> {
            return !this.instances.contains(instance);
        }).collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            updateByAddInstance((Instance) it.next());
        }
        Iterator it2 = ((List) this.instances.stream().filter(instance2 -> {
            return !list.contains(instance2);
        }).collect(Collectors.toList())).iterator();
        while (it2.hasNext()) {
            updateByRemoveInstance((Instance) it2.next());
        }
        this.instances = list;
    }

    private McpAsyncClient clientByInstance(Instance instance) {
        NamedClientMcpTransport namedClientMcpTransport = new NamedClientMcpTransport(this.serviceName + "-" + instance.getInstanceId(), new WebFluxSseClientTransport(this.webClientBuilderTemplate.clone().baseUrl(((String) instance.getMetadata().getOrDefault("scheme", "http")) + "://" + instance.getIp() + ":" + instance.getPort()), this.objectMapper));
        McpSchema.Implementation implementation = new McpSchema.Implementation(connectedClientName(this.commonProperties.getName(), namedClientMcpTransport.name()), this.commonProperties.getVersion());
        McpAsyncClient build = this.mcpAsyncClientConfigurer.configure(namedClientMcpTransport.name(), McpClient.async(namedClientMcpTransport.transport()).clientInfo(implementation).requestTimeout(this.commonProperties.getRequestTimeout())).build();
        if (this.commonProperties.isInitialized()) {
            build.initialize().block();
        }
        logger.info("Added McpAsyncClient: {}", implementation.name());
        return build;
    }

    private void updateByAddInstance(Instance instance) {
        Map metadata = instance.getMetadata();
        String str = (String) metadata.get("server.md5");
        if (!$assertionsDisabled && str == null) {
            throw new AssertionError();
        }
        this.md5ToClientMap.computeIfAbsent(str, str2 -> {
            return new ArrayList();
        }).add(clientByInstance(instance));
        if (this.md5ToToolsMap.containsKey(str)) {
            return;
        }
        this.md5ToToolsMap.put(str, List.of((Object[]) ((String) metadata.get("tools.names")).split(",")));
    }

    private void updateByRemoveInstance(Instance instance) {
        String connectedClientName = connectedClientName(this.commonProperties.getName(), this.serviceName + "-" + instance.getInstanceId());
        String str = (String) instance.getMetadata().get("server.md5");
        for (McpAsyncClient mcpAsyncClient : this.md5ToClientMap.getOrDefault(str, Collections.emptyList())) {
            String name = mcpAsyncClient.getClientInfo().name();
            if (connectedClientName.equals(name)) {
                logger.info("Removing McpAsyncClient: {}", name);
                mcpAsyncClient.closeGracefully().block();
                this.md5ToClientMap.get(str).remove(mcpAsyncClient);
                if (this.md5ToClientMap.get(str).isEmpty()) {
                    this.md5ToClientMap.remove(str);
                    this.md5ToToolsMap.remove(str);
                }
                logger.info("Removed McpAsyncClient: {} Success", mcpAsyncClient.getClientInfo().name());
                return;
            }
        }
    }

    private String connectedClientName(String str, String str2) {
        return str + " - " + str2;
    }

    public static Builder builder() {
        return new Builder();
    }

    static {
        $assertionsDisabled = !LoadbalancedMcpAsyncClient.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(LoadbalancedMcpAsyncClient.class);
    }
}
