/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.websocket;

import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.impl.DivideRuleHandle;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.enums.RpcTypeEnum;
import org.apache.shenyu.common.utils.GsonUtils;
import org.apache.shenyu.loadbalancer.cache.UpstreamCacheManager;
import org.apache.shenyu.loadbalancer.entity.Upstream;
import org.apache.shenyu.loadbalancer.factory.LoadBalancerFactory;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.context.ShenyuContext;
import org.apache.shenyu.plugin.api.result.ShenyuResultEnum;
import org.apache.shenyu.plugin.api.result.ShenyuResultWrap;
import org.apache.shenyu.plugin.api.utils.RequestQueryCodecUtil;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.NonNull;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

public class WebSocketPlugin
extends AbstractShenyuPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(WebSocketPlugin.class);
    private static final String SEC_WEB_SOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
    private final WebSocketClient webSocketClient;
    private final WebSocketService webSocketService;

    public WebSocketPlugin(WebSocketClient webSocketClient, WebSocketService webSocketService) {
        this.webSocketClient = webSocketClient;
        this.webSocketService = webSocketService;
    }

    protected Mono<Void> doExecute(ServerWebExchange exchange, ShenyuPluginChain chain, SelectorData selector, RuleData rule) {
        List upstreamList = UpstreamCacheManager.getInstance().findUpstreamListBySelectorId(selector.getId());
        ShenyuContext shenyuContext = (ShenyuContext)exchange.getAttribute("context");
        if (CollectionUtils.isEmpty((Collection)upstreamList) || Objects.isNull(shenyuContext)) {
            LOG.error("websocket upstream configuration error\uff1a{}", (Object)rule);
            return chain.execute(exchange);
        }
        DivideRuleHandle ruleHandle = (DivideRuleHandle)GsonUtils.getInstance().fromJson(rule.getHandle(), DivideRuleHandle.class);
        String ip = Objects.requireNonNull(exchange.getRequest().getRemoteAddress()).getAddress().getHostAddress();
        Upstream upstream = LoadBalancerFactory.selector((List)upstreamList, (String)ruleHandle.getLoadBalance(), (String)ip);
        if (Objects.isNull(upstream)) {
            LOG.error("websocket has no upstream");
            Object error = ShenyuResultWrap.error((ServerWebExchange)exchange, (ShenyuResultEnum)ShenyuResultEnum.CANNOT_FIND_HEALTHY_UPSTREAM_URL);
            return WebFluxResultUtils.result((ServerWebExchange)exchange, (Object)error);
        }
        URI wsRequestUrl = UriComponentsBuilder.fromUri((URI)URI.create(this.buildWsRealPath(exchange, upstream, shenyuContext))).build().toUri();
        LOG.info("you websocket urlPath is :{}", (Object)wsRequestUrl.toASCIIString());
        HttpHeaders headers = exchange.getRequest().getHeaders();
        return this.webSocketService.handleRequest(exchange, (WebSocketHandler)new ShenyuWebSocketHandler(wsRequestUrl, this.webSocketClient, this.filterHeaders(headers), this.buildWsProtocols(headers)));
    }

    private String buildWsRealPath(ServerWebExchange exchange, Upstream upstream, ShenyuContext shenyuContext) {
        String path;
        String protocol = upstream.getProtocol();
        if (!StringUtils.hasLength((String)protocol)) {
            protocol = "ws://";
        }
        String string = path = StringUtils.hasLength((String)shenyuContext.getRealUrl()) ? shenyuContext.getRealUrl() : shenyuContext.getMethod();
        if (StringUtils.hasText((String)exchange.getRequest().getURI().getQuery())) {
            path = String.join((CharSequence)"?", path, RequestQueryCodecUtil.getCodecQuery((ServerWebExchange)exchange));
        }
        return protocol + upstream.getUrl() + path;
    }

    private List<String> buildWsProtocols(HttpHeaders headers) {
        List protocols = headers.get((Object)SEC_WEB_SOCKET_PROTOCOL);
        if (CollectionUtils.isEmpty((Collection)protocols)) {
            return protocols;
        }
        return protocols.stream().flatMap(header -> Arrays.stream(StringUtils.commaDelimitedListToStringArray((String)header))).map(String::trim).collect(Collectors.toList());
    }

    private HttpHeaders filterHeaders(HttpHeaders headers) {
        HttpHeaders filtered = new HttpHeaders();
        headers.entrySet().stream().filter(entry -> !((String)entry.getKey()).toLowerCase().startsWith("sec-websocket")).forEach(header -> filtered.addAll((String)header.getKey(), (List)header.getValue()));
        return filtered;
    }

    public String named() {
        return PluginEnum.WEB_SOCKET.getName();
    }

    public boolean skip(ServerWebExchange exchange) {
        return this.skipExcept(exchange, new RpcTypeEnum[]{RpcTypeEnum.WEB_SOCKET});
    }

    protected Mono<Void> handleSelectorIfNull(String pluginName, ServerWebExchange exchange, ShenyuPluginChain chain) {
        return WebFluxResultUtils.noSelectorResult((String)pluginName, (ServerWebExchange)exchange);
    }

    protected Mono<Void> handleRuleIfNull(String pluginName, ServerWebExchange exchange, ShenyuPluginChain chain) {
        return WebFluxResultUtils.noRuleResult((String)pluginName, (ServerWebExchange)exchange);
    }

    public int getOrder() {
        return PluginEnum.WEB_SOCKET.getCode();
    }

    private static class ShenyuWebSocketHandler
    implements WebSocketHandler {
        private final WebSocketClient client;
        private final URI url;
        private final HttpHeaders headers;
        private final List<String> subProtocols;

        ShenyuWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
            this.client = client;
            this.url = url;
            this.headers = headers;
            this.subProtocols = (List)ObjectUtils.defaultIfNull(protocols, Collections.emptyList());
        }

        @NonNull
        public List<String> getSubProtocols() {
            return this.subProtocols;
        }

        @NonNull
        public Mono<Void> handle(final @NonNull WebSocketSession session) {
            return this.client.execute(this.url, this.headers, new WebSocketHandler(){

                @NonNull
                public Mono<Void> handle(@NonNull WebSocketSession webSocketSession) {
                    Mono sessionSend = webSocketSession.send((Publisher)session.receive().doOnNext(WebSocketMessage::retain));
                    Mono serverSessionSend = session.send((Publisher)webSocketSession.receive().doOnNext(WebSocketMessage::retain));
                    return Mono.zip((Mono)sessionSend, (Mono)serverSessionSend).then();
                }

                @NonNull
                public List<String> getSubProtocols() {
                    return subProtocols;
                }
            });
        }
    }
}

