package io.trino.gateway.ha.security;

import com.auth0.jwk.UrlJwkProvider;
import com.auth0.jwt.JWT;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.trino.gateway.ha.config.OAuthConfiguration;
import jakarta.ws.rs.client.ClientBuilder;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.Form;
import jakarta.ws.rs.core.NewCookie;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/trino/gateway/ha/security/LbOAuthManager.class */
public class LbOAuthManager {
    private static final Logger log = LoggerFactory.getLogger(LbOAuthManager.class);
    private final OAuthConfiguration oauthConfig;

    @JsonIgnoreProperties(ignoreUnknown = true)
    /* loaded from: input_file:io/trino/gateway/ha/security/LbOAuthManager$OidcTokens.class */
    static final class OidcTokens {

        @JsonProperty
        private final String accessToken;

        @JsonProperty
        private final String idToken;

        @JsonProperty
        private final String scope;

        @JsonProperty
        private final String refreshToken;

        @JsonProperty
        private final String tokenType;

        @JsonProperty
        private final String expiresIn;

        @JsonCreator
        public OidcTokens(@JsonProperty("id_token") String str, @JsonProperty("access_token") String str2, @JsonProperty("refresh_token") String str3, @JsonProperty("token_type") String str4, @JsonProperty("expires_in") String str5, @JsonProperty("scope") String str6) {
            this.accessToken = str2;
            this.idToken = str;
            this.tokenType = str4;
            this.expiresIn = str5;
            this.scope = str6;
            this.refreshToken = str3;
        }

        public String getAccessToken() {
            return this.accessToken;
        }

        public String getIdToken() {
            return this.idToken;
        }

        public String getScope() {
            return this.scope;
        }

        public String getRefreshToken() {
            return this.refreshToken;
        }

        public String getTokenType() {
            return this.tokenType;
        }

        public String getExpiresIn() {
            return this.expiresIn;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof OidcTokens)) {
                return false;
            }
            OidcTokens oidcTokens = (OidcTokens) obj;
            String accessToken = getAccessToken();
            String accessToken2 = oidcTokens.getAccessToken();
            if (accessToken == null) {
                if (accessToken2 != null) {
                    return false;
                }
            } else if (!accessToken.equals(accessToken2)) {
                return false;
            }
            String idToken = getIdToken();
            String idToken2 = oidcTokens.getIdToken();
            if (idToken == null) {
                if (idToken2 != null) {
                    return false;
                }
            } else if (!idToken.equals(idToken2)) {
                return false;
            }
            String scope = getScope();
            String scope2 = oidcTokens.getScope();
            if (scope == null) {
                if (scope2 != null) {
                    return false;
                }
            } else if (!scope.equals(scope2)) {
                return false;
            }
            String refreshToken = getRefreshToken();
            String refreshToken2 = oidcTokens.getRefreshToken();
            if (refreshToken == null) {
                if (refreshToken2 != null) {
                    return false;
                }
            } else if (!refreshToken.equals(refreshToken2)) {
                return false;
            }
            String tokenType = getTokenType();
            String tokenType2 = oidcTokens.getTokenType();
            if (tokenType == null) {
                if (tokenType2 != null) {
                    return false;
                }
            } else if (!tokenType.equals(tokenType2)) {
                return false;
            }
            String expiresIn = getExpiresIn();
            String expiresIn2 = oidcTokens.getExpiresIn();
            return expiresIn == null ? expiresIn2 == null : expiresIn.equals(expiresIn2);
        }

        public int hashCode() {
            String accessToken = getAccessToken();
            int hashCode = (1 * 59) + (accessToken == null ? 43 : accessToken.hashCode());
            String idToken = getIdToken();
            int hashCode2 = (hashCode * 59) + (idToken == null ? 43 : idToken.hashCode());
            String scope = getScope();
            int hashCode3 = (hashCode2 * 59) + (scope == null ? 43 : scope.hashCode());
            String refreshToken = getRefreshToken();
            int hashCode4 = (hashCode3 * 59) + (refreshToken == null ? 43 : refreshToken.hashCode());
            String tokenType = getTokenType();
            int hashCode5 = (hashCode4 * 59) + (tokenType == null ? 43 : tokenType.hashCode());
            String expiresIn = getExpiresIn();
            return (hashCode5 * 59) + (expiresIn == null ? 43 : expiresIn.hashCode());
        }

        public String toString() {
            return "LbOAuthManager.OidcTokens(accessToken=" + getAccessToken() + ", idToken=" + getIdToken() + ", scope=" + getScope() + ", refreshToken=" + getRefreshToken() + ", tokenType=" + getTokenType() + ", expiresIn=" + getExpiresIn() + ")";
        }
    }

    public LbOAuthManager(OAuthConfiguration oAuthConfiguration) {
        this.oauthConfig = oAuthConfiguration;
    }

    public String getUserIdField() {
        return this.oauthConfig.getUserIdField();
    }

    public Response exchangeCodeForToken(String str, String str2) {
        String tokenEndpoint = this.oauthConfig.getTokenEndpoint();
        String clientId = this.oauthConfig.getClientId();
        String clientSecret = this.oauthConfig.getClientSecret();
        Response post = ClientBuilder.newBuilder().build().target(tokenEndpoint).request().post(Entity.form(new Form().param("grant_type", "authorization_code").param("client_id", clientId).param("client_secret", clientSecret).param("code", str).param("redirect_uri", this.oauthConfig.getRedirectUrl())));
        if (post.getStatusInfo().getFamily() == Response.Status.Family.SUCCESSFUL) {
            return Response.status(302).location(URI.create(str2)).cookie(new NewCookie[]{SessionCookie.getTokenCookie(((OidcTokens) post.readEntity(OidcTokens.class)).getIdToken())}).build();
        }
        String format = String.format("token response failed with code %d - %s", Integer.valueOf(post.getStatus()), post.readEntity(String.class));
        log.error(format);
        return Response.status(500).entity(format).build();
    }

    public Response getAuthorizationCode() {
        return Response.status(302).location(URI.create(String.format("%s?client_id=%s&response_type=code&redirect_uri=%s&scope=%s", this.oauthConfig.getAuthorizationEndpoint(), this.oauthConfig.getClientId(), this.oauthConfig.getRedirectUrl(), String.join("+", this.oauthConfig.getScopes())))).build();
    }

    public Optional<Map<String, Claim>> getClaimsFromIdToken(String str) {
        try {
            DecodedJWT decode = JWT.decode(str);
            if (LbTokenUtil.validateToken(str, (RSAPublicKey) new UrlJwkProvider(new URL(this.oauthConfig.getJwkEndpoint())).get(decode.getKeyId()).getPublicKey(), decode.getIssuer())) {
                return Optional.of(decode.getClaims());
            }
        } catch (Exception e) {
            log.error("Could not validate token or get claims from it.", e);
        }
        return Optional.empty();
    }
}
