/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.oauth2.jwt;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
import com.nimbusds.jose.proc.JOSEObjectTypeVerifier;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.DPoPProofContext;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.JwtIssuedAtValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public final class DPoPProofJwtDecoderFactory
implements JwtDecoderFactory<DPoPProofContext> {
    public static final Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> DEFAULT_JWT_VALIDATOR_FACTORY = DPoPProofJwtDecoderFactory.defaultJwtValidatorFactory();
    private static final JOSEObjectTypeVerifier<SecurityContext> DPOP_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier(new JOSEObjectType[]{new JOSEObjectType("dpop+jwt")});
    private Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = DEFAULT_JWT_VALIDATOR_FACTORY;

    @Override
    public JwtDecoder createDecoder(DPoPProofContext dPoPProofContext) {
        Assert.notNull((Object)dPoPProofContext, (String)"dPoPProofContext cannot be null");
        NimbusJwtDecoder jwtDecoder = DPoPProofJwtDecoderFactory.buildDecoder();
        jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(dPoPProofContext));
        return jwtDecoder;
    }

    public void setJwtValidatorFactory(Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> jwtValidatorFactory) {
        Assert.notNull(jwtValidatorFactory, (String)"jwtValidatorFactory cannot be null");
        this.jwtValidatorFactory = jwtValidatorFactory;
    }

    private static NimbusJwtDecoder buildDecoder() {
        DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
        jwtProcessor.setJWSTypeVerifier(DPOP_TYPE_VERIFIER);
        jwtProcessor.setJWSKeySelector(DPoPProofJwtDecoderFactory.jwsKeySelector());
        jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
        return new NimbusJwtDecoder((JWTProcessor<SecurityContext>)jwtProcessor);
    }

    private static JWSKeySelector<SecurityContext> jwsKeySelector() {
        return (header, context) -> {
            JWSAlgorithm algorithm = header.getAlgorithm();
            if (!JWSAlgorithm.Family.RSA.contains((Object)algorithm) && !JWSAlgorithm.Family.EC.contains((Object)algorithm)) {
                throw new BadJwtException("Unsupported alg parameter in JWS Header: " + algorithm.getName());
            }
            JWK jwk = header.getJWK();
            if (jwk == null) {
                throw new BadJwtException("Missing jwk parameter in JWS Header.");
            }
            if (jwk.isPrivate()) {
                throw new BadJwtException("Invalid jwk parameter in JWS Header.");
            }
            try {
                if (JWSAlgorithm.Family.RSA.contains((Object)algorithm) && jwk instanceof RSAKey) {
                    RSAKey rsaKey = (RSAKey)jwk;
                    return Collections.singletonList(rsaKey.toRSAPublicKey());
                }
                if (JWSAlgorithm.Family.EC.contains((Object)algorithm) && jwk instanceof ECKey) {
                    ECKey ecKey = (ECKey)jwk;
                    return Collections.singletonList(ecKey.toECPublicKey());
                }
            }
            catch (JOSEException ex) {
                throw new BadJwtException("Invalid jwk parameter in JWS Header.");
            }
            throw new BadJwtException("Invalid alg / jwk parameter in JWS Header: alg=" + algorithm.getName() + ", jwk.kty=" + jwk.getKeyType().getValue());
        };
    }

    private static Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> defaultJwtValidatorFactory() {
        return context -> {
            OAuth2TokenValidator[] oAuth2TokenValidatorArray = new OAuth2TokenValidator[4];
            oAuth2TokenValidatorArray[0] = new JwtClaimValidator<Object>("htm", context.getMethod()::equals);
            oAuth2TokenValidatorArray[1] = new JwtClaimValidator<Object>("htu", context.getTargetUri()::equals);
            oAuth2TokenValidatorArray[2] = new JtiClaimValidator();
            oAuth2TokenValidatorArray[3] = new JwtIssuedAtValidator(true);
            return new DelegatingOAuth2TokenValidator(oAuth2TokenValidatorArray);
        };
    }

    private static final class JtiClaimValidator
    implements OAuth2TokenValidator<Jwt> {
        private static final Map<String, Long> JTI_CACHE = Collections.synchronizedMap(new JtiCache());

        private JtiClaimValidator() {
        }

        public OAuth2TokenValidatorResult validate(Jwt jwt) {
            String jtiHash;
            Assert.notNull((Object)jwt, (String)"DPoP proof jwt cannot be null");
            String jti = jwt.getId();
            if (!StringUtils.hasText((String)jti)) {
                OAuth2Error error = JtiClaimValidator.createOAuth2Error("jti claim is required.");
                return OAuth2TokenValidatorResult.failure((OAuth2Error[])new OAuth2Error[]{error});
            }
            try {
                jtiHash = JtiClaimValidator.computeSHA256(jti);
            }
            catch (Exception ex) {
                OAuth2Error error = JtiClaimValidator.createOAuth2Error("jti claim is invalid.");
                return OAuth2TokenValidatorResult.failure((OAuth2Error[])new OAuth2Error[]{error});
            }
            Instant expiry = Instant.now().plus(1L, ChronoUnit.HOURS);
            if (JTI_CACHE.putIfAbsent(jtiHash, expiry.toEpochMilli()) != null) {
                OAuth2Error error = JtiClaimValidator.createOAuth2Error("jti claim is invalid.");
                return OAuth2TokenValidatorResult.failure((OAuth2Error[])new OAuth2Error[]{error});
            }
            return OAuth2TokenValidatorResult.success();
        }

        private static OAuth2Error createOAuth2Error(String reason) {
            return new OAuth2Error("invalid_dpop_proof", reason, null);
        }

        private static String computeSHA256(String value) throws Exception {
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8));
            return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
        }

        private static final class JtiCache
        extends LinkedHashMap<String, Long> {
            private static final int MAX_SIZE = 1000;

            private JtiCache() {
            }

            @Override
            protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
                if (this.size() > 1000) {
                    return true;
                }
                Instant expiry = Instant.ofEpochMilli(eldest.getValue());
                return Instant.now().isAfter(expiry);
            }
        }
    }
}

