package io.trino.server.security.jwt;

import com.google.inject.Inject;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
import io.trino.server.security.UserMappingException;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;
import jakarta.ws.rs.container.ContainerRequestContext;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/server/security/jwt/JwtAuthenticator.class */
public class JwtAuthenticator extends AbstractBearerAuthenticator {
    private final JwtParser jwtParser;
    private final String principalField;
    private final UserMapping userMapping;
    private final Optional<String> requiredAudience;

    @Inject
    public JwtAuthenticator(JwtAuthenticatorConfig jwtAuthenticatorConfig, @ForJwt SigningKeyResolver signingKeyResolver) {
        this.principalField = jwtAuthenticatorConfig.getPrincipalField();
        this.requiredAudience = Optional.ofNullable(jwtAuthenticatorConfig.getRequiredAudience());
        JwtParserBuilder signingKeyResolver2 = JwtUtil.newJwtParserBuilder().setSigningKeyResolver(signingKeyResolver);
        if (jwtAuthenticatorConfig.getRequiredIssuer() != null) {
            signingKeyResolver2.requireIssuer(jwtAuthenticatorConfig.getRequiredIssuer());
        }
        this.jwtParser = signingKeyResolver2.build();
        this.userMapping = UserMapping.createUserMapping(jwtAuthenticatorConfig.getUserMappingPattern(), jwtAuthenticatorConfig.getUserMappingFile());
    }

    @Override // io.trino.server.security.AbstractBearerAuthenticator
    protected Optional<Identity> createIdentity(String str) throws UserMappingException {
        Claims claims = (Claims) this.jwtParser.parseClaimsJws(str).getBody();
        validateAudience(claims);
        Optional ofNullable = Optional.ofNullable((String) claims.get(this.principalField, String.class));
        return ofNullable.isEmpty() ? Optional.empty() : Optional.of(Identity.forUser(this.userMapping.mapUser((String) ofNullable.get())).withPrincipal(new BasicPrincipal((String) ofNullable.get())).build());
    }

    private void validateAudience(Claims claims) {
        if (this.requiredAudience.isEmpty()) {
            return;
        }
        Object obj = claims.get("aud");
        if (obj == null) {
            throw new InvalidClaimException(String.format("Expected %s claim to be: %s, but was not present in the JWT claims.", "aud", this.requiredAudience.get()));
        }
        if (obj instanceof String) {
            if (!this.requiredAudience.get().equals((String) obj)) {
                throw new InvalidClaimException(String.format("Invalid Audience: %s. Allowed audiences: %s", obj, this.requiredAudience.get()));
            }
        } else {
            if (!(obj instanceof Collection)) {
                throw new InvalidClaimException(String.format("Invalid Audience: %s", obj));
            }
            Stream stream = ((Collection) obj).stream();
            Class<String> cls = String.class;
            Objects.requireNonNull(String.class);
            if (stream.map(cls::cast).noneMatch(str -> {
                return this.requiredAudience.get().equals(str);
            })) {
                throw new InvalidClaimException(String.format("Invalid Audience: %s. Allowed audiences: %s", obj, this.requiredAudience.get()));
            }
        }
    }

    @Override // io.trino.server.security.AbstractBearerAuthenticator
    protected AuthenticationException needAuthentication(ContainerRequestContext containerRequestContext, Optional<String> optional, String str) {
        return new AuthenticationException(str, "Bearer realm=\"Trino\", token_type=\"JWT\"");
    }
}
