package dev.langchain4j.mcp.client.transport.http;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.protocol.InitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/mcp/client/transport/http/HttpMcpTransport.class */
public class HttpMcpTransport implements McpTransport {
    private final String sseUrl;
    private final OkHttpClient client;
    private final boolean logResponses;
    private final boolean logRequests;
    private EventSource mcpSseEventListener;
    private volatile String postUrl;
    private volatile McpOperationHandler messageHandler;
    private static final Logger log = LoggerFactory.getLogger(HttpMcpTransport.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

    /* loaded from: input_file:dev/langchain4j/mcp/client/transport/http/HttpMcpTransport$Builder.class */
    public static class Builder {
        private String sseUrl;
        private Duration timeout;
        private boolean logRequests = false;
        private boolean logResponses = false;

        public Builder sseUrl(String str) {
            this.sseUrl = str;
            return this;
        }

        public Builder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public Builder logRequests(boolean z) {
            this.logRequests = z;
            return this;
        }

        public Builder logResponses(boolean z) {
            this.logResponses = z;
            return this;
        }

        public HttpMcpTransport build() {
            return new HttpMcpTransport(this);
        }
    }

    public HttpMcpTransport(Builder builder) {
        OkHttpClient.Builder builder2 = new OkHttpClient.Builder();
        Duration duration = (Duration) Utils.getOrDefault(builder.timeout, Duration.ofSeconds(60L));
        builder2.callTimeout(duration);
        builder2.connectTimeout(duration);
        builder2.readTimeout(duration);
        builder2.writeTimeout(duration);
        this.logRequests = builder.logRequests;
        if (builder.logRequests) {
            builder2.addInterceptor(new McpRequestLoggingInterceptor());
        }
        this.logResponses = builder.logResponses;
        this.sseUrl = (String) ValidationUtils.ensureNotNull(builder.sseUrl, "Missing SSE endpoint URL");
        this.client = builder2.build();
    }

    @Override // dev.langchain4j.mcp.client.transport.McpTransport
    public void start(McpOperationHandler mcpOperationHandler) {
        this.messageHandler = mcpOperationHandler;
        this.mcpSseEventListener = startSseChannel(this.logResponses);
    }

    @Override // dev.langchain4j.mcp.client.transport.McpTransport
    public CompletableFuture<JsonNode> initialize(McpInitializeRequest mcpInitializeRequest) {
        try {
            Request createRequest = createRequest(mcpInitializeRequest);
            Request createRequest2 = createRequest(new InitializationNotification());
            return execute(createRequest, mcpInitializeRequest.getId()).thenCompose(jsonNode -> {
                return execute(createRequest2, null).thenCompose(jsonNode -> {
                    return CompletableFuture.completedFuture(jsonNode);
                });
            });
        } catch (JsonProcessingException e) {
            return CompletableFuture.failedFuture(e);
        }
    }

    @Override // dev.langchain4j.mcp.client.transport.McpTransport
    public CompletableFuture<JsonNode> executeOperationWithResponse(McpClientMessage mcpClientMessage) {
        try {
            return execute(createRequest(mcpClientMessage), mcpClientMessage.getId());
        } catch (JsonProcessingException e) {
            return CompletableFuture.failedFuture(e);
        }
    }

    @Override // dev.langchain4j.mcp.client.transport.McpTransport
    public void executeOperationWithoutResponse(McpClientMessage mcpClientMessage) {
        try {
            execute(createRequest(mcpClientMessage), null);
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    @Override // dev.langchain4j.mcp.client.transport.McpTransport
    public void checkHealth() {
    }

    private CompletableFuture<JsonNode> execute(Request request, final Long l) {
        final CompletableFuture<JsonNode> completableFuture = new CompletableFuture<>();
        if (l != null) {
            this.messageHandler.startOperation(l, completableFuture);
        }
        this.client.newCall(request).enqueue(new Callback() { // from class: dev.langchain4j.mcp.client.transport.http.HttpMcpTransport.1
            public void onFailure(Call call, IOException iOException) {
                completableFuture.completeExceptionally(iOException);
            }

            public void onResponse(Call call, Response response) throws IOException {
                int code = response.code();
                if (!HttpMcpTransport.this.isExpectedStatusCode(code)) {
                    completableFuture.completeExceptionally(new RuntimeException("Unexpected status code: " + code));
                }
                if (l == null) {
                    completableFuture.complete(null);
                }
            }
        });
        return completableFuture;
    }

    private boolean isExpectedStatusCode(int i) {
        return i >= 200 && i < 300;
    }

    private EventSource startSseChannel(boolean z) {
        Request build = new Request.Builder().url(this.sseUrl).build();
        CompletableFuture completableFuture = new CompletableFuture();
        EventSource newEventSource = EventSources.createFactory(this.client).newEventSource(build, new SseEventListener(this.messageHandler, z, completableFuture));
        try {
            this.postUrl = buildAbsolutePostUrl((String) completableFuture.get(this.client.callTimeoutMillis() > 0 ? this.client.callTimeoutMillis() : Integer.MAX_VALUE, TimeUnit.MILLISECONDS));
            log.debug("Received the server's POST URL: {}", this.postUrl);
            return newEventSource;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String buildAbsolutePostUrl(String str) {
        try {
            return URI.create(this.sseUrl).resolve(str).toString();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Request createRequest(McpClientMessage mcpClientMessage) throws JsonProcessingException {
        return new Request.Builder().url(this.postUrl).header("Content-Type", "application/json").post(RequestBody.create(OBJECT_MAPPER.writeValueAsBytes(mcpClientMessage))).build();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.mcpSseEventListener != null) {
            this.mcpSseEventListener.cancel();
        }
        if (this.client != null) {
            this.client.dispatcher().executorService().shutdown();
        }
    }
}
