/*
 * Decompiled with CFR 0.152.
 */
package com.couchbase.client.core.io.netty.kv;

import com.couchbase.client.core.cnc.events.io.SaslAuthenticationCompletedEvent;
import com.couchbase.client.core.cnc.events.io.SaslAuthenticationFailedEvent;
import com.couchbase.client.core.cnc.events.io.SaslMechanismsSelectedEvent;
import com.couchbase.client.core.deps.com.fasterxml.jackson.core.type.TypeReference;
import com.couchbase.client.core.deps.io.netty.buffer.ByteBuf;
import com.couchbase.client.core.deps.io.netty.buffer.Unpooled;
import com.couchbase.client.core.deps.io.netty.channel.ChannelDuplexHandler;
import com.couchbase.client.core.deps.io.netty.channel.ChannelHandlerContext;
import com.couchbase.client.core.deps.io.netty.channel.ChannelPromise;
import com.couchbase.client.core.deps.io.netty.util.ReferenceCountUtil;
import com.couchbase.client.core.deps.io.netty.util.concurrent.Future;
import com.couchbase.client.core.deps.io.netty.util.concurrent.GenericFutureListener;
import com.couchbase.client.core.endpoint.EndpointContext;
import com.couchbase.client.core.env.SaslMechanism;
import com.couchbase.client.core.error.AuthenticationFailureException;
import com.couchbase.client.core.error.context.KeyValueIoErrorContext;
import com.couchbase.client.core.io.IoContext;
import com.couchbase.client.core.io.netty.kv.ConnectTimings;
import com.couchbase.client.core.io.netty.kv.MemcacheProtocol;
import com.couchbase.client.core.io.netty.kv.sasl.CouchbaseSaslClientFactory;
import com.couchbase.client.core.json.Mapper;
import com.couchbase.client.core.msg.kv.BaseKeyValueRequest;
import com.couchbase.client.core.util.Bytes;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

public class SaslAuthenticationHandler
extends ChannelDuplexHandler
implements CallbackHandler {
    private static final short STATUS_AUTH_ERROR = 32;
    private static final short STATUS_AUTH_CONTINUE = 33;
    private final Duration timeout;
    private final String username;
    private final String password;
    private final Set<SaslMechanism> allowedMechanisms;
    private final EndpointContext endpointContext;
    private IoContext ioContext;
    private SaslClient saslClient;
    private ChannelPromise interceptedConnectPromise;

    public SaslAuthenticationHandler(EndpointContext endpointContext, String username, String password, Set<SaslMechanism> allowedSaslMechanisms) {
        this.endpointContext = endpointContext;
        this.username = username;
        this.password = password;
        this.allowedMechanisms = allowedSaslMechanisms;
        this.timeout = endpointContext.environment().timeoutConfig().connectTimeout();
    }

    @Override
    public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
        this.interceptedConnectPromise = promise;
        ChannelPromise downstream = ctx.newPromise();
        downstream.addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)f -> {
            if (!f.isSuccess() && !this.interceptedConnectPromise.isDone()) {
                ConnectTimings.record(ctx.channel(), this.getClass());
                this.interceptedConnectPromise.tryFailure(f.cause());
            }
        }));
        ctx.connect(remoteAddress, localAddress, downstream);
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) {
        this.ioContext = new IoContext(this.endpointContext, ctx.channel().localAddress(), ctx.channel().remoteAddress(), this.endpointContext.bucket());
        ctx.executor().schedule(() -> {
            if (!this.interceptedConnectPromise.isDone()) {
                ConnectTimings.stop(ctx.channel(), this.getClass(), true);
                this.interceptedConnectPromise.tryFailure(new TimeoutException("KV SASL Negotiation timed out after " + this.timeout.toMillis() + "ms"));
            }
        }, this.timeout.toNanos(), TimeUnit.NANOSECONDS);
        ConnectTimings.start(ctx.channel(), this.getClass());
        ctx.writeAndFlush(this.buildListMechanismsRequest(ctx));
    }

    private ByteBuf buildListMechanismsRequest(ChannelHandlerContext ctx) {
        return MemcacheProtocol.request(ctx.alloc(), MemcacheProtocol.Opcode.SASL_LIST_MECHS, MemcacheProtocol.noDatatype(), MemcacheProtocol.noPartition(), BaseKeyValueRequest.nextOpaque(), MemcacheProtocol.noCas(), MemcacheProtocol.noExtras(), MemcacheProtocol.noKey(), MemcacheProtocol.noBody());
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        if (msg instanceof ByteBuf) {
            ByteBuf response = (ByteBuf)msg;
            if (MemcacheProtocol.successful(response) || MemcacheProtocol.status(response) == 33) {
                byte opcode = MemcacheProtocol.opcode(response);
                try {
                    if (MemcacheProtocol.Opcode.SASL_LIST_MECHS.opcode() == opcode) {
                        this.handleListMechsResponse(ctx, (ByteBuf)msg);
                    } else if (MemcacheProtocol.Opcode.SASL_AUTH.opcode() == opcode) {
                        this.handleAuthResponse(ctx, (ByteBuf)msg);
                    } else if (MemcacheProtocol.Opcode.SASL_STEP.opcode() == opcode) {
                        this.completeAuth(ctx);
                    }
                }
                catch (Exception ex) {
                    this.failConnect(ctx, "Unexpected error during SASL auth", response, ex, MemcacheProtocol.status(response));
                }
            } else if (32 == MemcacheProtocol.status(response)) {
                this.failConnect(ctx, "Authentication Failure", response, null, MemcacheProtocol.status(response));
            } else {
                this.failConnect(ctx, "Unexpected Status 0x" + Integer.toHexString(MemcacheProtocol.status(response)) + " during SASL auth", response, null, MemcacheProtocol.status(response));
            }
        } else {
            this.failConnect(ctx, "Unexpected response type on channel read, this is a bug - please report. " + msg, null, null, (short)0);
        }
        ReferenceCountUtil.release(msg);
    }

    private void handleListMechsResponse(ChannelHandlerContext ctx, ByteBuf response) {
        Object[] serverMechanisms = MemcacheProtocol.body(response).orElse(Unpooled.EMPTY_BUFFER).toString(StandardCharsets.UTF_8).split(" ");
        Set<SaslMechanism> usedMechs = this.allowedMechanisms.stream().filter(arg_0 -> SaslAuthenticationHandler.lambda$handleListMechsResponse$2((String[])serverMechanisms, arg_0)).collect(Collectors.toSet());
        if (usedMechs.isEmpty()) {
            this.failConnect(ctx, "Could not negotiate SASL mechanism with server. If you are using LDAP you must eitherconnect via TLS (recommended), or enable PLAIN to the allowed SASL mechanism list on the PasswordAuthenticator(this is insecure and will present the user credentials in plain-text over the wire).", response, null, MemcacheProtocol.status(response));
            return;
        }
        try {
            this.saslClient = this.createSaslClient(usedMechs);
            this.endpointContext.environment().eventBus().publish(new SaslMechanismsSelectedEvent(this.ioContext, Stream.of(serverMechanisms).map(SaslMechanism::from).collect(Collectors.toSet()), this.allowedMechanisms, SaslMechanism.from(this.saslClient.getMechanismName())));
            ctx.writeAndFlush(this.buildAuthRequest(ctx));
        }
        catch (SaslException e) {
            this.failConnect(ctx, "SASL Client could not be constructed. Server Mechanisms: " + Arrays.toString(serverMechanisms), response, e, MemcacheProtocol.status(response));
        }
    }

    private ByteBuf buildAuthRequest(ChannelHandlerContext ctx) throws SaslException {
        byte[] payload = this.saslClient.hasInitialResponse() ? this.saslClient.evaluateChallenge(Bytes.EMPTY_BYTE_ARRAY) : null;
        ByteBuf body = payload != null ? ctx.alloc().buffer().writeBytes(payload) : Unpooled.EMPTY_BUFFER;
        ByteBuf key = Unpooled.copiedBuffer(this.saslClient.getMechanismName(), StandardCharsets.UTF_8);
        ByteBuf request = MemcacheProtocol.request(ctx.alloc(), MemcacheProtocol.Opcode.SASL_AUTH, MemcacheProtocol.noDatatype(), MemcacheProtocol.noPartition(), BaseKeyValueRequest.nextOpaque(), MemcacheProtocol.noCas(), MemcacheProtocol.noExtras(), key, body);
        key.release();
        body.release();
        return request;
    }

    private SaslClient createSaslClient(Set<SaslMechanism> selected) throws SaslException {
        return new CouchbaseSaslClientFactory().createSaslClient((String[])selected.stream().map(SaslMechanism::mech).toArray(String[]::new), null, "couchbase", this.ioContext.remoteSocket().toString(), null, this);
    }

    private void handleAuthResponse(ChannelHandlerContext ctx, ByteBuf response) {
        if (this.saslClient.isComplete()) {
            this.completeAuth(ctx);
            return;
        }
        ByteBuf responseBody = MemcacheProtocol.body(response).orElse(Unpooled.EMPTY_BUFFER);
        byte[] payload = new byte[responseBody.readableBytes()];
        responseBody.readBytes(payload);
        try {
            byte[] evaluatedBytes = this.saslClient.evaluateChallenge(payload);
            if (evaluatedBytes == null || evaluatedBytes.length <= 0) {
                throw new SaslException("Evaluation returned empty payload, this is unexpected!");
            }
            ctx.writeAndFlush(this.buildStepRequest(ctx, evaluatedBytes));
        }
        catch (SaslException e) {
            this.failConnect(ctx, "Failure while evaluating SASL Auth Response.", response, e, MemcacheProtocol.status(response));
        }
    }

    private ByteBuf buildStepRequest(ChannelHandlerContext ctx, byte[] evaluatedBytes) {
        ByteBuf body;
        String mech = this.saslClient.getMechanismName();
        if (mech.equalsIgnoreCase(SaslMechanism.PLAIN.mech())) {
            String[] evaluated = new String(evaluatedBytes, StandardCharsets.UTF_8).split(" ");
            body = Unpooled.copiedBuffer(this.username + "\u0000" + evaluated[1], StandardCharsets.UTF_8);
        } else {
            body = Unpooled.wrappedBuffer(evaluatedBytes);
        }
        ByteBuf key = Unpooled.copiedBuffer(mech, StandardCharsets.UTF_8);
        ByteBuf request = MemcacheProtocol.request(ctx.alloc(), MemcacheProtocol.Opcode.SASL_STEP, MemcacheProtocol.noDatatype(), MemcacheProtocol.noPartition(), BaseKeyValueRequest.nextOpaque(), MemcacheProtocol.noCas(), MemcacheProtocol.noExtras(), key, body);
        key.release();
        body.release();
        return request;
    }

    private void completeAuth(ChannelHandlerContext ctx) {
        Optional<Duration> latency = ConnectTimings.stop(ctx.channel(), this.getClass(), false);
        this.endpointContext.environment().eventBus().publish(new SaslAuthenticationCompletedEvent(latency.orElse(Duration.ZERO), this.ioContext));
        this.interceptedConnectPromise.trySuccess();
        ctx.pipeline().remove(this);
        ctx.fireChannelActive();
    }

    private void failConnect(ChannelHandlerContext ctx, String message, ByteBuf lastPacket, Throwable cause, short status) {
        Optional<Duration> latency = ConnectTimings.stop(ctx.channel(), this.getClass(), false);
        byte[] packetCopy = Bytes.EMPTY_BYTE_ARRAY;
        Map<String, Object> serverContext = null;
        if (lastPacket != null) {
            if (MemcacheProtocol.verifyResponse(lastPacket)) {
                Optional<ByteBuf> body = MemcacheProtocol.body(lastPacket);
                if (body.isPresent()) {
                    byte[] content = new byte[body.get().readableBytes()];
                    body.get().readBytes(content);
                    try {
                        serverContext = Mapper.decodeInto(content, new TypeReference<Map<String, Object>>(){});
                    }
                    catch (Exception exception) {}
                }
            } else {
                int ridx = lastPacket.readerIndex();
                lastPacket.readerIndex(lastPacket.writerIndex());
                packetCopy = new byte[lastPacket.readableBytes()];
                lastPacket.readBytes(packetCopy);
                lastPacket.readerIndex(ridx);
            }
        }
        KeyValueIoErrorContext errorContext = new KeyValueIoErrorContext(MemcacheProtocol.decodeStatus(status), this.endpointContext, serverContext);
        this.endpointContext.environment().eventBus().publish(new SaslAuthenticationFailedEvent(latency.orElse(Duration.ZERO), errorContext, message, packetCopy));
        this.interceptedConnectPromise.tryFailure(new AuthenticationFailureException(message, errorContext, cause));
    }

    @Override
    public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
        for (Callback callback : callbacks) {
            if (callback instanceof NameCallback) {
                ((NameCallback)callback).setName(this.username);
                continue;
            }
            if (callback instanceof PasswordCallback) {
                ((PasswordCallback)callback).setPassword(this.password.toCharArray());
                continue;
            }
            throw new UnsupportedCallbackException(callback, "Unexpected/Unsupported Callback");
        }
    }

    private static /* synthetic */ boolean lambda$handleListMechsResponse$2(String[] serverMechanisms, SaslMechanism m) {
        return Arrays.asList(serverMechanisms).contains(m.mech());
    }
}

