package org.springframework.graphql.server.webmvc;

import graphql.GraphqlErrorBuilder;
import io.micrometer.context.ContextSnapshot;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.aot.hint.annotation.RegisterReflectionForBinding;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionPublisherException;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

@RegisterReflectionForBinding({GraphQlWebSocketMessage.class})
/* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.class */
public class GraphQlWebSocketHandler extends TextWebSocketHandler implements SubProtocolCapable {
    private static final Log logger = LogFactory.getLog(GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");
    private final WebGraphQlHandler graphQlHandler;
    private final ContextHandshakeInterceptor contextHandshakeInterceptor;
    private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor;
    private final Duration initTimeoutDuration;
    private final HttpMessageConverter<?> converter;
    private final Map<String, SessionState> sessionInfoMap = new ConcurrentHashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$ContextHandshakeInterceptor.class */
    public static class ContextHandshakeInterceptor implements HandshakeInterceptor {
        private static final String KEY = ContextSnapshot.class.getName();

        private ContextHandshakeInterceptor() {
        }

        public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> map) {
            map.put(KEY, ContextSnapshot.captureAll(new Object[0]));
            return true;
        }

        public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, @Nullable Exception exc) {
        }

        public static AutoCloseable setThreadLocals(WebSocketSession webSocketSession) {
            ContextSnapshot contextSnapshot = (ContextSnapshot) webSocketSession.getAttributes().get(KEY);
            Assert.notNull(contextSnapshot, "Expected ContextSnapshot in WebSocketSession attributes");
            return contextSnapshot.setThreadLocals();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$GraphQlStatus.class */
    public static class GraphQlStatus {
        private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        private static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        private static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        private static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        static void closeSession(WebSocketSession webSocketSession, CloseStatus closeStatus) {
            try {
                webSocketSession.close(closeStatus);
            } catch (IOException e) {
                if (GraphQlWebSocketHandler.logger.isDebugEnabled()) {
                    GraphQlWebSocketHandler.logger.debug("Error while closing session with status: " + closeStatus, e);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$HttpInputMessageAdapter.class */
    public static class HttpInputMessageAdapter extends ByteArrayInputStream implements HttpInputMessage {
        HttpInputMessageAdapter(TextMessage textMessage) {
            super(textMessage.asBytes());
        }

        public InputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return HttpHeaders.EMPTY;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$HttpOutputMessageAdapter.class */
    public static class HttpOutputMessageAdapter extends ByteArrayOutputStream implements HttpOutputMessage {
        private static final HttpHeaders noOpHeaders = new HttpHeaders();

        private HttpOutputMessageAdapter() {
        }

        public OutputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return noOpHeaders;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$SendMessageSubscriber.class */
    public static class SendMessageSubscriber extends BaseSubscriber<TextMessage> {
        private final String subscriptionId;
        private final WebSocketSession session;
        private final SessionState sessionState;

        SendMessageSubscriber(String str, WebSocketSession webSocketSession, SessionState sessionState) {
            this.subscriptionId = str;
            this.session = webSocketSession;
            this.sessionState = sessionState;
        }

        protected void hookOnSubscribe(Subscription subscription) {
            subscription.request(1L);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void hookOnNext(TextMessage textMessage) {
            try {
                this.session.sendMessage(textMessage);
                request(1L);
            } catch (IOException e) {
                ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, e, GraphQlWebSocketHandler.logger);
            }
        }

        public void hookOnError(Throwable th) {
            ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, th, GraphQlWebSocketHandler.logger);
        }

        public void hookOnComplete() {
            this.sessionState.getSubscriptions().remove(this.subscriptionId);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$SessionState.class */
    public static class SessionState {
        private final WebSocketSessionInfo sessionInfo;
        private final AtomicReference<Map<String, Object>> connectionInitPayloadRef = new AtomicReference<>();
        private final Map<String, Subscription> subscriptions = new ConcurrentHashMap();
        private final Scheduler scheduler;

        SessionState(String str, WebSocketSessionInfo webSocketSessionInfo) {
            this.sessionInfo = webSocketSessionInfo;
            this.scheduler = Schedulers.newSingle("GraphQL-WsSession-" + str);
        }

        public WebSocketSessionInfo getSessionInfo() {
            return this.sessionInfo;
        }

        @Nullable
        Map<String, Object> getConnectionInitPayload() {
            return this.connectionInitPayloadRef.get();
        }

        boolean setConnectionInitPayload(Map<String, Object> map) {
            return this.connectionInitPayloadRef.compareAndSet(null, map);
        }

        Map<String, Subscription> getSubscriptions() {
            return this.subscriptions;
        }

        void dispose() {
            Iterator<Map.Entry<String, Subscription>> it = this.subscriptions.entrySet().iterator();
            while (it.hasNext()) {
                try {
                    it.next().getValue().cancel();
                } catch (Throwable th) {
                }
            }
            this.subscriptions.clear();
            this.scheduler.dispose();
        }

        Scheduler getScheduler() {
            return this.scheduler;
        }
    }

    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$SubscriptionExistsException.class */
    private static class SubscriptionExistsException extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }

    /* loaded from: input_file:org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler$WebMvcSessionInfo.class */
    private static class WebMvcSessionInfo implements WebSocketSessionInfo {
        private final WebSocketSession session;

        private WebMvcSessionInfo(WebSocketSession webSocketSession) {
            this.session = webSocketSession;
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public String getId() {
            return this.session.getId();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public Map<String, Object> getAttributes() {
            return this.session.getAttributes();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public URI getUri() {
            Assert.notNull(this.session.getUri(), "Expected URI");
            return this.session.getUri();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public HttpHeaders getHeaders() {
            return this.session.getHandshakeHeaders();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public Mono<Principal> getPrincipal() {
            return Mono.justOrEmpty(this.session.getPrincipal());
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public InetSocketAddress getRemoteAddress() {
            return this.session.getRemoteAddress();
        }
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler webGraphQlHandler, HttpMessageConverter<?> httpMessageConverter, Duration duration) {
        Assert.notNull(webGraphQlHandler, "WebGraphQlHandler is required");
        Assert.notNull(httpMessageConverter, "HttpMessageConverter for JSON is required");
        this.graphQlHandler = webGraphQlHandler;
        this.contextHandshakeInterceptor = new ContextHandshakeInterceptor();
        this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor();
        this.initTimeoutDuration = duration;
        this.converter = httpMessageConverter;
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public WebSocketHttpRequestHandler initWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) {
        WebSocketHttpRequestHandler webSocketHttpRequestHandler = new WebSocketHttpRequestHandler(this, handshakeHandler);
        webSocketHttpRequestHandler.setHandshakeInterceptors(Collections.singletonList(this.contextHandshakeInterceptor));
        return webSocketHttpRequestHandler;
    }

    @Deprecated
    public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) {
        return initWebSocketHttpRequestHandler(handshakeHandler);
    }

    public void afterConnectionEstablished(WebSocketSession webSocketSession) {
        if ("graphql-ws".equalsIgnoreCase(webSocketSession.getAcceptedProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
        } else {
            SessionState sessionState = new SessionState(webSocketSession.getId(), new WebMvcSessionInfo(webSocketSession));
            this.sessionInfoMap.put(webSocketSession.getId(), sessionState);
            Mono.delay(this.initTimeoutDuration).then(Mono.fromRunnable(() -> {
                if (sessionState.setConnectionInitPayload(Collections.emptyMap())) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INIT_TIMEOUT_STATUS);
                }
            })).subscribe();
        }
    }

    protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage textMessage) throws Exception {
        AutoCloseable threadLocals = ContextHandshakeInterceptor.setThreadLocals(webSocketSession);
        try {
            handleInternal(webSocketSession, textMessage);
            if (threadLocals != null) {
                threadLocals.close();
            }
        } catch (Throwable th) {
            if (threadLocals != null) {
                try {
                    threadLocals.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void handleInternal(WebSocketSession webSocketSession, TextMessage textMessage) throws IOException {
        GraphQlWebSocketMessage decode = decode(textMessage);
        String id = decode.getId();
        Map<String, Object> map = (Map) decode.getPayload();
        SessionState sessionInfo = getSessionInfo(webSocketSession);
        switch (decode.resolvedType()) {
            case SUBSCRIBE:
                if (sessionInfo.getConnectionInitPayload() == null) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    return;
                }
                if (id == null) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    return;
                }
                URI uri = webSocketSession.getUri();
                Assert.notNull(uri, "Expected handshake url");
                WebSocketGraphQlRequest webSocketGraphQlRequest = new WebSocketGraphQlRequest(uri, webSocketSession.getHandshakeHeaders(), map, id, null, sessionInfo.getSessionInfo());
                if (logger.isDebugEnabled()) {
                    logger.debug("Executing: " + webSocketGraphQlRequest);
                }
                this.graphQlHandler.handleRequest(webSocketGraphQlRequest).flatMapMany(webGraphQlResponse -> {
                    return handleResponse(webSocketSession, webSocketGraphQlRequest.getId(), webGraphQlResponse);
                }).publishOn(sessionInfo.getScheduler()).subscribe(new SendMessageSubscriber(id, webSocketSession, sessionInfo));
                return;
            case PING:
                webSocketSession.sendMessage(encode(GraphQlWebSocketMessage.pong(null)));
                return;
            case COMPLETE:
                if (id != null) {
                    Subscription remove = sessionInfo.getSubscriptions().remove(id);
                    if (remove != null) {
                        remove.cancel();
                    }
                    this.webSocketGraphQlInterceptor.handleCancelledSubscription(sessionInfo.getSessionInfo(), id).block(Duration.ofSeconds(10L));
                    return;
                }
                return;
            case CONNECTION_INIT:
                if (sessionInfo.setConnectionInitPayload(map)) {
                    this.webSocketGraphQlInterceptor.handleConnectionInitialization(sessionInfo.getSessionInfo(), map).defaultIfEmpty(Collections.emptyMap()).publishOn(sessionInfo.getScheduler()).doOnNext(obj -> {
                        try {
                            webSocketSession.sendMessage(encode(GraphQlWebSocketMessage.connectionAck(obj)));
                        } catch (IOException e) {
                            throw new IllegalStateException(e);
                        }
                    }).onErrorResume(th -> {
                        GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                        return Mono.empty();
                    }).block(Duration.ofSeconds(10L));
                    return;
                } else {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                    return;
                }
            default:
                GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                return;
        }
    }

    private GraphQlWebSocketMessage decode(TextMessage textMessage) throws IOException {
        return (GraphQlWebSocketMessage) this.converter.read(GraphQlWebSocketMessage.class, (Class) null, new HttpInputMessageAdapter(textMessage));
    }

    private SessionState getSessionInfo(WebSocketSession webSocketSession) {
        SessionState sessionState = this.sessionInfoMap.get(webSocketSession.getId());
        Assert.notNull(sessionState, "No SessionInfo for " + webSocketSession);
        return sessionState;
    }

    private Flux<TextMessage> handleResponse(WebSocketSession webSocketSession, String str, WebGraphQlResponse webGraphQlResponse) {
        if (logger.isDebugEnabled()) {
            logger.debug("Execution result ready" + (!CollectionUtils.isEmpty(webGraphQlResponse.getErrors()) ? " with errors: " + webGraphQlResponse.getErrors() : "") + ".");
        }
        return (webGraphQlResponse.getData() instanceof Publisher ? Flux.from((Publisher) webGraphQlResponse.getData()).map((v0) -> {
            return v0.toSpecification();
        }).doOnSubscribe(subscription -> {
            if (getSessionInfo(webSocketSession).getSubscriptions().putIfAbsent(str, subscription) != null) {
                throw new SubscriptionExistsException();
            }
        }) : Flux.just(webGraphQlResponse.toMap())).map(map -> {
            return encode(GraphQlWebSocketMessage.next(str, map));
        }).concatWith(Mono.fromCallable(() -> {
            return encode(GraphQlWebSocketMessage.complete(str));
        })).onErrorResume(th -> {
            if (!(th instanceof SubscriptionExistsException)) {
                return Mono.just(encode(GraphQlWebSocketMessage.error(str, th instanceof SubscriptionPublisherException ? ((SubscriptionPublisherException) th).getErrors() : Collections.singletonList(GraphqlErrorBuilder.newError().message("Subscription error", new Object[0]).errorType(ErrorType.INTERNAL_ERROR).build()))));
            }
            GraphQlStatus.closeSession(webSocketSession, new CloseStatus(4409, "Subscriber for " + str + " already exists"));
            return Flux.empty();
        });
    }

    private <T> TextMessage encode(GraphQlWebSocketMessage graphQlWebSocketMessage) {
        try {
            HttpOutputMessageAdapter httpOutputMessageAdapter = new HttpOutputMessageAdapter();
            this.converter.write(graphQlWebSocketMessage, (MediaType) null, httpOutputMessageAdapter);
            return new TextMessage(httpOutputMessageAdapter.toByteArray());
        } catch (IOException e) {
            throw new IllegalStateException("Failed to write " + graphQlWebSocketMessage + " as JSON", e);
        }
    }

    public void handleTransportError(WebSocketSession webSocketSession, Throwable th) {
        SessionState remove = this.sessionInfoMap.remove(webSocketSession.getId());
        if (remove != null) {
            remove.dispose();
        }
    }

    public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) {
        SessionState remove = this.sessionInfoMap.remove(webSocketSession.getId());
        if (remove != null) {
            remove.dispose();
            Map<String, Object> connectionInitPayload = remove.getConnectionInitPayload();
            if (connectionInitPayload != null) {
                this.webSocketGraphQlInterceptor.handleConnectionClosed(remove.getSessionInfo(), closeStatus.getCode(), connectionInitPayload);
            }
        }
    }

    public boolean supportsPartialMessages() {
        return false;
    }
}
