package io.trino.server.security.oauth2;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.KeyLengthException;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESDecrypter;
import com.nimbusds.jose.crypto.AESEncrypter;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.io.CompressionAlgorithm;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.TokenPairSerializer;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.time.Clock;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;

/* loaded from: input_file:io/trino/server/security/oauth2/JweTokenSerializer.class */
public class JweTokenSerializer implements TokenPairSerializer {
    private static final CompressionAlgorithm COMPRESSION_ALGORITHM = new ZstdCodec();
    private static final Logger LOG = Logger.get(JweTokenSerializer.class);
    private static final String ACCESS_TOKEN_KEY = "access_token";
    private static final String EXPIRATION_TIME_KEY = "expiration_time";
    private static final String REFRESH_TOKEN_KEY = "refresh_token";
    private final JweEncryptedSerializer jweSerializer;
    private final OAuth2Client client;
    private final Clock clock;
    private final String issuer;
    private final String audience;
    private final Duration tokenExpiration;
    private final JwtParser parser;
    private final String principalField;

    /* loaded from: input_file:io/trino/server/security/oauth2/JweTokenSerializer$JweEncryptedSerializer.class */
    private static class JweEncryptedSerializer {
        private final AESEncrypter jweEncrypter;
        private final AESDecrypter jweDecrypter;
        private final JWEHeader encryptionHeader;

        private JweEncryptedSerializer(SecretKey secretKey) {
            try {
                this.encryptionHeader = createEncryptionHeader(secretKey);
                this.jweEncrypter = new AESEncrypter(secretKey);
                this.jweDecrypter = new AESDecrypter(secretKey);
            } catch (KeyLengthException e) {
                throw new RuntimeException((Throwable) e);
            }
        }

        private JWEHeader createEncryptionHeader(SecretKey secretKey) {
            int length = secretKey.getEncoded().length;
            switch (length) {
                case 16:
                    return new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
                case 24:
                    return new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
                case 32:
                    return new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
                default:
                    throw new IllegalArgumentException("Secret key size must be either 16, 24 or 32 bytes but was %d".formatted(Integer.valueOf(length)));
            }
        }

        private String serialize(String str) {
            try {
                JWEObject jWEObject = new JWEObject(this.encryptionHeader, new Payload(str));
                jWEObject.encrypt(this.jweEncrypter);
                return jWEObject.serialize();
            } catch (JOSEException e) {
                throw new RuntimeException((Throwable) e);
            }
        }

        private String deserialize(String str) throws ParseException {
            try {
                JWEObject parse = JWEObject.parse(str);
                parse.decrypt(this.jweDecrypter);
                return parse.getPayload().toString();
            } catch (JOSEException e) {
                throw new RuntimeException((Throwable) e);
            }
        }
    }

    public JweTokenSerializer(RefreshTokensConfig refreshTokensConfig, OAuth2Client oAuth2Client, String str, String str2, String str3, Clock clock, Duration duration) {
        this.jweSerializer = new JweEncryptedSerializer(getOrGenerateKey(refreshTokensConfig));
        this.client = (OAuth2Client) Objects.requireNonNull(oAuth2Client, "client is null");
        this.issuer = (String) Objects.requireNonNull(str, "issuer is null");
        this.principalField = (String) Objects.requireNonNull(str3, "principalField is null");
        this.audience = (String) Objects.requireNonNull(str2, "issuer is null");
        this.clock = (Clock) Objects.requireNonNull(clock, "clock is null");
        this.tokenExpiration = (Duration) Objects.requireNonNull(duration, "tokenExpiration is null");
        this.parser = ((JwtParserBuilder) JwtUtil.newJwtParserBuilder().clock(() -> {
            return Date.from(clock.instant());
        }).requireIssuer(this.issuer).requireAudience(this.audience).zip().add(COMPRESSION_ALGORITHM).and()).unsecuredDecompression().unsecured().build();
    }

    @Override // io.trino.server.security.oauth2.TokenPairSerializer
    public TokenPairSerializer.TokenPair deserialize(String str) {
        Objects.requireNonNull(str, "token is null");
        try {
            Claims claims = (Claims) this.parser.parseUnsecuredClaims(this.jweSerializer.deserialize(str)).getBody();
            return TokenPairSerializer.TokenPair.withAccessAndRefreshTokens((String) claims.get(ACCESS_TOKEN_KEY, String.class), (Date) claims.get(EXPIRATION_TIME_KEY, Date.class), (String) claims.get(REFRESH_TOKEN_KEY, String.class));
        } catch (ParseException e) {
            return TokenPairSerializer.TokenPair.withAccessToken(str);
        }
    }

    @Override // io.trino.server.security.oauth2.TokenPairSerializer
    public String serialize(TokenPairSerializer.TokenPair tokenPair) {
        Objects.requireNonNull(tokenPair, "tokenPair is null");
        Map<String, Object> orElseThrow = this.client.getClaims(tokenPair.accessToken()).orElseThrow(() -> {
            return new IllegalArgumentException("Claims are missing");
        });
        if (!orElseThrow.containsKey(this.principalField)) {
            throw new IllegalArgumentException(String.format("%s field is missing", this.principalField));
        }
        JwtBuilder compressWith = ((JwtBuilder) JwtUtil.newJwtBuilder().expiration(Date.from(this.clock.instant().plusMillis(this.tokenExpiration.toMillis()))).claim(this.principalField, orElseThrow.get(this.principalField).toString()).audience().add(this.audience).and()).issuer(this.issuer).claim(ACCESS_TOKEN_KEY, tokenPair.accessToken()).claim(EXPIRATION_TIME_KEY, tokenPair.expiration()).compressWith(COMPRESSION_ALGORITHM);
        if (tokenPair.refreshToken().isPresent()) {
            compressWith.claim(REFRESH_TOKEN_KEY, tokenPair.refreshToken().orElseThrow());
        } else {
            LOG.info("No refresh token has been issued, although coordinator expects one. Please check your IdP whether that is correct behaviour");
        }
        return this.jweSerializer.serialize(compressWith.compact());
    }

    private static SecretKey getOrGenerateKey(RefreshTokensConfig refreshTokensConfig) {
        SecretKey secretKey = refreshTokensConfig.getSecretKey();
        if (secretKey != null) {
            return secretKey;
        }
        try {
            KeyGenerator keyGenerator = KeyGenerator.getInstance("AES");
            keyGenerator.init(256);
            return keyGenerator.generateKey();
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }
}
