package org.eclipse.milo.opcua.stack.server.transport.uasc;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.Timeout;
import java.io.IOException;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.milo.opcua.stack.core.Stack;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.UaSerializationException;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.ChunkDecoder;
import org.eclipse.milo.opcua.stack.core.channel.ChunkEncoder;
import org.eclipse.milo.opcua.stack.core.channel.ExceptionHandler;
import org.eclipse.milo.opcua.stack.core.channel.MessageAbortException;
import org.eclipse.milo.opcua.stack.core.channel.MessageDecodeException;
import org.eclipse.milo.opcua.stack.core.channel.MessageEncodeException;
import org.eclipse.milo.opcua.stack.core.channel.SerializationQueue;
import org.eclipse.milo.opcua.stack.core.channel.ServerSecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.HeaderDecoder;
import org.eclipse.milo.opcua.stack.core.channel.messages.ErrorMessage;
import org.eclipse.milo.opcua.stack.core.channel.messages.MessageType;
import org.eclipse.milo.opcua.stack.core.security.CertificateManager;
import org.eclipse.milo.opcua.stack.core.security.SecurityPolicy;
import org.eclipse.milo.opcua.stack.core.transport.TransportProfile;
import org.eclipse.milo.opcua.stack.core.types.builtin.ByteString;
import org.eclipse.milo.opcua.stack.core.types.builtin.DateTime;
import org.eclipse.milo.opcua.stack.core.types.builtin.DiagnosticInfo;
import org.eclipse.milo.opcua.stack.core.types.builtin.ExtensionObject;
import org.eclipse.milo.opcua.stack.core.types.builtin.StatusCode;
import org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.Unsigned;
import org.eclipse.milo.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import org.eclipse.milo.opcua.stack.core.types.structured.ChannelSecurityToken;
import org.eclipse.milo.opcua.stack.core.types.structured.EndpointDescription;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import org.eclipse.milo.opcua.stack.core.types.structured.ResponseHeader;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.EndpointUtil;
import org.eclipse.milo.opcua.stack.core.util.NonceUtil;
import org.eclipse.milo.opcua.stack.server.UaStackServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/eclipse/milo/opcua/stack/server/transport/uasc/UascServerAsymmetricHandler.class */
public class UascServerAsymmetricHandler extends ByteToMessageDecoder implements HeaderDecoder {
    static final AttributeKey<EndpointDescription> ENDPOINT_KEY = AttributeKey.valueOf("endpoint");
    private ServerSecureChannel secureChannel;
    private Timeout secureChannelTimeout;
    private final int maxChunkCount;
    private final int maxChunkSize;
    private final UaStackServer stackServer;
    private final TransportProfile transportProfile;
    private final SerializationQueue serializationQueue;
    private final Logger logger = LoggerFactory.getLogger(getClass());
    private boolean symmetricHandlerAdded = false;
    private List<ByteBuf> chunkBuffers = new ArrayList();
    private final AtomicReference<AsymmetricSecurityHeader> headerRef = new AtomicReference<>();

    /* renamed from: org.eclipse.milo.opcua.stack.server.transport.uasc.UascServerAsymmetricHandler$1, reason: invalid class name */
    /* loaded from: input_file:org/eclipse/milo/opcua/stack/server/transport/uasc/UascServerAsymmetricHandler$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$eclipse$milo$opcua$stack$core$channel$messages$MessageType = new int[MessageType.values().length];

        static {
            try {
                $SwitchMap$org$eclipse$milo$opcua$stack$core$channel$messages$MessageType[MessageType.OpenSecureChannel.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$eclipse$milo$opcua$stack$core$channel$messages$MessageType[MessageType.CloseSecureChannel.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public UascServerAsymmetricHandler(UaStackServer uaStackServer, TransportProfile transportProfile, SerializationQueue serializationQueue) {
        this.stackServer = uaStackServer;
        this.transportProfile = transportProfile;
        this.serializationQueue = serializationQueue;
        this.maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
        this.maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
    }

    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        int messageLength;
        if (byteBuf.readableBytes() < 8 || byteBuf.readableBytes() < (messageLength = getMessageLength(byteBuf, this.maxChunkSize))) {
            return;
        }
        MessageType fromMediumInt = MessageType.fromMediumInt(byteBuf.getMediumLE(byteBuf.readerIndex()));
        switch (AnonymousClass1.$SwitchMap$org$eclipse$milo$opcua$stack$core$channel$messages$MessageType[fromMediumInt.ordinal()]) {
            case 1:
                onOpenSecureChannel(channelHandlerContext, byteBuf.readSlice(messageLength));
                return;
            case 2:
                this.logger.debug("Received CloseSecureChannelRequest");
                byteBuf.skipBytes(messageLength);
                if (this.secureChannelTimeout != null) {
                    this.secureChannelTimeout.cancel();
                    this.secureChannelTimeout = null;
                }
                channelHandlerContext.close();
                return;
            default:
                throw new UaException(2155741184L, "unexpected MessageType: " + fromMediumInt);
        }
    }

    private void onOpenSecureChannel(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf) throws UaException {
        byteBuf.skipBytes(3);
        char readByte = (char) byteBuf.readByte();
        if (readByte == 'A') {
            this.chunkBuffers.forEach((v0) -> {
                v0.release();
            });
            this.chunkBuffers.clear();
            this.headerRef.set(null);
            return;
        }
        byteBuf.skipBytes(4);
        long readUnsignedIntLE = byteBuf.readUnsignedIntLE();
        AsymmetricSecurityHeader decode = AsymmetricSecurityHeader.decode(byteBuf, this.stackServer.getConfig().getEncodingLimits());
        if (!this.headerRef.compareAndSet(null, decode) && !decode.equals(this.headerRef.get())) {
            throw new UaException(2148728832L, "subsequent AsymmetricSecurityHeader did not match");
        }
        if (readUnsignedIntLE != 0) {
            if (this.secureChannel == null) {
                throw new UaException(2155806720L, "unknown secure channel id: " + readUnsignedIntLE);
            }
            if (readUnsignedIntLE != this.secureChannel.getChannelId()) {
                throw new UaException(2155806720L, "unknown secure channel id: " + readUnsignedIntLE);
            }
        }
        if (this.secureChannel == null) {
            this.secureChannel = new ServerSecureChannel();
            this.secureChannel.setChannelId(this.stackServer.getNextChannelId());
            SecurityPolicy fromUri = SecurityPolicy.fromUri(decode.getSecurityPolicyUri());
            this.secureChannel.setSecurityPolicy(fromUri);
            if (fromUri != SecurityPolicy.None) {
                this.secureChannel.setRemoteCertificate(decode.getSenderCertificate().bytesOrEmpty());
                this.stackServer.getConfig().getCertificateValidator().validateCertificateChain(this.secureChannel.getRemoteCertificateChain());
                CertificateManager certificateManager = this.stackServer.getConfig().getCertificateManager();
                Optional certificateChain = certificateManager.getCertificateChain(decode.getReceiverThumbprint());
                Optional keyPair = certificateManager.getKeyPair(decode.getReceiverThumbprint());
                if (!certificateChain.isPresent() || !keyPair.isPresent()) {
                    throw new UaException(2148728832L, "no certificate for provided thumbprint");
                }
                X509Certificate[] x509CertificateArr = (X509Certificate[]) certificateChain.get();
                this.secureChannel.setLocalCertificate(x509CertificateArr[0]);
                this.secureChannel.setLocalCertificateChain(x509CertificateArr);
                this.secureChannel.setKeyPair((KeyPair) keyPair.get());
            }
        }
        if (byteBuf.readerIndex(0).readableBytes() > this.maxChunkSize) {
            throw new UaException(2155872256L, String.format("max chunk size exceeded (%s)", Integer.valueOf(this.maxChunkSize)));
        }
        this.chunkBuffers.add(byteBuf.retain());
        if (this.maxChunkCount > 0 && this.chunkBuffers.size() > this.maxChunkCount) {
            throw new UaException(2155872256L, String.format("max chunk count exceeded (%s)", Integer.valueOf(this.maxChunkCount)));
        }
        if (readByte == 'F') {
            List<ByteBuf> list = this.chunkBuffers;
            this.chunkBuffers = new ArrayList();
            this.headerRef.set(null);
            this.serializationQueue.decode((opcUaBinaryStreamDecoder, chunkDecoder) -> {
                try {
                    ChunkDecoder.DecodedMessage decodeAsymmetric = chunkDecoder.decodeAsymmetric(this.secureChannel, list);
                    ByteBuf message = decodeAsymmetric.getMessage();
                    long requestId = decodeAsymmetric.getRequestId();
                    try {
                        try {
                            OpenSecureChannelRequest openSecureChannelRequest = (OpenSecureChannelRequest) opcUaBinaryStreamDecoder.setBuffer(message).readMessage((String) null);
                            this.logger.debug("Received OpenSecureChannelRequest ({}, id={}).", openSecureChannelRequest.getRequestType(), Long.valueOf(readUnsignedIntLE));
                            sendOpenSecureChannelResponse(channelHandlerContext, requestId, openSecureChannelRequest);
                            message.release();
                            list.clear();
                        } catch (Throwable th) {
                            this.logger.error("Error decoding OpenSecureChannelRequest", th);
                            channelHandlerContext.close();
                            message.release();
                            list.clear();
                        }
                    } catch (Throwable th2) {
                        message.release();
                        list.clear();
                        throw th2;
                    }
                } catch (MessageAbortException e) {
                    this.logger.warn("Received message abort chunk; error={}, reason={}", e.getStatusCode(), e.getMessage());
                } catch (MessageDecodeException e2) {
                    this.logger.error("Error decoding asymmetric message", e2);
                    channelHandlerContext.close();
                }
            });
        }
    }

    private void sendOpenSecureChannelResponse(ChannelHandlerContext channelHandlerContext, long j, OpenSecureChannelRequest openSecureChannelRequest) {
        this.serializationQueue.encode((opcUaBinaryStreamEncoder, chunkEncoder) -> {
            ByteBuf pooledBuffer = BufferUtil.pooledBuffer();
            try {
                try {
                    try {
                        try {
                            OpenSecureChannelResponse openSecureChannel = openSecureChannel(channelHandlerContext, openSecureChannelRequest);
                            opcUaBinaryStreamEncoder.setBuffer(pooledBuffer);
                            opcUaBinaryStreamEncoder.writeMessage((String) null, openSecureChannel);
                            checkMessageSize(pooledBuffer);
                            ChunkEncoder.EncodedMessage encodeAsymmetric = chunkEncoder.encodeAsymmetric(this.secureChannel, j, pooledBuffer, MessageType.OpenSecureChannel);
                            if (!this.symmetricHandlerAdded) {
                                channelHandlerContext.pipeline().addBefore(channelHandlerContext.name(), (String) null, new UascServerSymmetricHandler(this.stackServer, this.serializationQueue, this.secureChannel));
                                this.symmetricHandlerAdded = true;
                            }
                            CompositeByteBuf compositeBuffer = BufferUtil.compositeBuffer();
                            for (ByteBuf byteBuf : encodeAsymmetric.getMessageChunks()) {
                                compositeBuffer.addComponent(byteBuf);
                                compositeBuffer.writerIndex(compositeBuffer.writerIndex() + byteBuf.readableBytes());
                            }
                            channelHandlerContext.writeAndFlush(compositeBuffer, channelHandlerContext.voidPromise());
                            this.logger.debug("Sent OpenSecureChannelResponse.");
                            pooledBuffer.release();
                        } catch (UaException e) {
                            this.logger.error("Error installing security token: {}", e.getStatusCode(), e);
                            channelHandlerContext.close();
                            pooledBuffer.release();
                        }
                    } catch (UaSerializationException e2) {
                        this.logger.error("Error serializing OpenSecureChannelResponse: {}", e2.getMessage(), e2);
                        channelHandlerContext.fireExceptionCaught(e2);
                        pooledBuffer.release();
                    }
                } catch (MessageEncodeException e3) {
                    this.logger.error("Error encoding OpenSecureChannelResponse: {}", e3.getMessage(), e3);
                    channelHandlerContext.fireExceptionCaught(e3);
                    pooledBuffer.release();
                }
            } catch (Throwable th) {
                pooledBuffer.release();
                throw th;
            }
        });
    }

    private OpenSecureChannelResponse openSecureChannel(ChannelHandlerContext channelHandlerContext, OpenSecureChannelRequest openSecureChannelRequest) throws UaException {
        SecurityTokenRequestType requestType = openSecureChannelRequest.getRequestType();
        if (requestType == SecurityTokenRequestType.Issue) {
            this.secureChannel.setMessageSecurityMode(openSecureChannelRequest.getSecurityMode());
            String str = (String) channelHandlerContext.channel().attr(UascServerHelloHandler.ENDPOINT_URL_KEY).get();
            channelHandlerContext.channel().attr(ENDPOINT_KEY).set((EndpointDescription) this.stackServer.getEndpointDescriptions().stream().filter(endpointDescription -> {
                return Objects.equals(endpointDescription.getTransportProfileUri(), this.transportProfile.getUri()) && Objects.equals(EndpointUtil.getPath(endpointDescription.getEndpointUrl()), EndpointUtil.getPath(str)) && Objects.equals(endpointDescription.getSecurityPolicyUri(), this.secureChannel.getSecurityPolicy().getUri()) && Objects.equals(endpointDescription.getSecurityMode(), openSecureChannelRequest.getSecurityMode());
            }).findFirst().orElseThrow(() -> {
                return new UaException(2148728832L, String.format("no matching endpoint found: transportProfile=%s, endpointUrl=%s, securityPolicy=%s, securityMode=%s", this.transportProfile, str, this.secureChannel.getSecurityPolicy(), openSecureChannelRequest.getSecurityMode()));
            }));
        }
        if (requestType == SecurityTokenRequestType.Renew && this.secureChannel.getMessageSecurityMode() != openSecureChannelRequest.getSecurityMode()) {
            throw new UaException(2148728832L, "secure channel renewal requested a different MessageSecurityMode.");
        }
        long max = Math.max(Math.min(openSecureChannelRequest.getRequestedLifetime().longValue(), this.stackServer.getConfig().getMaximumSecureChannelLifetime().longValue()), this.stackServer.getConfig().getMinimumSecureChannelLifetime().longValue());
        ChannelSecurityToken channelSecurityToken = new ChannelSecurityToken(Unsigned.uint(this.secureChannel.getChannelId()), Unsigned.uint(this.stackServer.getNextTokenId()), DateTime.now(), Unsigned.uint(max));
        ChannelSecurity.SecurityKeys securityKeys = null;
        if (this.secureChannel.isSymmetricSigningEnabled()) {
            ByteString clientNonce = openSecureChannelRequest.getClientNonce();
            NonceUtil.validateNonce(clientNonce, this.secureChannel.getSecurityPolicy());
            this.secureChannel.setLocalNonce(NonceUtil.generateNonce(this.secureChannel.getSecurityPolicy()));
            this.secureChannel.setRemoteNonce(clientNonce);
            securityKeys = ChannelSecurity.generateKeyPair(this.secureChannel, this.secureChannel.getRemoteNonce(), this.secureChannel.getLocalNonce());
        }
        ChannelSecurity channelSecurity = this.secureChannel.getChannelSecurity();
        this.secureChannel.setChannelSecurity(new ChannelSecurity(securityKeys, channelSecurityToken, channelSecurity != null ? channelSecurity.getCurrentKeys() : null, channelSecurity != null ? channelSecurity.getCurrentToken() : null));
        if (this.secureChannelTimeout == null || this.secureChannelTimeout.cancel()) {
            this.secureChannelTimeout = Stack.sharedWheelTimer().newTimeout(timeout -> {
                this.logger.debug("SecureChannel renewal timed out after {}ms. id={}, channel={}", new Object[]{Long.valueOf(max), Long.valueOf(this.secureChannel.getChannelId()), channelHandlerContext.channel()});
                channelHandlerContext.close();
            }, max, TimeUnit.MILLISECONDS);
        }
        return new OpenSecureChannelResponse(new ResponseHeader(DateTime.now(), openSecureChannelRequest.getRequestHeader().getRequestHandle(), StatusCode.GOOD, (DiagnosticInfo) null, (String[]) null, (ExtensionObject) null), Unsigned.uint(0L), channelSecurityToken, this.secureChannel.getLocalNonce());
    }

    private void checkMessageSize(ByteBuf byteBuf) throws UaSerializationException {
        int readableBytes = byteBuf.readableBytes();
        int remoteMaxMessageSize = this.serializationQueue.getParameters().getRemoteMaxMessageSize();
        if (remoteMaxMessageSize > 0 && readableBytes > remoteMaxMessageSize) {
            throw new UaSerializationException(2159607808L, "response exceeds remote max message size: " + readableBytes + " > " + remoteMaxMessageSize);
        }
    }

    public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
        this.chunkBuffers.forEach((v0) -> {
            ReferenceCountUtil.safeRelease(v0);
        });
        this.chunkBuffers.clear();
        if (th instanceof IOException) {
            channelHandlerContext.close();
            this.logger.debug("[remote={}] IOException caught; channel closed", channelHandlerContext.channel().remoteAddress(), th);
            return;
        }
        ErrorMessage sendErrorMessage = ExceptionHandler.sendErrorMessage(channelHandlerContext, th);
        if (th instanceof UaException) {
            this.logger.debug("[remote={}] UaException caught; sent {}", new Object[]{channelHandlerContext.channel().remoteAddress(), sendErrorMessage, th});
        } else {
            this.logger.error("[remote={}] Exception caught; sent {}", new Object[]{channelHandlerContext.channel().remoteAddress(), sendErrorMessage, th});
        }
    }

    public void channelInactive(ChannelHandlerContext channelHandlerContext) throws Exception {
        if (this.secureChannelTimeout != null) {
            this.secureChannelTimeout.cancel();
            this.secureChannelTimeout = null;
        }
        super.channelInactive(channelHandlerContext);
    }
}
