package com.alibaba.cloud.ai.mcp.nacos.client.transport;

import com.alibaba.cloud.ai.mcp.nacos.client.utils.NacosMcpClientUtils;
import com.alibaba.cloud.ai.mcp.nacos.service.NacosMcpOperationService;
import com.alibaba.cloud.ai.mcp.nacos.service.model.NacosMcpServerEndpoint;
import com.alibaba.nacos.api.ai.model.mcp.McpEndpointInfo;
import com.alibaba.nacos.api.exception.NacosException;
import com.alibaba.nacos.api.utils.StringUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.client.autoconfigure.NamedClientMcpTransport;
import org.springframework.ai.mcp.client.autoconfigure.configurer.McpAsyncClientConfigurer;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

/* loaded from: input_file:com/alibaba/cloud/ai/mcp/nacos/client/transport/LoadbalancedMcpAsyncClient.class */
public class LoadbalancedMcpAsyncClient {
    private static final Logger logger = LoggerFactory.getLogger(LoadbalancedMcpAsyncClient.class);
    private final String serverName;
    private final NacosMcpOperationService nacosMcpOperationService;
    private final McpClientCommonProperties commonProperties;
    private final WebClient.Builder webClientBuilderTemplate;
    private final McpAsyncClientConfigurer mcpAsyncClientConfigurer;
    private final ObjectMapper objectMapper;
    private final ApplicationContext applicationContext;
    private final AtomicInteger index = new AtomicInteger(0);
    private Map<String, McpAsyncClient> keyToClientMap;
    private NacosMcpServerEndpoint serverEndpoint;

    /* loaded from: input_file:com/alibaba/cloud/ai/mcp/nacos/client/transport/LoadbalancedMcpAsyncClient$Builder.class */
    public static class Builder {
        private String serverName;
        private String version;
        private NacosMcpOperationService nacosMcpOperationService;
        private ApplicationContext applicationContext;

        public Builder serverName(String str) {
            this.serverName = str;
            return this;
        }

        public Builder version(String str) {
            this.version = str;
            return this;
        }

        public Builder nacosMcpOperationService(NacosMcpOperationService nacosMcpOperationService) {
            this.nacosMcpOperationService = nacosMcpOperationService;
            return this;
        }

        public Builder applicationContext(ApplicationContext applicationContext) {
            this.applicationContext = applicationContext;
            return this;
        }

        public LoadbalancedMcpAsyncClient build() {
            return new LoadbalancedMcpAsyncClient(this.serverName, this.version, this.nacosMcpOperationService, this.applicationContext);
        }
    }

    public LoadbalancedMcpAsyncClient(String str, String str2, NacosMcpOperationService nacosMcpOperationService, ApplicationContext applicationContext) {
        Assert.notNull(str, "serviceName cannot be null");
        Assert.notNull(str2, "version cannot be null");
        Assert.notNull(nacosMcpOperationService, "nacosMcpOperationService cannot be null");
        Assert.notNull(applicationContext, "applicationContext cannot be null");
        this.serverName = str;
        this.nacosMcpOperationService = nacosMcpOperationService;
        this.applicationContext = applicationContext;
        try {
            this.serverEndpoint = this.nacosMcpOperationService.getServerEndpoint(this.serverName, str2);
            if (this.serverEndpoint == null) {
                throw new NacosException(404, String.format("Can not find mcp server from nacos: %s", str));
            }
            if (!StringUtils.equals(this.serverEndpoint.getProtocol(), "mcp-sse")) {
                throw new RuntimeException("mcp server protocol must be sse");
            }
            this.commonProperties = (McpClientCommonProperties) this.applicationContext.getBean(McpClientCommonProperties.class);
            this.mcpAsyncClientConfigurer = (McpAsyncClientConfigurer) this.applicationContext.getBean(McpAsyncClientConfigurer.class);
            this.objectMapper = (ObjectMapper) this.applicationContext.getBean(ObjectMapper.class);
            this.webClientBuilderTemplate = (WebClient.Builder) this.applicationContext.getBean(WebClient.Builder.class);
        } catch (Exception e) {
            throw new RuntimeException(String.format("Failed to get instances for service: %s", str), e);
        }
    }

    public void init() {
        this.keyToClientMap = new ConcurrentHashMap();
        Iterator<McpEndpointInfo> it = this.serverEndpoint.getMcpEndpointInfoList().iterator();
        while (it.hasNext()) {
            updateByAddEndpoint(it.next(), this.serverEndpoint.getExportPath());
        }
    }

    public void subscribe() {
        this.nacosMcpOperationService.subscribeNacosMcpServer(this.serverName, mcpServerDetailInfo -> {
            List arrayList = mcpServerDetailInfo.getBackendEndpoints() == null ? new ArrayList() : mcpServerDetailInfo.getBackendEndpoints();
            String exportPath = mcpServerDetailInfo.getRemoteServerConfig().getExportPath();
            String protocol = mcpServerDetailInfo.getProtocol();
            NacosMcpServerEndpoint nacosMcpServerEndpoint = new NacosMcpServerEndpoint(arrayList, exportPath, protocol, mcpServerDetailInfo.getVersionDetail().getVersion());
            if (StringUtils.equals(protocol, "mcp-sse")) {
                updateClientList(nacosMcpServerEndpoint);
            }
        });
    }

    public McpAsyncClient getMcpAsyncClient() {
        List<McpAsyncClient> mcpAsyncClientList = getMcpAsyncClientList();
        if (mcpAsyncClientList.isEmpty()) {
            throw new IllegalStateException("No McpAsyncClient available");
        }
        return mcpAsyncClientList.get(this.index.getAndUpdate(i -> {
            return (i + 1) % mcpAsyncClientList.size();
        }));
    }

    public List<McpAsyncClient> getMcpAsyncClientList() {
        return this.keyToClientMap.values().stream().toList();
    }

    public String getServerName() {
        return this.serverName;
    }

    public NacosMcpServerEndpoint getNacosMcpServerEndpoint() {
        return this.serverEndpoint;
    }

    public McpSchema.ServerCapabilities getServerCapabilities() {
        return getMcpAsyncClientList().get(0).getServerCapabilities();
    }

    public McpSchema.Implementation getServerInfo() {
        return getMcpAsyncClientList().get(0).getServerInfo();
    }

    public McpSchema.ClientCapabilities getClientCapabilities() {
        return getMcpAsyncClientList().get(0).getClientCapabilities();
    }

    public McpSchema.Implementation getClientInfo() {
        return getMcpAsyncClientList().get(0).getClientInfo();
    }

    public void close() {
        Iterator<McpAsyncClient> it = getMcpAsyncClientList().iterator();
        while (it.hasNext()) {
            McpAsyncClient next = it.next();
            next.close();
            it.remove();
            logger.info("Closed and removed McpAsyncClient: {}", next.getClientInfo().name());
        }
    }

    public Mono<Void> closeGracefully() {
        Iterator<McpAsyncClient> it = getMcpAsyncClientList().iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            McpAsyncClient next = it.next();
            arrayList.add(next.closeGracefully().doOnSuccess(r6 -> {
                it.remove();
                logger.info("Closed and removed McpAsyncClient: {}", next.getClientInfo().name());
            }));
        }
        return Mono.when(arrayList);
    }

    public Mono<Object> ping() {
        return getMcpAsyncClient().ping();
    }

    public Mono<Void> addRoot(McpSchema.Root root) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.addRoot(root);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> removeRoot(String str) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.removeRoot(str);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> rootsListChangedNotification() {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map((v0) -> {
            return v0.rootsListChangedNotification();
        }).collect(Collectors.toList()));
    }

    public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
        return getMcpAsyncClient().callTool(callToolRequest);
    }

    public Mono<McpSchema.ListToolsResult> listTools() {
        return listToolsInternal(null);
    }

    public Mono<McpSchema.ListToolsResult> listTools(String str) {
        return listToolsInternal(str);
    }

    private Mono<McpSchema.ListToolsResult> listToolsInternal(String str) {
        return getMcpAsyncClient().listTools(str);
    }

    public Mono<McpSchema.ListResourcesResult> listResources() {
        return getMcpAsyncClient().listResources();
    }

    public Mono<McpSchema.ListResourcesResult> listResources(String str) {
        return getMcpAsyncClient().listResources(str);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resource) {
        return getMcpAsyncClient().readResource(resource);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
        return getMcpAsyncClient().readResource(readResourceRequest);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
        return getMcpAsyncClient().listResourceTemplates();
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String str) {
        return getMcpAsyncClient().listResourceTemplates(str);
    }

    public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.subscribeResource(subscribeRequest);
        }).collect(Collectors.toList()));
    }

    public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.unsubscribeResource(unsubscribeRequest);
        }).collect(Collectors.toList()));
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts() {
        return getMcpAsyncClient().listPrompts();
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts(String str) {
        return getMcpAsyncClient().listPrompts(str);
    }

    public Mono<McpSchema.GetPromptResult> getPrompt(McpSchema.GetPromptRequest getPromptRequest) {
        return getMcpAsyncClient().getPrompt(getPromptRequest);
    }

    public Mono<Void> setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        return Mono.when((Iterable) getMcpAsyncClientList().stream().map(mcpAsyncClient -> {
            return mcpAsyncClient.setLoggingLevel(loggingLevel);
        }).collect(Collectors.toList()));
    }

    private void updateClientList(NacosMcpServerEndpoint nacosMcpServerEndpoint) {
        if (StringUtils.equals(this.serverEndpoint.getExportPath(), nacosMcpServerEndpoint.getExportPath()) && StringUtils.equals(this.serverEndpoint.getVersion(), nacosMcpServerEndpoint.getVersion())) {
            List<McpEndpointInfo> mcpEndpointInfoList = this.serverEndpoint.getMcpEndpointInfoList();
            List<McpEndpointInfo> mcpEndpointInfoList2 = nacosMcpServerEndpoint.getMcpEndpointInfoList();
            List<McpEndpointInfo> list = mcpEndpointInfoList2.stream().filter(mcpEndpointInfo -> {
                return mcpEndpointInfoList.stream().noneMatch(mcpEndpointInfo -> {
                    return mcpEndpointInfo.getAddress().equals(mcpEndpointInfo.getAddress()) && mcpEndpointInfo.getPort() == mcpEndpointInfo.getPort();
                });
            }).toList();
            List<McpEndpointInfo> list2 = mcpEndpointInfoList.stream().filter(mcpEndpointInfo2 -> {
                return mcpEndpointInfoList2.stream().noneMatch(mcpEndpointInfo2 -> {
                    return mcpEndpointInfo2.getAddress().equals(mcpEndpointInfo2.getAddress()) && mcpEndpointInfo2.getPort() == mcpEndpointInfo2.getPort();
                });
            }).toList();
            Iterator<McpEndpointInfo> it = list.iterator();
            while (it.hasNext()) {
                updateByAddEndpoint(it.next(), nacosMcpServerEndpoint.getExportPath());
            }
            Iterator<McpEndpointInfo> it2 = list2.iterator();
            while (it2.hasNext()) {
                updateByRemoveEndpoint(it2.next(), nacosMcpServerEndpoint.getExportPath());
            }
        } else {
            updateAll(nacosMcpServerEndpoint);
        }
        this.serverEndpoint = nacosMcpServerEndpoint;
    }

    private void updateAll(NacosMcpServerEndpoint nacosMcpServerEndpoint) {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        Map<String, McpAsyncClient> map = this.keyToClientMap;
        ConcurrentHashMap concurrentHashMap2 = new ConcurrentHashMap();
        for (McpEndpointInfo mcpEndpointInfo : nacosMcpServerEndpoint.getMcpEndpointInfoList()) {
            McpAsyncClient clientByEndpoint = clientByEndpoint(mcpEndpointInfo, nacosMcpServerEndpoint.getExportPath());
            String mcpEndpointInfoId = NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, nacosMcpServerEndpoint.getExportPath());
            concurrentHashMap.putIfAbsent(mcpEndpointInfoId, clientByEndpoint);
            concurrentHashMap2.putIfAbsent(mcpEndpointInfoId, 0);
        }
        this.keyToClientMap = concurrentHashMap;
        Iterator<Map.Entry<String, McpAsyncClient>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            McpAsyncClient value = it.next().getValue();
            logger.info("Removing McpAsyncClient: {}", value.getClientInfo().name());
            value.closeGracefully().block();
            logger.info("Removed McpAsyncClient: {} Success", value.getClientInfo().name());
        }
    }

    private McpAsyncClient clientByEndpoint(McpEndpointInfo mcpEndpointInfo, String str) {
        NamedClientMcpTransport namedClientMcpTransport = new NamedClientMcpTransport(this.serverName + "-" + NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, str), new WebFluxSseClientTransport(this.webClientBuilderTemplate.clone().baseUrl("http://" + mcpEndpointInfo.getAddress() + ":" + mcpEndpointInfo.getPort()), this.objectMapper, str));
        McpSchema.Implementation implementation = new McpSchema.Implementation(connectedClientName(this.commonProperties.getName(), namedClientMcpTransport.name()), this.commonProperties.getVersion());
        McpAsyncClient build = this.mcpAsyncClientConfigurer.configure(namedClientMcpTransport.name(), McpClient.async(namedClientMcpTransport.transport()).clientInfo(implementation).requestTimeout(this.commonProperties.getRequestTimeout())).build();
        if (this.commonProperties.isInitialized()) {
            build.initialize().block();
        }
        logger.info("Added McpAsyncClient: {}", implementation.name());
        return build;
    }

    private void updateByAddEndpoint(McpEndpointInfo mcpEndpointInfo, String str) {
        McpAsyncClient clientByEndpoint = clientByEndpoint(mcpEndpointInfo, str);
        this.keyToClientMap.putIfAbsent(NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, str), clientByEndpoint);
    }

    private void updateByRemoveEndpoint(McpEndpointInfo mcpEndpointInfo, String str) {
        String mcpEndpointInfoId = NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, str);
        if (this.keyToClientMap.containsKey(mcpEndpointInfoId)) {
            McpAsyncClient remove = this.keyToClientMap.remove(mcpEndpointInfoId);
            logger.info("Removing McpAsyncClient: {}", remove.getClientInfo().name());
            remove.closeGracefully().block();
            logger.info("Removed McpAsyncClient: {} Success", remove.getClientInfo().name());
        }
    }

    private String connectedClientName(String str, String str2) {
        return str + " - " + str2;
    }

    public static Builder builder() {
        return new Builder();
    }
}
