package io.trino.gateway.ha.handler;

import com.codahale.metrics.Meter;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.io.CharStreams;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingGroupSelector;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.proxyserver.ProxyHandler;
import io.trino.gateway.proxyserver.wrapper.MultiReadHttpServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.util.Callback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.class */
public class QueryIdCachingProxyHandler extends ProxyHandler {
    public static final String PROXY_TARGET_HEADER = "proxytarget";
    public static final String V1_STATEMENT_PATH = "/v1/statement";
    public static final String V1_QUERY_PATH = "/v1/query";
    public static final String V1_INFO_PATH = "/v1/info";
    public static final String V1_NODE_PATH = "/v1/node";
    public static final String UI_API_STATS_PATH = "/ui/api/stats";
    public static final String UI_LOGIN_PATH = "/ui/login";
    public static final String UI_API_QUEUED_LIST_PATH = "/ui/api/query?state=QUEUED";
    public static final String TRINO_UI_PATH = "/ui";
    public static final String OAUTH_PATH = "/oauth2";
    public static final String AUTHORIZATION = "Authorization";
    public static final String USER_HEADER = "X-Trino-User";
    public static final String SOURCE_HEADER = "X-Trino-Source";
    public static final String HOST_HEADER = "Host";
    private static final int QUERY_TEXT_LENGTH_FOR_HISTORY = 200;
    private final RoutingManager routingManager;
    private final RoutingGroupSelector routingGroupSelector;
    private final QueryHistoryManager queryHistoryManager;
    private final Meter requestMeter;
    private final int serverApplicationPort;
    private final List<String> extraWhitelistPaths;
    private static final Logger log = LoggerFactory.getLogger(QueryIdCachingProxyHandler.class);
    private static final Pattern QUERY_ID_PATTERN = Pattern.compile(".*[/=?](\\d+_\\d+_\\d+_\\w+).*");
    private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

    public QueryIdCachingProxyHandler(QueryHistoryManager queryHistoryManager, RoutingManager routingManager, RoutingGroupSelector routingGroupSelector, int i, Meter meter, List<String> list) {
        this.requestMeter = meter;
        this.routingManager = routingManager;
        this.routingGroupSelector = routingGroupSelector;
        this.queryHistoryManager = queryHistoryManager;
        this.serverApplicationPort = i;
        this.extraWhitelistPaths = list;
    }

    protected static String extractQueryIdIfPresent(String str, String str2) {
        if (str == null) {
            return null;
        }
        String str3 = null;
        log.debug("trying to extract query id from  path [{}] or queryString [{}]", str, str2);
        if (str.startsWith(V1_STATEMENT_PATH) || str.startsWith(V1_QUERY_PATH)) {
            String[] split = str.split("/");
            if (split.length >= 4) {
                str3 = (str.contains("queued") || str.contains("scheduled") || str.contains("executing") || str.contains("partialCancel")) ? split[4] : split[3];
            }
        } else if (str.startsWith(TRINO_UI_PATH)) {
            Matcher matcher = QUERY_ID_PATTERN.matcher(str);
            if (matcher.matches()) {
                str3 = matcher.group(1);
            }
        }
        log.debug("query id in url [{}]", str3);
        return str3;
    }

    protected String extractQueryIdIfPresent(HttpServletRequest httpServletRequest) {
        String requestURI = httpServletRequest.getRequestURI();
        String queryString = httpServletRequest.getQueryString();
        try {
            String charStreams = CharStreams.toString(httpServletRequest.getReader());
            if (!Strings.isNullOrEmpty(charStreams) && charStreams.toLowerCase().contains("system.runtime.kill_query")) {
                for (String str : charStreams.split(",")) {
                    if (str.contains("query_id")) {
                        Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(str);
                        if (matcher.find()) {
                            String group = matcher.group();
                            if (!Strings.isNullOrEmpty(group) && group.length() > 0) {
                                return group.substring(1, group.length() - 1);
                            }
                        } else {
                            continue;
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("Error extracting query payload from request", e);
        }
        return extractQueryIdIfPresent(requestURI, queryString);
    }

    static void setForwardedHostHeaderOnProxyRequest(HttpServletRequest httpServletRequest, Request request) {
        if (httpServletRequest.getHeader(PROXY_TARGET_HEADER) == null) {
            log.warn("Proxy Target not set on request, unable to decipher HOST header");
            return;
        }
        try {
            URI uri = new URI(httpServletRequest.getHeader(PROXY_TARGET_HEADER));
            StringBuilder sb = new StringBuilder();
            sb.append(uri.getHost());
            if (uri.getPort() != -1) {
                sb.append(":").append(uri.getPort());
            }
            String sb2 = sb.toString();
            log.debug("Incoming Request Host header : [{}], proxy request host header : [{}]", httpServletRequest.getHeader(HOST_HEADER), sb2);
            request.header(HOST_HEADER, sb2);
        } catch (URISyntaxException e) {
            log.warn(e.toString());
        }
    }

    static String getQueryUser(HttpServletRequest httpServletRequest) {
        String header = httpServletRequest.getHeader(USER_HEADER);
        if (!Strings.isNullOrEmpty(header)) {
            log.info("user from %s", USER_HEADER);
            return header;
        }
        log.info("user from basic auth");
        String header2 = httpServletRequest.getHeader(AUTHORIZATION);
        if (header2 == null) {
            log.error("didn't find any basic auth header");
            return "";
        }
        int indexOf = header2.indexOf(32);
        if (indexOf < 0 || !header2.substring(0, indexOf).equalsIgnoreCase("basic")) {
            log.error("basic auth format is incorrect");
            return "";
        }
        String trim = header2.substring(indexOf + 1).trim();
        if (Strings.isNullOrEmpty(trim)) {
            log.error("The encoded value of basic auth doesn't exist");
            return "";
        }
        List splitToList = Splitter.on(':').limit(2).splitToList(new String(Base64.getDecoder().decode(trim)));
        if (splitToList.size() >= 1) {
            return (String) splitToList.get(0);
        }
        log.error("no user inside the basic auth text");
        return "";
    }

    public void preConnectionHook(HttpServletRequest httpServletRequest, Request request) {
        if (httpServletRequest.getMethod().equals("POST") && httpServletRequest.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
            this.requestMeter.mark();
            try {
                log.info("Processing request endpoint: [{}], payload: [{}]", httpServletRequest.getRequestURI(), CharStreams.toString(httpServletRequest.getReader()));
                debugLogHeaders(httpServletRequest);
            } catch (Exception e) {
                log.warn("Error fetching the request payload", e);
            }
        }
        if (isPathWhiteListed(httpServletRequest.getRequestURI())) {
            setForwardedHostHeaderOnProxyRequest(httpServletRequest, request);
        }
    }

    private boolean isPathWhiteListed(String str) {
        return str.startsWith(V1_STATEMENT_PATH) || str.startsWith(V1_QUERY_PATH) || str.startsWith(TRINO_UI_PATH) || str.startsWith(V1_INFO_PATH) || str.startsWith(V1_NODE_PATH) || str.startsWith(UI_API_STATS_PATH) || str.startsWith(OAUTH_PATH) || this.extraWhitelistPaths.stream().anyMatch(str2 -> {
            return str.startsWith(str2);
        });
    }

    public boolean isAuthEnabled() {
        return false;
    }

    public boolean handleAuthRequest(HttpServletRequest httpServletRequest) {
        return true;
    }

    public String rewriteTarget(HttpServletRequest httpServletRequest) {
        String str = "http://localhost:" + this.serverApplicationPort;
        if (isPathWhiteListed(httpServletRequest.getRequestURI())) {
            String extractQueryIdIfPresent = extractQueryIdIfPresent(httpServletRequest);
            if (Strings.isNullOrEmpty(extractQueryIdIfPresent)) {
                String findRoutingGroup = this.routingGroupSelector.findRoutingGroup(httpServletRequest);
                String header = httpServletRequest.getHeader(USER_HEADER);
                str = !Strings.isNullOrEmpty(findRoutingGroup) ? this.routingManager.provideBackendForRoutingGroup(findRoutingGroup, header) : this.routingManager.provideAdhocBackend(header);
            } else {
                str = this.routingManager.findBackendForQueryId(extractQueryIdIfPresent);
            }
            ((MultiReadHttpServletRequest) httpServletRequest).addHeader(PROXY_TARGET_HEADER, str);
        }
        if (isAuthEnabled() && httpServletRequest.getHeader(AUTHORIZATION) != null && !handleAuthRequest(httpServletRequest)) {
            log.info("Could not authenticate Request: " + httpServletRequest.toString());
            return null;
        }
        String str2 = str + httpServletRequest.getRequestURI() + (httpServletRequest.getQueryString() != null ? "?" + httpServletRequest.getQueryString() : "");
        log.info("Rerouting [{}]--> [{}]", httpServletRequest.getScheme() + "://" + httpServletRequest.getRemoteHost() + ":" + httpServletRequest.getServerPort() + httpServletRequest.getRequestURI() + (httpServletRequest.getQueryString() != null ? "?" + httpServletRequest.getQueryString() : ""), str2);
        return str2;
    }

    protected void postConnectionHook(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, byte[] bArr, int i, int i2, Callback callback) {
        try {
            String requestURI = httpServletRequest.getRequestURI();
            if (requestURI.startsWith(V1_STATEMENT_PATH) && httpServletRequest.getMethod().equals("POST")) {
                String plainTextFromGz = isGZipEncoding(httpServletResponse) ? plainTextFromGz(bArr) : new String(bArr);
                log.debug("For Request [{}] got Response output [{}]", httpServletRequest.getRequestURI(), plainTextFromGz);
                QueryHistoryManager.QueryDetail queryDetailsFromRequest = getQueryDetailsFromRequest(httpServletRequest);
                log.debug("Extracting Proxy destination : [{}] for request : [{}]", queryDetailsFromRequest.getBackendUrl(), httpServletRequest.getRequestURI());
                if (httpServletResponse.getStatus() == QUERY_TEXT_LENGTH_FOR_HISTORY) {
                    queryDetailsFromRequest.setQueryId((String) ((HashMap) OBJECT_MAPPER.readValue(plainTextFromGz, HashMap.class)).get("id"));
                    if (Strings.isNullOrEmpty(queryDetailsFromRequest.getQueryId())) {
                        log.debug("QueryId [{}] could not be cached", queryDetailsFromRequest.getQueryId());
                    } else {
                        this.routingManager.setBackendForQueryId(queryDetailsFromRequest.getQueryId(), queryDetailsFromRequest.getBackendUrl());
                        log.debug("QueryId [{}] mapped with proxy [{}]", queryDetailsFromRequest.getQueryId(), queryDetailsFromRequest.getBackendUrl());
                    }
                } else {
                    log.error("Non OK HTTP Status code with response [{}] , Status code [{}]", plainTextFromGz, Integer.valueOf(httpServletResponse.getStatus()));
                }
                this.queryHistoryManager.submitQueryDetail(queryDetailsFromRequest);
            } else {
                log.debug("SKIPPING For {}", requestURI);
            }
        } catch (Exception e) {
            log.error("Error in proxying falling back to super call", e);
        }
        super.postConnectionHook(httpServletRequest, httpServletResponse, bArr, i, i2, callback);
    }

    private QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(HttpServletRequest httpServletRequest) throws IOException {
        QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
        queryDetail.setBackendUrl(httpServletRequest.getHeader(PROXY_TARGET_HEADER));
        queryDetail.setCaptureTime(System.currentTimeMillis());
        queryDetail.setUser(getQueryUser(httpServletRequest));
        queryDetail.setSource(httpServletRequest.getHeader(SOURCE_HEADER));
        String charStreams = CharStreams.toString(httpServletRequest.getReader());
        queryDetail.setQueryText(charStreams.length() > QUERY_TEXT_LENGTH_FOR_HISTORY ? charStreams.substring(0, QUERY_TEXT_LENGTH_FOR_HISTORY) + "..." : charStreams);
        return queryDetail;
    }
}
