/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.graphql.server.webmvc;

import graphql.ErrorClassification;
import graphql.ExecutionResult;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionPublisherException;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

public class GraphQlWebSocketHandler
extends TextWebSocketHandler
implements SubProtocolCapable {
    private static final Log logger = LogFactory.getLog(GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");
    private final WebGraphQlHandler graphQlHandler;
    private final ContextHandshakeInterceptor contextHandshakeInterceptor;
    private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor;
    private final Duration initTimeoutDuration;
    private final HttpMessageConverter<?> converter;
    private final Map<String, SessionState> sessionInfoMap = new ConcurrentHashMap<String, SessionState>();

    public GraphQlWebSocketHandler(WebGraphQlHandler graphQlHandler, HttpMessageConverter<?> converter, Duration connectionInitTimeout) {
        Assert.notNull((Object)graphQlHandler, (String)"WebGraphQlHandler is required");
        Assert.notNull(converter, (String)"HttpMessageConverter for JSON is required");
        this.graphQlHandler = graphQlHandler;
        this.contextHandshakeInterceptor = new ContextHandshakeInterceptor(graphQlHandler.getThreadLocalAccessor());
        this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor();
        this.initTimeoutDuration = connectionInitTimeout;
        this.converter = converter;
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) {
        WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler((WebSocketHandler)this, handshakeHandler);
        handler.setHandshakeInterceptors(Collections.singletonList(this.contextHandshakeInterceptor));
        return handler;
    }

    public void afterConnectionEstablished(WebSocketSession session) {
        if ("graphql-ws".equalsIgnoreCase(session.getAcceptedProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug((Object)"apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
            return;
        }
        SessionState sessionState = new SessionState(session.getId(), new WebMvcSessionInfo(session));
        this.sessionInfoMap.put(session.getId(), sessionState);
        Mono.delay((Duration)this.initTimeoutDuration).then(Mono.fromRunnable(() -> {
            if (sessionState.setConnectionInitPayload(Collections.emptyMap())) {
                GraphQlStatus.closeSession(session, GraphQlStatus.INIT_TIMEOUT_STATUS);
            }
        })).subscribe();
    }

    protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
        try (Closeable closeable = this.contextHandshakeInterceptor.restoreThreadLocalValue(session);){
            this.handleInternal(session, webSocketMessage);
        }
    }

    private void handleInternal(WebSocketSession session, TextMessage webSocketMessage) throws IOException {
        GraphQlWebSocketMessage message = this.decode(webSocketMessage);
        String id = message.getId();
        Map payload = (Map)message.getPayload();
        SessionState state = this.getSessionInfo(session);
        switch (message.resolvedType()) {
            case SUBSCRIBE: {
                if (state.getConnectionInitPayload() == null) {
                    GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                    return;
                }
                if (id == null) {
                    GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    return;
                }
                URI uri = session.getUri();
                Assert.notNull((Object)uri, (String)"Expected handshake url");
                HttpHeaders headers = session.getHandshakeHeaders();
                WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(uri, headers, (Map<String, Object>)payload, id, null, state.getSessionInfo());
                if (logger.isDebugEnabled()) {
                    logger.debug((Object)("Executing: " + request));
                }
                this.graphQlHandler.handleRequest(request).flatMapMany(response -> this.handleResponse(session, request.getId(), (WebGraphQlResponse)response)).publishOn(state.getScheduler()).subscribe((CoreSubscriber)new SendMessageSubscriber(id, session, state));
                return;
            }
            case PING: {
                session.sendMessage((WebSocketMessage)this.encode(GraphQlWebSocketMessage.pong(null)));
                return;
            }
            case COMPLETE: {
                if (id != null) {
                    Subscription subscription = state.getSubscriptions().remove(id);
                    if (subscription != null) {
                        subscription.cancel();
                    }
                    this.webSocketGraphQlInterceptor.handleCancelledSubscription(state.getSessionInfo(), id).block(Duration.ofSeconds(10L));
                }
                return;
            }
            case CONNECTION_INIT: {
                if (!state.setConnectionInitPayload(payload)) {
                    GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                    return;
                }
                this.webSocketGraphQlInterceptor.handleConnectionInitialization(state.getSessionInfo(), payload).defaultIfEmpty(Collections.emptyMap()).publishOn(state.getScheduler()).doOnNext(ackPayload -> {
                    TextMessage outputMessage = this.encode(GraphQlWebSocketMessage.connectionAck(ackPayload));
                    try {
                        session.sendMessage((WebSocketMessage)outputMessage);
                    }
                    catch (IOException ex) {
                        throw new IllegalStateException(ex);
                    }
                }).onErrorResume(ex -> {
                    GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                    return Mono.empty();
                }).block(Duration.ofSeconds(10L));
                return;
            }
        }
        GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
    }

    private GraphQlWebSocketMessage decode(TextMessage message) throws IOException {
        return (GraphQlWebSocketMessage)((GenericHttpMessageConverter)this.converter).read(GraphQlWebSocketMessage.class, null, (HttpInputMessage)new HttpInputMessageAdapter(message));
    }

    private SessionState getSessionInfo(WebSocketSession session) {
        SessionState info = this.sessionInfoMap.get(session.getId());
        Assert.notNull((Object)info, (String)("No SessionInfo for " + session));
        return info;
    }

    private Flux<TextMessage> handleResponse(WebSocketSession session, String id, WebGraphQlResponse response) {
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Execution result ready" + (!CollectionUtils.isEmpty(response.getErrors()) ? " with errors: " + response.getErrors() : "") + "."));
        }
        Flux responseFlux = response.getData() instanceof Publisher ? Flux.from((Publisher)((Publisher)response.getData())).map(ExecutionResult::toSpecification).doOnSubscribe(subscription -> {
            Subscription prev = this.getSessionInfo(session).getSubscriptions().putIfAbsent(id, (Subscription)subscription);
            if (prev != null) {
                throw new SubscriptionExistsException();
            }
        }) : Flux.just(response.toMap());
        return responseFlux.map(responseMap -> this.encode(GraphQlWebSocketMessage.next(id, responseMap))).concatWith((Publisher)Mono.fromCallable(() -> this.encode(GraphQlWebSocketMessage.complete(id)))).onErrorResume(ex -> {
            if (ex instanceof SubscriptionExistsException) {
                CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
                GraphQlStatus.closeSession(session, status);
                return Flux.empty();
            }
            List<GraphQLError> errors = ex instanceof SubscriptionPublisherException ? ((SubscriptionPublisherException)((Object)((Object)ex))).getErrors() : Collections.singletonList(GraphqlErrorBuilder.newError().message("Subscription error", new Object[0]).errorType((ErrorClassification)ErrorType.INTERNAL_ERROR).build());
            return Mono.just((Object)this.encode(GraphQlWebSocketMessage.error(id, errors)));
        });
    }

    private <T> TextMessage encode(GraphQlWebSocketMessage message) {
        try {
            HttpOutputMessageAdapter outputMessage = new HttpOutputMessageAdapter();
            this.converter.write((Object)message, null, (HttpOutputMessage)outputMessage);
            return new TextMessage(outputMessage.toByteArray());
        }
        catch (IOException ex) {
            throw new IllegalStateException("Failed to write " + message + " as JSON", ex);
        }
    }

    public void handleTransportError(WebSocketSession session, Throwable exception) {
        SessionState info = this.sessionInfoMap.remove(session.getId());
        if (info != null) {
            info.dispose();
        }
    }

    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) {
        String id = session.getId();
        SessionState state = this.sessionInfoMap.remove(id);
        if (state != null) {
            state.dispose();
            Map<String, Object> connectionInitPayload = state.getConnectionInitPayload();
            if (connectionInitPayload != null) {
                this.webSocketGraphQlInterceptor.handleConnectionClosed(state.getSessionInfo(), closeStatus.getCode(), connectionInitPayload);
            }
        }
    }

    public boolean supportsPartialMessages() {
        return false;
    }

    private static class SubscriptionExistsException
    extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }

    private static class SendMessageSubscriber
    extends BaseSubscriber<TextMessage> {
        private final String subscriptionId;
        private final WebSocketSession session;
        private final SessionState sessionState;

        SendMessageSubscriber(String subscriptionId, WebSocketSession session, SessionState sessionState) {
            this.subscriptionId = subscriptionId;
            this.session = session;
            this.sessionState = sessionState;
        }

        protected void hookOnSubscribe(Subscription subscription) {
            subscription.request(1L);
        }

        protected void hookOnNext(TextMessage nextMessage) {
            try {
                this.session.sendMessage((WebSocketMessage)nextMessage);
                this.request(1L);
            }
            catch (IOException ex) {
                ExceptionWebSocketHandlerDecorator.tryCloseWithError((WebSocketSession)this.session, (Throwable)ex, (Log)logger);
            }
        }

        public void hookOnError(Throwable ex) {
            ExceptionWebSocketHandlerDecorator.tryCloseWithError((WebSocketSession)this.session, (Throwable)ex, (Log)logger);
        }

        public void hookOnComplete() {
            this.sessionState.getSubscriptions().remove(this.subscriptionId);
        }
    }

    private static class WebMvcSessionInfo
    implements WebSocketSessionInfo {
        private final WebSocketSession session;

        private WebMvcSessionInfo(WebSocketSession session) {
            this.session = session;
        }

        @Override
        public String getId() {
            return this.session.getId();
        }

        @Override
        public Map<String, Object> getAttributes() {
            return this.session.getAttributes();
        }

        @Override
        public URI getUri() {
            Assert.notNull((Object)this.session.getUri(), (String)"Expected URI");
            return this.session.getUri();
        }

        @Override
        public HttpHeaders getHeaders() {
            return this.session.getHandshakeHeaders();
        }

        @Override
        public Mono<Principal> getPrincipal() {
            return Mono.justOrEmpty((Object)this.session.getPrincipal());
        }

        @Override
        public InetSocketAddress getRemoteAddress() {
            return this.session.getRemoteAddress();
        }
    }

    private static class SessionState {
        private final WebSocketSessionInfo sessionInfo;
        private final AtomicReference<Map<String, Object>> connectionInitPayloadRef = new AtomicReference();
        private final Map<String, Subscription> subscriptions = new ConcurrentHashMap<String, Subscription>();
        private final Scheduler scheduler;

        SessionState(String graphQlSessionId, WebSocketSessionInfo sessionInfo) {
            this.sessionInfo = sessionInfo;
            this.scheduler = Schedulers.newSingle((String)("GraphQL-WsSession-" + graphQlSessionId));
        }

        public WebSocketSessionInfo getSessionInfo() {
            return this.sessionInfo;
        }

        @Nullable
        Map<String, Object> getConnectionInitPayload() {
            return this.connectionInitPayloadRef.get();
        }

        boolean setConnectionInitPayload(Map<String, Object> payload) {
            return this.connectionInitPayloadRef.compareAndSet(null, payload);
        }

        Map<String, Subscription> getSubscriptions() {
            return this.subscriptions;
        }

        void dispose() {
            for (Map.Entry<String, Subscription> entry : this.subscriptions.entrySet()) {
                try {
                    entry.getValue().cancel();
                }
                catch (Throwable throwable) {}
            }
            this.subscriptions.clear();
            this.scheduler.dispose();
        }

        Scheduler getScheduler() {
            return this.scheduler;
        }
    }

    private static class HttpOutputMessageAdapter
    extends ByteArrayOutputStream
    implements HttpOutputMessage {
        private static final HttpHeaders noOpHeaders = new HttpHeaders();

        private HttpOutputMessageAdapter() {
        }

        public OutputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return noOpHeaders;
        }
    }

    private static class HttpInputMessageAdapter
    extends ByteArrayInputStream
    implements HttpInputMessage {
        HttpInputMessageAdapter(TextMessage message) {
            super(message.asBytes());
        }

        public InputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return HttpHeaders.EMPTY;
        }
    }

    private static class GraphQlStatus {
        private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        private static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        private static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        private static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        static void closeSession(WebSocketSession session, CloseStatus status) {
            block2: {
                try {
                    session.close(status);
                }
                catch (IOException ex) {
                    if (!logger.isDebugEnabled()) break block2;
                    logger.debug((Object)("Error while closing session with status: " + status), (Throwable)ex);
                }
            }
        }
    }

    private static class ContextHandshakeInterceptor
    implements HandshakeInterceptor {
        private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor.class.getName();
        @Nullable
        private final ThreadLocalAccessor accessor;

        ContextHandshakeInterceptor(@Nullable ThreadLocalAccessor accessor) {
            this.accessor = accessor;
        }

        public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) {
            if (this.accessor != null) {
                LinkedHashMap<String, Object> valuesMap = new LinkedHashMap<String, Object>();
                this.accessor.extractValues(valuesMap);
                attributes.put(SAVED_CONTEXT_KEY, valuesMap);
            }
            return true;
        }

        public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, @Nullable Exception exception) {
        }

        public Closeable restoreThreadLocalValue(WebSocketSession session) {
            if (this.accessor != null) {
                Map valuesMap = (Map)session.getAttributes().get(SAVED_CONTEXT_KEY);
                Assert.state((valuesMap != null ? 1 : 0) != 0, (String)"No ThreadLocal context in WebSocketSession attributes");
                this.accessor.restoreValues(valuesMap);
                return () -> this.accessor.resetValues(valuesMap);
            }
            return () -> {};
        }
    }
}

