/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.graphql.server.webflux;

import graphql.ErrorClassification;
import graphql.ExecutionResult;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.aot.hint.annotation.RegisterReflectionForBinding;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionPublisherException;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.graphql.server.webflux.WebSocketCodecDelegate;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@RegisterReflectionForBinding(value={GraphQlWebSocketMessage.class})
public class GraphQlWebSocketHandler
implements WebSocketHandler {
    private static final Log logger = LogFactory.getLog(GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");
    private final WebGraphQlHandler graphQlHandler;
    private final WebSocketGraphQlInterceptor webSocketInterceptor;
    private final WebSocketCodecDelegate codecDelegate;
    private final Duration initTimeoutDuration;
    private final @Nullable Duration keepAliveDuration;

    public GraphQlWebSocketHandler(WebGraphQlHandler graphQlHandler, CodecConfigurer codecConfigurer, Duration connectionInitTimeout) {
        this(graphQlHandler, codecConfigurer, connectionInitTimeout, null);
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler graphQlHandler, CodecConfigurer codecConfigurer, Duration connectionInitTimeout, @Nullable Duration keepAliveDuration) {
        Assert.notNull((Object)graphQlHandler, (String)"WebGraphQlHandler is required");
        this.graphQlHandler = graphQlHandler;
        this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor();
        this.codecDelegate = new WebSocketCodecDelegate(codecConfigurer);
        this.initTimeoutDuration = connectionInitTimeout;
        this.keepAliveDuration = keepAliveDuration;
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public Mono<Void> handle(WebSocketSession session) {
        HandshakeInfo handshakeInfo = session.getHandshakeInfo();
        if ("graphql-ws".equalsIgnoreCase(handshakeInfo.getSubProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug((Object)"apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            return session.close(GraphQlStatus.INVALID_MESSAGE_STATUS);
        }
        WebFluxSessionInfo sessionInfo = new WebFluxSessionInfo(session);
        AtomicReference<@Nullable V> connectionInitPayloadRef = new AtomicReference();
        ConcurrentHashMap subscriptions = new ConcurrentHashMap();
        Mono.delay((Duration)this.initTimeoutDuration).then(Mono.defer(() -> connectionInitPayloadRef.compareAndSet(null, Collections.emptyMap()) ? session.close(GraphQlStatus.INIT_TIMEOUT_STATUS) : Mono.empty())).subscribe();
        session.closeStatus().doOnSuccess(closeStatus -> {
            Map connectionInitPayload = (Map)connectionInitPayloadRef.get();
            if (connectionInitPayload == null) {
                return;
            }
            int statusCode = closeStatus != null ? closeStatus.getCode() : 1005;
            this.webSocketInterceptor.handleConnectionClosed(sessionInfo, statusCode, connectionInitPayload);
        }).subscribe();
        return session.send((Publisher)session.receive().flatMap(webSocketMessage -> {
            GraphQlWebSocketMessage message = this.codecDelegate.decode((WebSocketMessage)webSocketMessage);
            if (message == null) {
                return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
            }
            String id = message.getId();
            Map payload = (Map)message.getPayload();
            switch (message.resolvedType()) {
                case SUBSCRIBE: {
                    if (connectionInitPayloadRef.get() == null) {
                        return GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                    }
                    if (id == null) {
                        return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    }
                    WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(handshakeInfo.getUri(), handshakeInfo.getHeaders(), (MultiValueMap<String, HttpCookie>)handshakeInfo.getCookies(), handshakeInfo.getRemoteAddress(), handshakeInfo.getAttributes(), payload, id, null, sessionInfo);
                    if (logger.isDebugEnabled()) {
                        logger.debug((Object)("Executing: " + String.valueOf(request)));
                    }
                    return this.graphQlHandler.handleRequest(request).flatMapMany(response -> this.handleResponse(session, id, subscriptions, (WebGraphQlResponse)response)).doOnTerminate(() -> subscriptions.remove(id));
                }
                case PING: {
                    return Flux.just((Object)this.codecDelegate.encode(session, GraphQlWebSocketMessage.pong(null)));
                }
                case PONG: {
                    return Flux.empty();
                }
                case COMPLETE: {
                    if (id != null) {
                        Subscription subscription = (Subscription)subscriptions.remove(id);
                        if (subscription != null) {
                            subscription.cancel();
                        }
                        return this.webSocketInterceptor.handleCancelledSubscription(sessionInfo, id).thenMany((Publisher)Flux.empty());
                    }
                    return Flux.empty();
                }
                case CONNECTION_INIT: {
                    if (!connectionInitPayloadRef.compareAndSet(null, payload)) {
                        return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                    }
                    Flux flux = this.webSocketInterceptor.handleConnectionInitialization(sessionInfo, payload).defaultIfEmpty(Collections.emptyMap()).map(ackPayload -> this.codecDelegate.encodeConnectionAck(session, ackPayload)).flux();
                    if (this.keepAliveDuration != null) {
                        flux = flux.mergeWith((Publisher)Flux.interval((Duration)this.keepAliveDuration, (Duration)this.keepAliveDuration).filter(aLong -> !this.codecDelegate.checkMessagesEncodedAndClear()).map(aLong -> this.codecDelegate.encode(session, GraphQlWebSocketMessage.ping(null))));
                    }
                    return flux.onErrorResume(ex -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
                }
            }
            return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
        }));
    }

    private Flux<WebSocketMessage> handleResponse(WebSocketSession session, String id, Map<String, Subscription> subscriptions, WebGraphQlResponse response) {
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Execution result ready" + (String)(!CollectionUtils.isEmpty(response.getErrors()) ? " with errors: " + String.valueOf(response.getErrors()) : "") + "."));
        }
        Flux responseFlux = response.getData() instanceof Publisher ? Flux.from((Publisher)((Publisher)response.getData())).map(ExecutionResult::toSpecification).doOnSubscribe(subscription -> {
            Subscription previous = subscriptions.putIfAbsent(id, (Subscription)subscription);
            if (previous != null) {
                throw new SubscriptionExistsException();
            }
        }) : Flux.just(response.toMap());
        return responseFlux.map(responseMap -> this.codecDelegate.encodeNext(session, id, (Map<String, Object>)responseMap)).concatWith((Publisher)Mono.fromCallable(() -> this.codecDelegate.encodeComplete(session, id))).onErrorResume(ex -> {
            List<GraphQLError> errors;
            if (ex instanceof SubscriptionExistsException) {
                CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
                return GraphQlStatus.close(session, status);
            }
            if (ex instanceof SubscriptionPublisherException) {
                SubscriptionPublisherException subscriptionEx = (SubscriptionPublisherException)((Object)((Object)ex));
                errors = subscriptionEx.getErrors();
            } else {
                if (logger.isErrorEnabled()) {
                    logger.error((Object)("Unresolved " + ex.getClass().getSimpleName() + " for request id " + id), ex);
                }
                errors = Collections.singletonList(GraphqlErrorBuilder.newError().message("Subscription error", new Object[0]).errorType((ErrorClassification)ErrorType.INTERNAL_ERROR).build());
            }
            return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, errors));
        });
    }

    private static final class GraphQlStatus {
        static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        static <V> Flux<V> close(WebSocketSession session, CloseStatus status) {
            return session.close(status).thenMany((Publisher)Mono.empty());
        }
    }

    private static final class WebFluxSessionInfo
    implements WebSocketSessionInfo {
        private final WebSocketSession session;

        private WebFluxSessionInfo(WebSocketSession session) {
            this.session = session;
        }

        @Override
        public String getId() {
            return this.session.getId();
        }

        @Override
        public Map<String, Object> getAttributes() {
            return this.session.getAttributes();
        }

        @Override
        public URI getUri() {
            return this.session.getHandshakeInfo().getUri();
        }

        @Override
        public HttpHeaders getHeaders() {
            return this.session.getHandshakeInfo().getHeaders();
        }

        @Override
        public Mono<Principal> getPrincipal() {
            return this.session.getHandshakeInfo().getPrincipal();
        }

        @Override
        public @Nullable InetSocketAddress getRemoteAddress() {
            return this.session.getHandshakeInfo().getRemoteAddress();
        }
    }

    private static final class SubscriptionExistsException
    extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }
}

