package org.ldaptive.transport;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.ldaptive.BindResponse;
import org.ldaptive.LdapException;
import org.ldaptive.LdapUtils;
import org.ldaptive.ResultCode;
import org.ldaptive.sasl.Mechanism;
import org.ldaptive.sasl.SaslBindRequest;
import org.ldaptive.sasl.SaslClient;
import org.ldaptive.sasl.ScramBindRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/ldaptive-2.1.0.jar:org/ldaptive/transport/ScramSaslClient.class */
public class ScramSaslClient implements SaslClient<ScramBindRequest> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) ScramSaslClient.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/ldaptive-2.1.0.jar:org/ldaptive/transport/ScramSaslClient$ClientFinalMessage.class */
    public static class ClientFinalMessage {
        private static final String GS2_NO_CHANNEL_BINDING = LdapUtils.base64Encode("n,,");
        private static final byte[] INTEGER_ONE = {0, 0, 0, 1};
        private static final byte[] CLIENT_KEY_INIT = "Client Key".getBytes(StandardCharsets.UTF_8);
        private final Mechanism mechanism;
        private final String withoutProof;
        private final String message;
        private final byte[] saltedPassword;

        ClientFinalMessage(Mechanism mechanism, String str, ClientFirstMessage clientFirstMessage, ServerFirstMessage serverFirstMessage) {
            this.mechanism = mechanism;
            this.saltedPassword = createSaltedPassword(this.mechanism.properties()[1], str, serverFirstMessage.getSalt(), serverFirstMessage.getIterations());
            this.withoutProof = "c=".concat(GS2_NO_CHANNEL_BINDING).concat(",").concat("r=").concat(serverFirstMessage.getCombinedNonce());
            this.message = clientFirstMessage.getMessage().concat(",").concat(serverFirstMessage.getMessage()).concat(",").concat(this.withoutProof);
        }

        public byte[] getSaltedPassword() {
            return this.saltedPassword;
        }

        public String getMessage() {
            return this.message;
        }

        public String encode() {
            byte[] doFinal = ScramSaslClient.createMac(this.mechanism.properties()[1], this.saltedPassword).doFinal(CLIENT_KEY_INIT);
            byte[] doFinal2 = ScramSaslClient.createMac(this.mechanism.properties()[1], ScramSaslClient.createDigest(this.mechanism.properties()[0], doFinal)).doFinal(this.message.getBytes(StandardCharsets.UTF_8));
            byte[] bArr = new byte[doFinal.length];
            for (int i = 0; i < bArr.length; i++) {
                bArr[i] = (byte) (doFinal[i] ^ doFinal2[i]);
            }
            return this.withoutProof.concat(",p=").concat(LdapUtils.base64Encode(bArr));
        }

        private static byte[] createSaltedPassword(String str, String str2, byte[] bArr, int i) {
            Mac createMac = ScramSaslClient.createMac(str, str2.getBytes(StandardCharsets.UTF_8));
            byte[] copyOf = Arrays.copyOf(bArr, bArr.length + INTEGER_ONE.length);
            System.arraycopy(INTEGER_ONE, 0, copyOf, bArr.length, INTEGER_ONE.length);
            byte[] doFinal = createMac.doFinal(copyOf);
            for (int i2 = 1; i2 < i; i2++) {
                byte[] doFinal2 = createMac.doFinal(doFinal);
                for (int i3 = 0; i3 < doFinal2.length; i3++) {
                    int i4 = i3;
                    doFinal[i4] = (byte) (doFinal[i4] ^ doFinal2[i3]);
                }
                doFinal = doFinal2;
            }
            return doFinal;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/ldaptive-2.1.0.jar:org/ldaptive/transport/ScramSaslClient$ClientFirstMessage.class */
    public static class ClientFirstMessage {
        private static final String GS2_NO_CHANNEL_BINDING = "n,,";
        private static final int DEFAULT_NONCE_SIZE = 16;
        private final String clientUsername;
        private final String clientNonce;
        private final String message;

        ClientFirstMessage(String str, byte[] bArr) {
            this.clientUsername = str;
            if (bArr == null) {
                byte[] bArr2 = new byte[16];
                new SecureRandom().nextBytes(bArr2);
                this.clientNonce = LdapUtils.base64Encode(bArr2);
            } else {
                this.clientNonce = LdapUtils.base64Encode(bArr);
            }
            this.message = "n=".concat(this.clientUsername).concat(",").concat("r=").concat(this.clientNonce);
        }

        public String getNonce() {
            return this.clientNonce;
        }

        public String getMessage() {
            return this.message;
        }

        public String encode() {
            return GS2_NO_CHANNEL_BINDING.concat(this.message);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/ldaptive-2.1.0.jar:org/ldaptive/transport/ScramSaslClient$ServerFinalMessage.class */
    public static class ServerFinalMessage {
        private static final byte[] SERVER_KEY_INIT = "Server Key".getBytes(StandardCharsets.UTF_8);
        private final String message;
        private final boolean verified;

        ServerFinalMessage(Mechanism mechanism, ClientFinalMessage clientFinalMessage, BindResponse bindResponse) {
            if (bindResponse.getServerSaslCreds() == null || bindResponse.getServerSaslCreds().length == 0) {
                throw new IllegalArgumentException("Bind response missing server SASL credentials");
            }
            this.message = new String(bindResponse.getServerSaslCreds(), StandardCharsets.UTF_8);
            Map map = (Map) Stream.of((Object[]) this.message.split(",")).map(str -> {
                return str.split("=", 2);
            }).collect(Collectors.toMap(strArr -> {
                return strArr[0];
            }, strArr2 -> {
                return strArr2[1];
            }));
            String str2 = (String) map.get("e");
            if (str2 != null) {
                ScramSaslClient.LOGGER.warn("SASL bind server final message included error: {}", str2);
            }
            if (bindResponse.getResultCode() != ResultCode.SUCCESS) {
                this.verified = false;
                return;
            }
            String str3 = (String) map.get("v");
            if (str3 == null) {
                throw new IllegalArgumentException("Invalid SASL credentials, missing server verification");
            }
            if (!LdapUtils.base64Encode(ScramSaslClient.createMac(mechanism.properties()[1], ScramSaslClient.createMac(mechanism.properties()[1], clientFinalMessage.getSaltedPassword()).doFinal(SERVER_KEY_INIT)).doFinal(clientFinalMessage.getMessage().getBytes(StandardCharsets.UTF_8))).equals(str3)) {
                throw new IllegalArgumentException("Invalid SASL credentials, incorrect server verification");
            }
            this.verified = true;
        }

        public boolean isVerified() {
            return this.verified;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/ldaptive-2.1.0.jar:org/ldaptive/transport/ScramSaslClient$ServerFirstMessage.class */
    public static class ServerFirstMessage {
        private static final int MINIMUM_ITERATION_COUNT = 4096;
        private final String message;
        private final String combinedNonce;
        private final byte[] salt;
        private final int iterations;

        ServerFirstMessage(ClientFirstMessage clientFirstMessage, BindResponse bindResponse) {
            if (bindResponse.getServerSaslCreds() == null || bindResponse.getServerSaslCreds().length == 0) {
                throw new IllegalArgumentException("Bind response missing server SASL credentials");
            }
            this.message = new String(bindResponse.getServerSaslCreds(), StandardCharsets.UTF_8);
            Map map = (Map) Stream.of((Object[]) this.message.split(",")).map(str -> {
                return str.split("=", 2);
            }).collect(Collectors.toMap(strArr -> {
                return strArr[0];
            }, strArr2 -> {
                return strArr2[1];
            }));
            String str2 = (String) map.get("r");
            if (str2 == null) {
                throw new IllegalArgumentException("Invalid SASL credentials, missing server nonce");
            }
            if (!str2.startsWith(clientFirstMessage.getNonce())) {
                throw new IllegalArgumentException("Invalid SASL credentials, missing client nonce");
            }
            this.combinedNonce = str2;
            String str3 = (String) map.get("s");
            if (str3 == null) {
                throw new IllegalArgumentException("Invalid SASL credentials, missing server salt");
            }
            this.salt = LdapUtils.base64Decode(str3);
            this.iterations = Integer.parseInt((String) map.get(WikipediaTokenizer.ITALICS));
            if (this.iterations < 4096) {
                throw new IllegalArgumentException("Invalid SASL credentials, iterations minimum value is 4096");
            }
        }

        public String getMessage() {
            return this.message;
        }

        public String getCombinedNonce() {
            return this.combinedNonce;
        }

        public byte[] getSalt() {
            return this.salt;
        }

        public int getIterations() {
            return this.iterations;
        }
    }

    @Override // org.ldaptive.sasl.SaslClient
    public BindResponse bind(TransportConnection transportConnection, ScramBindRequest scramBindRequest) throws LdapException {
        ClientFirstMessage clientFirstMessage = new ClientFirstMessage(scramBindRequest.getUsername(), scramBindRequest.getNonce());
        BindResponse execute = transportConnection.operation(new SaslBindRequest(scramBindRequest.getMechanism().mechanism(), clientFirstMessage.encode().getBytes(StandardCharsets.UTF_8))).execute();
        if (execute.getResultCode() != ResultCode.SASL_BIND_IN_PROGRESS) {
            if (execute.isSuccess()) {
                throw new IllegalStateException("Unexpected success result from SCRAM SASL bind: " + execute.getResultCode());
            }
            LOGGER.warn("Unexpected server result {}", execute);
            return execute;
        }
        ClientFinalMessage clientFinalMessage = new ClientFinalMessage(scramBindRequest.getMechanism(), scramBindRequest.getPassword(), clientFirstMessage, new ServerFirstMessage(clientFirstMessage, execute));
        BindResponse execute2 = transportConnection.operation(new SaslBindRequest(scramBindRequest.getMechanism().mechanism(), clientFinalMessage.encode().getBytes(StandardCharsets.UTF_8))).execute();
        ServerFinalMessage serverFinalMessage = new ServerFinalMessage(scramBindRequest.getMechanism(), clientFinalMessage, execute2);
        if (!execute2.isSuccess() && serverFinalMessage.isVerified()) {
            throw new IllegalStateException("Verified server message but result was not a success");
        }
        if (!execute2.isSuccess() || serverFinalMessage.isVerified()) {
            return execute2;
        }
        throw new IllegalStateException("Received success from server but message could not be verified");
    }

    private static Mac createMac(String str, byte[] bArr) {
        try {
            Mac mac = Mac.getInstance(str);
            mac.init(new SecretKeySpec(bArr, str));
            return mac;
        } catch (Exception e) {
            throw new IllegalStateException("Could not create MAC", e);
        }
    }

    private static byte[] createDigest(String str, byte[] bArr) {
        try {
            return MessageDigest.getInstance(str).digest(bArr);
        } catch (Exception e) {
            throw new IllegalStateException("Could not create digest", e);
        }
    }
}
