/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.security;

import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.FilterInputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.TextInputCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.kerberos.KerberosPrincipal;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.flink.hadoop2.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.flink.hadoop2.shaded.com.google.protobuf.ByteString;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.GlobPattern;
import org.apache.hadoop.ipc.Client;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.ResponseBuffer;
import org.apache.hadoop.ipc.RpcConstants;
import org.apache.hadoop.ipc.RpcWritable;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SaslPropertiesResolver;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.util.ProtoUtil;

@InterfaceAudience.LimitedPrivate(value={"HDFS", "MapReduce"})
@InterfaceStability.Evolving
public class SaslRpcClient {
    public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
    private final UserGroupInformation ugi;
    private final Class<?> protocol;
    private final InetSocketAddress serverAddr;
    private final Configuration conf;
    private SaslClient saslClient;
    private SaslPropertiesResolver saslPropsResolver;
    private SaslRpcServer.AuthMethod authMethod;
    private static final RpcHeaderProtos.RpcRequestHeaderProto saslHeader = ProtoUtil.makeRpcRequestHeader(RPC.RpcKind.RPC_PROTOCOL_BUFFER, RpcHeaderProtos.RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET, Server.AuthProtocol.SASL.callId, -1, RpcConstants.DUMMY_CLIENT_ID);
    private static final RpcHeaderProtos.RpcSaslProto negotiateRequest = RpcHeaderProtos.RpcSaslProto.newBuilder().setState(RpcHeaderProtos.RpcSaslProto.SaslState.NEGOTIATE).build();

    public SaslRpcClient(UserGroupInformation ugi, Class<?> protocol, InetSocketAddress serverAddr, Configuration conf) {
        this.ugi = ugi;
        this.protocol = protocol;
        this.serverAddr = serverAddr;
        this.conf = conf;
        this.saslPropsResolver = SaslPropertiesResolver.getInstance(conf);
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    public Object getNegotiatedProperty(String key) {
        return this.saslClient != null ? this.saslClient.getNegotiatedProperty(key) : null;
    }

    @InterfaceAudience.Private
    public SaslRpcServer.AuthMethod getAuthMethod() {
        return this.authMethod;
    }

    private RpcHeaderProtos.RpcSaslProto.SaslAuth selectSaslClient(List<RpcHeaderProtos.RpcSaslProto.SaslAuth> authTypes) throws SaslException, AccessControlException, IOException {
        RpcHeaderProtos.RpcSaslProto.SaslAuth selectedAuthType = null;
        boolean switchToSimple = false;
        for (RpcHeaderProtos.RpcSaslProto.SaslAuth authType : authTypes) {
            if (!this.isValidAuthType(authType)) continue;
            SaslRpcServer.AuthMethod authMethod = SaslRpcServer.AuthMethod.valueOf(authType.getMethod());
            if (authMethod == SaslRpcServer.AuthMethod.SIMPLE) {
                switchToSimple = true;
            } else {
                this.saslClient = this.createSaslClient(authType);
                if (this.saslClient == null) continue;
            }
            selectedAuthType = authType;
            break;
        }
        if (this.saslClient == null && !switchToSimple) {
            ArrayList<String> serverAuthMethods = new ArrayList<String>();
            for (RpcHeaderProtos.RpcSaslProto.SaslAuth authType : authTypes) {
                serverAuthMethods.add(authType.getMethod());
            }
            throw new AccessControlException("Client cannot authenticate via:" + serverAuthMethods);
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Use " + selectedAuthType.getMethod() + " authentication for protocol " + this.protocol.getSimpleName()));
        }
        return selectedAuthType;
    }

    private boolean isValidAuthType(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) {
        SaslRpcServer.AuthMethod authMethod;
        try {
            authMethod = SaslRpcServer.AuthMethod.valueOf(authType.getMethod());
        }
        catch (IllegalArgumentException iae) {
            authMethod = null;
        }
        return authMethod != null && authMethod.getMechanismName().equals(authType.getMechanism());
    }

    private SaslClient createSaslClient(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws SaslException, IOException {
        String saslUser = null;
        String saslProtocol = authType.getProtocol();
        String saslServerName = authType.getServerId();
        Map<String, String> saslProperties = this.saslPropsResolver.getClientProperties(this.serverAddr.getAddress());
        SaslClientCallbackHandler saslCallback = null;
        SaslRpcServer.AuthMethod method = SaslRpcServer.AuthMethod.valueOf(authType.getMethod());
        switch (method) {
            case TOKEN: {
                Token<?> token = this.getServerToken(authType);
                if (token == null) {
                    return null;
                }
                saslCallback = new SaslClientCallbackHandler(token);
                break;
            }
            case KERBEROS: {
                if (this.ugi.getRealAuthenticationMethod().getAuthMethod() != SaslRpcServer.AuthMethod.KERBEROS) {
                    return null;
                }
                String serverPrincipal = this.getServerPrincipal(authType);
                if (serverPrincipal == null) {
                    return null;
                }
                if (!LOG.isDebugEnabled()) break;
                LOG.debug((Object)("RPC Server's Kerberos principal name for protocol=" + this.protocol.getCanonicalName() + " is " + serverPrincipal));
                break;
            }
            default: {
                throw new IOException("Unknown authentication method " + (Object)((Object)method));
            }
        }
        String mechanism = method.getMechanismName();
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Creating SASL " + mechanism + "(" + (Object)((Object)method) + ") " + " client to authenticate to service at " + saslServerName));
        }
        return Sasl.createSaslClient(new String[]{mechanism}, saslUser, saslProtocol, saslServerName, saslProperties, saslCallback);
    }

    private Token<?> getServerToken(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws IOException {
        TokenInfo tokenInfo = SecurityUtil.getTokenInfo(this.protocol, this.conf);
        LOG.debug((Object)("Get token info proto:" + this.protocol + " info:" + tokenInfo));
        if (tokenInfo == null) {
            return null;
        }
        TokenSelector<? extends TokenIdentifier> tokenSelector = null;
        try {
            tokenSelector = tokenInfo.value().newInstance();
        }
        catch (InstantiationException e) {
            throw new IOException(e.toString());
        }
        catch (IllegalAccessException e) {
            throw new IOException(e.toString());
        }
        return tokenSelector.selectToken(SecurityUtil.buildTokenService(this.serverAddr), this.ugi.getTokens());
    }

    @VisibleForTesting
    String getServerPrincipal(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws IOException {
        KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(this.protocol, this.conf);
        LOG.debug((Object)("Get kerberos info proto:" + this.protocol + " info:" + krbInfo));
        if (krbInfo == null) {
            return null;
        }
        String serverKey = krbInfo.serverPrincipal();
        if (serverKey == null) {
            throw new IllegalArgumentException("Can't obtain server Kerberos config key from protocol=" + this.protocol.getCanonicalName());
        }
        String serverPrincipal = new KerberosPrincipal(authType.getProtocol() + "/" + authType.getServerId(), 3).getName();
        String serverKeyPattern = this.conf.get(serverKey + ".pattern");
        if (serverKeyPattern != null && !serverKeyPattern.isEmpty()) {
            Pattern pattern = GlobPattern.compile(serverKeyPattern);
            if (!pattern.matcher(serverPrincipal).matches()) {
                throw new IllegalArgumentException(String.format("Server has invalid Kerberos principal: %s, doesn't match the pattern: %s", serverPrincipal, serverKeyPattern));
            }
        } else {
            String confPrincipal = SecurityUtil.getServerPrincipal(this.conf.get(serverKey), this.serverAddr.getAddress());
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("getting serverKey: " + serverKey + " conf value: " + this.conf.get(serverKey) + " principal: " + confPrincipal));
            }
            if (confPrincipal == null || confPrincipal.isEmpty()) {
                throw new IllegalArgumentException("Failed to specify server's Kerberos principal name");
            }
            KerberosName name = new KerberosName(confPrincipal);
            if (name.getHostName() == null) {
                throw new IllegalArgumentException("Kerberos principal name does NOT have the expected hostname part: " + confPrincipal);
            }
            if (!serverPrincipal.equals(confPrincipal)) {
                throw new IllegalArgumentException(String.format("Server has invalid Kerberos principal: %s, expecting: %s", serverPrincipal, confPrincipal));
            }
        }
        return serverPrincipal;
    }

    public SaslRpcServer.AuthMethod saslConnect(Client.IpcStreams ipcStreams) throws IOException {
        this.authMethod = SaslRpcServer.AuthMethod.SIMPLE;
        this.sendSaslMessage(ipcStreams.out, negotiateRequest);
        boolean done = false;
        do {
            ByteBuffer bb = ipcStreams.readResponse();
            RpcWritable.Buffer saslPacket = RpcWritable.Buffer.wrap(bb);
            RpcHeaderProtos.RpcResponseHeaderProto header = saslPacket.getValue(RpcHeaderProtos.RpcResponseHeaderProto.getDefaultInstance());
            switch (header.getStatus()) {
                case ERROR: 
                case FATAL: {
                    throw new RemoteException(header.getExceptionClassName(), header.getErrorMsg());
                }
            }
            if (header.getCallId() != Server.AuthProtocol.SASL.callId) {
                throw new SaslException("Non-SASL response during negotiation");
            }
            RpcHeaderProtos.RpcSaslProto saslMessage = saslPacket.getValue(RpcHeaderProtos.RpcSaslProto.getDefaultInstance());
            if (saslPacket.remaining() > 0) {
                throw new SaslException("Received malformed response length");
            }
            RpcHeaderProtos.RpcSaslProto.Builder response = null;
            switch (saslMessage.getState()) {
                case NEGOTIATE: {
                    RpcHeaderProtos.RpcSaslProto.SaslAuth saslAuthType = this.selectSaslClient(saslMessage.getAuthsList());
                    this.authMethod = SaslRpcServer.AuthMethod.valueOf(saslAuthType.getMethod());
                    byte[] responseToken = null;
                    if (this.authMethod == SaslRpcServer.AuthMethod.SIMPLE) {
                        done = true;
                    } else {
                        byte[] challengeToken = null;
                        if (saslAuthType.hasChallenge()) {
                            challengeToken = saslAuthType.getChallenge().toByteArray();
                            saslAuthType = RpcHeaderProtos.RpcSaslProto.SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
                        } else if (this.saslClient.hasInitialResponse()) {
                            challengeToken = new byte[]{};
                        }
                        responseToken = challengeToken != null ? this.saslClient.evaluateChallenge(challengeToken) : new byte[]{};
                    }
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.INITIATE, responseToken);
                    response.addAuths(saslAuthType);
                    break;
                }
                case CHALLENGE: {
                    if (this.saslClient == null) {
                        throw new SaslException("Server sent unsolicited challenge");
                    }
                    byte[] responseToken = this.saslEvaluateToken(saslMessage, false);
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.RESPONSE, responseToken);
                    break;
                }
                case SUCCESS: {
                    if (this.saslClient == null) {
                        this.authMethod = SaslRpcServer.AuthMethod.SIMPLE;
                    } else {
                        this.saslEvaluateToken(saslMessage, true);
                    }
                    done = true;
                    break;
                }
                default: {
                    throw new SaslException("RPC client doesn't support SASL " + saslMessage.getState());
                }
            }
            if (response == null) continue;
            this.sendSaslMessage(ipcStreams.out, response.build());
        } while (!done);
        return this.authMethod;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void sendSaslMessage(OutputStream out, RpcHeaderProtos.RpcSaslProto message) throws IOException {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Sending sasl message " + message));
        }
        ResponseBuffer buf = new ResponseBuffer();
        saslHeader.writeDelimitedTo(buf);
        message.writeDelimitedTo(buf);
        OutputStream outputStream = out;
        synchronized (outputStream) {
            buf.writeTo(out);
            out.flush();
        }
    }

    private byte[] saslEvaluateToken(RpcHeaderProtos.RpcSaslProto saslResponse, boolean serverIsDone) throws SaslException {
        byte[] saslToken = null;
        if (saslResponse.hasToken()) {
            saslToken = saslResponse.getToken().toByteArray();
            saslToken = this.saslClient.evaluateChallenge(saslToken);
        } else if (!serverIsDone) {
            throw new SaslException("Server challenge contains no token");
        }
        if (serverIsDone) {
            if (!this.saslClient.isComplete()) {
                throw new SaslException("Client is out of sync with server");
            }
            if (saslToken != null) {
                throw new SaslException("Client generated spurious response");
            }
        }
        return saslToken;
    }

    private RpcHeaderProtos.RpcSaslProto.Builder createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState state, byte[] responseToken) {
        RpcHeaderProtos.RpcSaslProto.Builder response = RpcHeaderProtos.RpcSaslProto.newBuilder();
        response.setState(state);
        if (responseToken != null) {
            response.setToken(ByteString.copyFrom(responseToken));
        }
        return response;
    }

    private boolean useWrap() {
        String qop = (String)this.saslClient.getNegotiatedProperty("javax.security.sasl.qop");
        return qop != null && !"auth".equalsIgnoreCase(qop);
    }

    public InputStream getInputStream(InputStream in) throws IOException {
        if (this.useWrap()) {
            in = new WrappedInputStream(in);
        }
        return in;
    }

    public OutputStream getOutputStream(OutputStream out) throws IOException {
        if (this.useWrap()) {
            String maxBuf = (String)this.saslClient.getNegotiatedProperty("javax.security.sasl.rawsendsize");
            out = new BufferedOutputStream(new WrappedOutputStream(out), Integer.parseInt(maxBuf));
        }
        return out;
    }

    public void dispose() throws SaslException {
        if (this.saslClient != null) {
            this.saslClient.dispose();
            this.saslClient = null;
        }
    }

    private static class SaslClientCallbackHandler
    implements CallbackHandler {
        private final String userName;
        private final char[] userPassword;

        public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) {
            this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier());
            this.userPassword = SaslRpcServer.encodePassword(token.getPassword());
        }

        @Override
        public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
            NameCallback nc = null;
            PasswordCallback pc = null;
            TextInputCallback rc = null;
            for (Callback callback : callbacks) {
                if (callback instanceof RealmChoiceCallback) continue;
                if (callback instanceof NameCallback) {
                    nc = (NameCallback)callback;
                    continue;
                }
                if (callback instanceof PasswordCallback) {
                    pc = (PasswordCallback)callback;
                    continue;
                }
                if (callback instanceof RealmCallback) {
                    rc = (RealmCallback)callback;
                    continue;
                }
                throw new UnsupportedCallbackException(callback, "Unrecognized SASL client callback");
            }
            if (nc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("SASL client callback: setting username: " + this.userName));
                }
                nc.setName(this.userName);
            }
            if (pc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)"SASL client callback: setting userPassword");
                }
                pc.setPassword(this.userPassword);
            }
            if (rc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("SASL client callback: setting realm: " + rc.getDefaultText()));
                }
                rc.setText(rc.getDefaultText());
            }
        }
    }

    class WrappedOutputStream
    extends FilterOutputStream {
        public WrappedOutputStream(OutputStream out) throws IOException {
            super(out);
        }

        @Override
        public void write(byte[] buf, int off, int len) throws IOException {
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("wrapping token of length:" + len));
            }
            buf = SaslRpcClient.this.saslClient.wrap(buf, off, len);
            RpcHeaderProtos.RpcSaslProto saslMessage = RpcHeaderProtos.RpcSaslProto.newBuilder().setState(RpcHeaderProtos.RpcSaslProto.SaslState.WRAP).setToken(ByteString.copyFrom(buf, 0, buf.length)).build();
            SaslRpcClient.this.sendSaslMessage(this.out, saslMessage);
        }
    }

    class WrappedInputStream
    extends FilterInputStream {
        private ByteBuffer unwrappedRpcBuffer;

        public WrappedInputStream(InputStream in) throws IOException {
            super(in);
            this.unwrappedRpcBuffer = ByteBuffer.allocate(0);
        }

        @Override
        public int read() throws IOException {
            byte[] b = new byte[1];
            int n = this.read(b, 0, 1);
            return n != -1 ? b[0] : -1;
        }

        @Override
        public int read(byte[] b) throws IOException {
            return this.read(b, 0, b.length);
        }

        @Override
        public synchronized int read(byte[] buf, int off, int len) throws IOException {
            if (len == 0) {
                return 0;
            }
            if (this.unwrappedRpcBuffer.remaining() == 0) {
                this.readNextRpcPacket();
            }
            int readLen = Math.min(len, this.unwrappedRpcBuffer.remaining());
            this.unwrappedRpcBuffer.get(buf, off, readLen);
            return readLen;
        }

        private void readNextRpcPacket() throws IOException {
            LOG.debug((Object)"reading next wrapped RPC packet");
            DataInputStream dis = new DataInputStream(this.in);
            int rpcLen = dis.readInt();
            byte[] rpcBuf = new byte[rpcLen];
            dis.readFully(rpcBuf);
            ByteArrayInputStream bis = new ByteArrayInputStream(rpcBuf);
            RpcHeaderProtos.RpcResponseHeaderProto.Builder headerBuilder = RpcHeaderProtos.RpcResponseHeaderProto.newBuilder();
            headerBuilder.mergeDelimitedFrom(bis);
            boolean isWrapped = false;
            if (headerBuilder.getCallId() == Server.AuthProtocol.SASL.callId) {
                RpcHeaderProtos.RpcSaslProto.Builder saslMessage = RpcHeaderProtos.RpcSaslProto.newBuilder();
                saslMessage.mergeDelimitedFrom(bis);
                if (saslMessage.getState() == RpcHeaderProtos.RpcSaslProto.SaslState.WRAP) {
                    isWrapped = true;
                    byte[] token = saslMessage.getToken().toByteArray();
                    if (LOG.isDebugEnabled()) {
                        LOG.debug((Object)("unwrapping token of length:" + token.length));
                    }
                    token = SaslRpcClient.this.saslClient.unwrap(token, 0, token.length);
                    this.unwrappedRpcBuffer = ByteBuffer.wrap(token);
                }
            }
            if (!isWrapped) {
                throw new SaslException("Server sent non-wrapped response");
            }
        }
    }
}

