package io.trino.server.security.oauth2;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.hash.Hashing;
import com.google.inject.Inject;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWEKeySelector;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.RefreshToken;
import com.nimbusds.oauth2.sdk.token.Tokens;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCScopeValue;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;
import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.AccessTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.InvalidHashException;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.server.security.oauth2.NimbusHttpClient;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2ServerConfigProvider;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Date;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/server/security/oauth2/NimbusOAuth2Client.class */
public class NimbusOAuth2Client implements OAuth2Client {
    private static final Logger LOG = Logger.get(NimbusAirliftHttpClient.class);
    private final Issuer issuer;
    private final ClientID clientId;
    private final ClientSecretBasic clientAuth;
    private final Scope scope;
    private final String principalField;
    private final Set<String> accessTokenAudiences;
    private final Duration maxClockSkew;
    private final NimbusHttpClient httpClient;
    private final OAuth2ServerConfigProvider serverConfigurationProvider;
    private volatile boolean loaded;
    private URI authUrl;
    private URI tokenUrl;
    private Optional<URI> userinfoUrl;
    private JWSKeySelector<SecurityContext> jwsKeySelector;
    private JWTProcessor<SecurityContext> accessTokenProcessor;
    private AuthorizationCodeFlow flow;

    /* loaded from: input_file:io/trino/server/security/oauth2/NimbusOAuth2Client$AuthorizationCodeFlow.class */
    private interface AuthorizationCodeFlow {
        OAuth2Client.Request createAuthorizationRequest(String str, URI uri);

        OAuth2Client.Response getOAuth2Response(String str, URI uri, Optional<String> optional) throws ChallengeFailedException;

        OAuth2Client.Response refreshTokens(String str) throws ChallengeFailedException;
    }

    /* loaded from: input_file:io/trino/server/security/oauth2/NimbusOAuth2Client$OAuth2AuthorizationCodeFlow.class */
    private class OAuth2AuthorizationCodeFlow implements AuthorizationCodeFlow {
        private OAuth2AuthorizationCodeFlow() {
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Request createAuthorizationRequest(String str, URI uri) {
            return new OAuth2Client.Request(new AuthorizationRequest.Builder(ResponseType.CODE, NimbusOAuth2Client.this.clientId).redirectionURI(uri).scope(NimbusOAuth2Client.this.scope).endpointURI(NimbusOAuth2Client.this.authUrl).state(new State(str)).build().toURI(), Optional.empty());
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Response getOAuth2Response(String str, URI uri, Optional<String> optional) throws ChallengeFailedException {
            Preconditions.checkArgument(optional.isEmpty(), "Unexpected nonce provided");
            return toResponse(NimbusOAuth2Client.this.getTokenResponse(str, uri, AccessTokenResponse::parse).toSuccessResponse().getTokens(), Optional.empty());
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Response refreshTokens(String str) throws ChallengeFailedException {
            Objects.requireNonNull(str, "refreshToken is null");
            return toResponse(NimbusOAuth2Client.this.getTokenResponse(str, AccessTokenResponse::parse).toSuccessResponse().getTokens(), Optional.of(str));
        }

        private OAuth2Client.Response toResponse(Tokens tokens, Optional<String> optional) throws ChallengeFailedException {
            AccessToken accessToken = tokens.getAccessToken();
            RefreshToken refreshToken = tokens.getRefreshToken();
            return new OAuth2Client.Response(accessToken.getValue(), NimbusOAuth2Client.determineExpiration(NimbusOAuth2Client.getExpiration(accessToken), NimbusOAuth2Client.this.getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> {
                return new ChallengeFailedException("invalid access token");
            }).getExpirationTime()), Optional.empty(), Optional.ofNullable(refreshToken).map((v0) -> {
                return v0.getValue();
            }).or(() -> {
                return optional;
            }));
        }
    }

    /* loaded from: input_file:io/trino/server/security/oauth2/NimbusOAuth2Client$OAuth2WithOidcExtensionsCodeFlow.class */
    private class OAuth2WithOidcExtensionsCodeFlow implements AuthorizationCodeFlow {
        private final IDTokenValidator idTokenValidator;

        public OAuth2WithOidcExtensionsCodeFlow() {
            this.idTokenValidator = new IDTokenValidator(NimbusOAuth2Client.this.issuer, NimbusOAuth2Client.this.clientId, NimbusOAuth2Client.this.jwsKeySelector, (JWEKeySelector) null);
            this.idTokenValidator.setMaxClockSkew((int) NimbusOAuth2Client.this.maxClockSkew.roundTo(TimeUnit.SECONDS));
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Request createAuthorizationRequest(String str, URI uri) {
            String value = new Nonce().getValue();
            return new OAuth2Client.Request(new AuthenticationRequest.Builder(ResponseType.CODE, NimbusOAuth2Client.this.scope, NimbusOAuth2Client.this.clientId, uri).endpointURI(NimbusOAuth2Client.this.authUrl).state(new State(str)).nonce(new Nonce(hashNonce(value))).build().toURI(), Optional.of(value));
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Response getOAuth2Response(String str, URI uri, Optional<String> optional) throws ChallengeFailedException {
            if (optional.isEmpty()) {
                throw new ChallengeFailedException("Missing nonce");
            }
            OIDCTokens oIDCTokens = NimbusOAuth2Client.this.getTokenResponse(str, uri, OIDCTokenResponse::parse).getOIDCTokens();
            validateTokens(oIDCTokens, optional);
            return toResponse(oIDCTokens, Optional.empty());
        }

        @Override // io.trino.server.security.oauth2.NimbusOAuth2Client.AuthorizationCodeFlow
        public OAuth2Client.Response refreshTokens(String str) throws ChallengeFailedException {
            OIDCTokens oIDCTokens = NimbusOAuth2Client.this.getTokenResponse(str, OIDCTokenResponse::parse).getOIDCTokens();
            validateTokens(oIDCTokens);
            return toResponse(oIDCTokens, Optional.of(str));
        }

        private OAuth2Client.Response toResponse(OIDCTokens oIDCTokens, Optional<String> optional) throws ChallengeFailedException {
            AccessToken accessToken = oIDCTokens.getAccessToken();
            RefreshToken refreshToken = oIDCTokens.getRefreshToken();
            return new OAuth2Client.Response(accessToken.getValue(), NimbusOAuth2Client.determineExpiration(NimbusOAuth2Client.getExpiration(accessToken), NimbusOAuth2Client.this.getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> {
                return new ChallengeFailedException("invalid access token");
            }).getExpirationTime()), Optional.ofNullable(oIDCTokens.getIDTokenString()), Optional.ofNullable(refreshToken).map((v0) -> {
                return v0.getValue();
            }).or(() -> {
                return optional;
            }));
        }

        private void validateTokens(OIDCTokens oIDCTokens, Optional<String> optional) throws ChallengeFailedException {
            try {
                AccessTokenHash accessTokenHash = this.idTokenValidator.validate(oIDCTokens.getIDToken(), (Nonce) optional.map(this::hashNonce).map(Nonce::new).orElse(null)).getAccessTokenHash();
                if (accessTokenHash != null) {
                    AccessTokenValidator.validate(oIDCTokens.getAccessToken(), oIDCTokens.getIDToken().getHeader().getAlgorithm(), accessTokenHash);
                }
            } catch (BadJOSEException | JOSEException | InvalidHashException e) {
                throw new ChallengeFailedException("Cannot validate tokens", e);
            }
        }

        private void validateTokens(OIDCTokens oIDCTokens) throws ChallengeFailedException {
            validateTokens(oIDCTokens, Optional.empty());
        }

        private String hashNonce(String str) {
            return Hashing.sha256().hashString(str, StandardCharsets.UTF_8).toString();
        }
    }

    @Inject
    public NimbusOAuth2Client(OAuth2Config oAuth2Config, OAuth2ServerConfigProvider oAuth2ServerConfigProvider, NimbusHttpClient nimbusHttpClient) {
        this.issuer = new Issuer(oAuth2Config.getIssuer());
        this.clientId = new ClientID(oAuth2Config.getClientId());
        this.clientAuth = new ClientSecretBasic(this.clientId, new Secret(oAuth2Config.getClientSecret()));
        this.scope = Scope.parse(oAuth2Config.getScopes());
        this.principalField = oAuth2Config.getPrincipalField();
        this.maxClockSkew = oAuth2Config.getMaxClockSkew();
        this.accessTokenAudiences = new HashSet(oAuth2Config.getAdditionalAudiences());
        this.accessTokenAudiences.add(this.clientId.getValue());
        this.accessTokenAudiences.add(null);
        this.serverConfigurationProvider = (OAuth2ServerConfigProvider) Objects.requireNonNull(oAuth2ServerConfigProvider, "serverConfigurationProvider is null");
        this.httpClient = (NimbusHttpClient) Objects.requireNonNull(nimbusHttpClient, "httpClient is null");
    }

    @Override // io.trino.server.security.oauth2.OAuth2Client
    public void load() {
        OAuth2ServerConfigProvider.OAuth2ServerConfig oAuth2ServerConfig = this.serverConfigurationProvider.get();
        this.authUrl = oAuth2ServerConfig.getAuthUrl();
        this.tokenUrl = oAuth2ServerConfig.getTokenUrl();
        this.userinfoUrl = oAuth2ServerConfig.getUserinfoUrl();
        try {
            this.jwsKeySelector = new JWSVerificationKeySelector((Set) Stream.concat(JWSAlgorithm.Family.RSA.stream(), JWSAlgorithm.Family.EC.stream()).collect(ImmutableSet.toImmutableSet()), new RemoteJWKSet(oAuth2ServerConfig.getJwksUrl().toURL(), this.httpClient));
            DefaultJWTProcessor defaultJWTProcessor = new DefaultJWTProcessor();
            defaultJWTProcessor.setJWSKeySelector(this.jwsKeySelector);
            DefaultJWTClaimsVerifier defaultJWTClaimsVerifier = new DefaultJWTClaimsVerifier(this.accessTokenAudiences, new JWTClaimsSet.Builder().issuer(oAuth2ServerConfig.getAccessTokenIssuer().orElse(this.issuer.getValue())).build(), ImmutableSet.of(this.principalField), ImmutableSet.of());
            defaultJWTClaimsVerifier.setMaxClockSkew((int) this.maxClockSkew.roundTo(TimeUnit.SECONDS));
            defaultJWTProcessor.setJWTClaimsSetVerifier(defaultJWTClaimsVerifier);
            this.accessTokenProcessor = defaultJWTProcessor;
            this.flow = this.scope.contains(OIDCScopeValue.OPENID) ? new OAuth2WithOidcExtensionsCodeFlow() : new OAuth2AuthorizationCodeFlow();
            this.loaded = true;
        } catch (MalformedURLException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // io.trino.server.security.oauth2.OAuth2Client
    public OAuth2Client.Request createAuthorizationRequest(String str, URI uri) {
        Preconditions.checkState(this.loaded, "OAuth2 client not initialized");
        return this.flow.createAuthorizationRequest(str, uri);
    }

    @Override // io.trino.server.security.oauth2.OAuth2Client
    public OAuth2Client.Response getOAuth2Response(String str, URI uri, Optional<String> optional) throws ChallengeFailedException {
        Preconditions.checkState(this.loaded, "OAuth2 client not initialized");
        return this.flow.getOAuth2Response(str, uri, optional);
    }

    @Override // io.trino.server.security.oauth2.OAuth2Client
    public Optional<Map<String, Object>> getClaims(String str) {
        Preconditions.checkState(this.loaded, "OAuth2 client not initialized");
        return getJWTClaimsSet(str).map((v0) -> {
            return v0.getClaims();
        });
    }

    @Override // io.trino.server.security.oauth2.OAuth2Client
    public OAuth2Client.Response refreshTokens(String str) throws ChallengeFailedException {
        Preconditions.checkState(this.loaded, "OAuth2 client not initialized");
        return this.flow.refreshTokens(str);
    }

    private <T extends AccessTokenResponse> T getTokenResponse(String str, URI uri, NimbusHttpClient.Parser<T> parser) throws ChallengeFailedException {
        return (T) getTokenResponse(new TokenRequest(this.tokenUrl, this.clientAuth, new AuthorizationCodeGrant(new AuthorizationCode(str), uri)), parser);
    }

    private <T extends AccessTokenResponse> T getTokenResponse(String str, NimbusHttpClient.Parser<T> parser) throws ChallengeFailedException {
        return (T) getTokenResponse(new TokenRequest(this.tokenUrl, this.clientAuth, new RefreshTokenGrant(new RefreshToken(str)), this.scope), parser);
    }

    private <T extends AccessTokenResponse> T getTokenResponse(TokenRequest tokenRequest, NimbusHttpClient.Parser<T> parser) throws ChallengeFailedException {
        T t = (T) this.httpClient.execute(tokenRequest, parser);
        if (t.indicatesSuccess()) {
            return t;
        }
        throw new ChallengeFailedException("Error while fetching access token: " + t.toErrorResponse().toHTTPResponse().getContent());
    }

    private Optional<JWTClaimsSet> getJWTClaimsSet(String str) {
        return this.userinfoUrl.isPresent() ? queryUserInfo(str) : parseAccessToken(str);
    }

    private Optional<JWTClaimsSet> queryUserInfo(String str) {
        try {
            UserInfoResponse userInfoResponse = (UserInfoResponse) this.httpClient.execute(new UserInfoRequest(this.userinfoUrl.get(), new BearerAccessToken(str)), UserInfoResponse::parse);
            if (userInfoResponse.indicatesSuccess()) {
                return Optional.of(userInfoResponse.toSuccessResponse().getUserInfo().toJWTClaimsSet());
            }
            LOG.error("Received bad response from userinfo endpoint: %s", new Object[]{userInfoResponse.toErrorResponse().getErrorObject()});
            return Optional.empty();
        } catch (ParseException | RuntimeException e) {
            LOG.error(e, "Received bad response from userinfo endpoint");
            return Optional.empty();
        }
    }

    private Optional<JWTClaimsSet> parseAccessToken(String str) {
        try {
            return Optional.of(this.accessTokenProcessor.process(str, (SecurityContext) null));
        } catch (java.text.ParseException | BadJOSEException | JOSEException e) {
            LOG.error(e, "Failed to parse JWT access token");
            return Optional.empty();
        }
    }

    private static Instant determineExpiration(Optional<Instant> optional, Date date) throws ChallengeFailedException {
        if (optional.isPresent()) {
            return date != null ? (Instant) Ordering.natural().min(optional.get(), date.toInstant()) : optional.get();
        }
        if (date != null) {
            return date.toInstant();
        }
        throw new ChallengeFailedException("no valid expiration date");
    }

    private static Optional<Instant> getExpiration(AccessToken accessToken) {
        return accessToken.getLifetime() != 0 ? Optional.of(Instant.now().plusSeconds(accessToken.getLifetime())) : Optional.empty();
    }
}
