/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.util.Assert;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class McpClientSession
implements McpSession {
    private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class);
    private final Duration requestTimeout;
    private final McpClientTransport transport;
    private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
    private final ConcurrentHashMap<String, RequestHandler<?>> requestHandlers = new ConcurrentHashMap();
    private final ConcurrentHashMap<String, NotificationHandler> notificationHandlers = new ConcurrentHashMap();
    private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8);
    private final AtomicLong requestCounter = new AtomicLong(0L);

    @Deprecated
    public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
        this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity());
    }

    public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
        Assert.notNull(requestTimeout, "The requestTimeout can not be null");
        Assert.notNull(transport, "The transport can not be null");
        Assert.notNull(requestHandlers, "The requestHandlers can not be null");
        Assert.notNull(notificationHandlers, "The notificationHandlers can not be null");
        this.requestTimeout = requestTimeout;
        this.transport = transport;
        this.requestHandlers.putAll(requestHandlers);
        this.notificationHandlers.putAll(notificationHandlers);
        this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
    }

    private void dismissPendingResponses() {
        this.pendingResponses.forEach((id, sink) -> {
            logger.warn("Abruptly terminating exchange for request {}", id);
            sink.error((Throwable)new RuntimeException("MCP session with server terminated"));
        });
        this.pendingResponses.clear();
    }

    private void handle(McpSchema.JSONRPCMessage message) {
        if (message instanceof McpSchema.JSONRPCResponse) {
            McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse)message;
            logger.debug("Received Response: {}", (Object)response);
            MonoSink<McpSchema.JSONRPCResponse> sink = this.pendingResponses.remove(response.id());
            if (sink == null) {
                logger.warn("Unexpected response for unknown id {}", response.id());
            } else {
                sink.success((Object)response);
            }
        } else if (message instanceof McpSchema.JSONRPCRequest) {
            McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest)message;
            logger.debug("Received request: {}", (Object)request);
            this.handleIncomingRequest(request).onErrorResume(error -> {
                McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), null));
                return Mono.just((Object)errorResponse);
            }).flatMap(this.transport::sendMessage).onErrorComplete(t -> {
                logger.warn("Issue sending response to the client, ", t);
                return true;
            }).subscribe();
        } else if (message instanceof McpSchema.JSONRPCNotification) {
            McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification)message;
            logger.debug("Received notification: {}", (Object)notification);
            this.handleIncomingNotification(notification).onErrorComplete(t -> {
                logger.error("Error handling notification: {}", (Object)t.getMessage());
                return true;
            }).subscribe();
        } else {
            logger.warn("Received unknown message type: {}", (Object)message);
        }
    }

    private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
        return Mono.defer(() -> {
            RequestHandler<?> handler = this.requestHandlers.get(request.method());
            if (handler == null) {
                MethodNotFoundError error = this.getMethodNotFoundError(request.method());
                return Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error.message(), error.data())));
            }
            return handler.handle(request.params()).map(result -> new McpSchema.JSONRPCResponse("2.0", request.id(), result, null));
        });
    }

    private MethodNotFoundError getMethodNotFoundError(String method) {
        switch (method) {
            case "roots/list": {
                return new MethodNotFoundError(method, "Roots not supported", Map.of("reason", "Client does not have roots capability"));
            }
        }
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
        return Mono.defer(() -> {
            NotificationHandler handler = this.notificationHandlers.get(notification.method());
            if (handler == null) {
                logger.warn("No handler registered for notification method: {}", (Object)notification);
                return Mono.empty();
            }
            return handler.handle(notification.params());
        });
    }

    private String generateRequestId() {
        return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
        String requestId = this.generateRequestId();
        return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> {
            logger.debug("Sending message for method {}", (Object)method);
            this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)pendingResponseSink);
            McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
            this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> {}, error -> {
                this.pendingResponses.remove(requestId);
                pendingResponseSink.error(error);
            });
        })).timeout(this.requestTimeout).handle((jsonRpcResponse, deliveredResponseSink) -> {
            if (jsonRpcResponse.error() != null) {
                logger.error("Error handling request: {}", (Object)jsonRpcResponse.error());
                deliveredResponseSink.error((Throwable)new McpError(jsonRpcResponse.error()));
            } else if (typeRef.getType().equals(Void.class)) {
                deliveredResponseSink.complete();
            } else {
                deliveredResponseSink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef));
            }
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Object params) {
        McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
        return this.transport.sendMessage(jsonrpcNotification);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(this::dismissPendingResponses);
    }

    @Override
    public void close() {
        this.dismissPendingResponses();
    }

    record MethodNotFoundError(String method, String message, Object data) {
    }

    @FunctionalInterface
    public static interface NotificationHandler {
        public Mono<Void> handle(Object var1);
    }

    @FunctionalInterface
    public static interface RequestHandler<T> {
        public Mono<T> handle(Object var1);
    }
}

