package org.apache.flink.runtime.io.network.netty;

import java.io.IOException;
import java.util.Random;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.TestingPartitionRequestClient;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.buffer.ReadOnlySlicedNetworkBuffer;
import org.apache.flink.runtime.io.network.netty.NettyMessage;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.TestLoggerExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

@ExtendWith({TestLoggerExtension.class})
/* loaded from: input_file:org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.class */
class NettyMessageClientSideSerializationTest {
    private static final int BUFFER_SIZE = 1024;
    private final Random random = new Random();
    private static BufferCompressor compressor;
    private static BufferDecompressor decompressor;
    private EmbeddedChannel channel;
    private NetworkBufferPool networkBufferPool;
    private SingleInputGate inputGate;
    private InputChannelID inputChannelId;

    NettyMessageClientSideSerializationTest() {
    }

    @BeforeEach
    void setup() throws IOException, InterruptedException {
        this.networkBufferPool = new NetworkBufferPool(8, BUFFER_SIZE);
        this.inputGate = InputChannelTestUtils.createSingleInputGate(1, this.networkBufferPool);
        InputChannel createRemoteInputChannel = InputChannelTestUtils.createRemoteInputChannel(this.inputGate, new TestingPartitionRequestClient());
        createRemoteInputChannel.requestSubpartition();
        this.inputGate.setInputChannels(new InputChannel[]{createRemoteInputChannel});
        this.inputGate.setup();
        CreditBasedPartitionRequestClientHandler creditBasedPartitionRequestClientHandler = new CreditBasedPartitionRequestClientHandler();
        creditBasedPartitionRequestClientHandler.addInputChannel(createRemoteInputChannel);
        this.channel = new EmbeddedChannel(new ChannelHandler[]{new NettyMessage.NettyMessageEncoder(), new NettyMessageClientDecoderDelegate(creditBasedPartitionRequestClientHandler)});
        this.inputChannelId = createRemoteInputChannel.getInputChannelId();
    }

    @AfterEach
    void tearDown() throws IOException {
        if (this.inputGate != null) {
            this.inputGate.close();
        }
        if (this.networkBufferPool != null) {
            this.networkBufferPool.destroyAllBufferPools();
            this.networkBufferPool.destroy();
        }
        if (this.channel != null) {
            this.channel.close();
        }
    }

    @Test
    void testErrorResponseWithoutErrorMessage() {
        testErrorResponse(new NettyMessage.ErrorResponse(new IllegalStateException(), this.inputChannelId));
    }

    @Test
    void testErrorResponseWithErrorMessage() {
        testErrorResponse(new NettyMessage.ErrorResponse(new IllegalStateException("Illegal illegal illegal"), this.inputChannelId));
    }

    @Test
    void testErrorResponseWithFatalError() {
        testErrorResponse(new NettyMessage.ErrorResponse(new IllegalStateException("Illegal illegal illegal")));
    }

    @Test
    void testOrdinaryBufferResponse() {
        testBufferResponse(false, false);
    }

    @Test
    void testBufferResponseWithReadOnlySlice() {
        testBufferResponse(true, false);
    }

    @ValueSource(strings = {"LZ4", "LZO", "ZSTD"})
    @ParameterizedTest
    void testCompressedBufferResponse(String str) {
        compressor = new BufferCompressor(BUFFER_SIZE, str);
        decompressor = new BufferDecompressor(BUFFER_SIZE, str);
        testBufferResponse(false, true);
    }

    @Test
    void testBacklogAnnouncement() {
        NettyMessage.BacklogAnnouncement backlogAnnouncement = new NettyMessage.BacklogAnnouncement(BUFFER_SIZE, this.inputChannelId);
        NettyMessage.BacklogAnnouncement encodeAndDecode = NettyTestUtil.encodeAndDecode(backlogAnnouncement, this.channel);
        Assertions.assertThat(encodeAndDecode.backlog).isEqualTo(backlogAnnouncement.backlog);
        Assertions.assertThat(encodeAndDecode.receiverId).isEqualTo(backlogAnnouncement.receiverId);
    }

    private void testErrorResponse(NettyMessage.ErrorResponse errorResponse) {
        NettyTestUtil.verifyErrorResponse(errorResponse, NettyTestUtil.encodeAndDecode(errorResponse, this.channel));
    }

    private void testBufferResponse(boolean z, boolean z2) {
        Preconditions.checkArgument(!(z & z2), "There are no cases with both readonly slice and compression.");
        ReadOnlySlicedNetworkBuffer networkBuffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE), FreeingBufferRecycler.INSTANCE);
        for (int i = 0; i < BUFFER_SIZE; i += 8) {
            networkBuffer.writeLong(i);
        }
        ReadOnlySlicedNetworkBuffer readOnlySlicedNetworkBuffer = networkBuffer;
        if (z) {
            readOnlySlicedNetworkBuffer = networkBuffer.readOnlySlice();
        } else if (z2) {
            readOnlySlicedNetworkBuffer = compressor.compressToOriginalBuffer(networkBuffer);
        }
        NettyMessage.BufferResponse bufferResponse = new NettyMessage.BufferResponse(readOnlySlicedNetworkBuffer, this.random.nextInt(Integer.MAX_VALUE), this.inputChannelId, this.random.nextInt(Integer.MAX_VALUE));
        NettyMessage.BufferResponse encodeAndDecode = NettyTestUtil.encodeAndDecode(bufferResponse, this.channel);
        Assertions.assertThat(networkBuffer.isRecycled()).isTrue();
        Assertions.assertThat(readOnlySlicedNetworkBuffer.isRecycled()).isTrue();
        Assertions.assertThat(encodeAndDecode.getBuffer()).as("The request input channel should always have available buffers in this test.", new Object[0]).isNotNull();
        Buffer buffer = encodeAndDecode.getBuffer();
        if (z2) {
            Assertions.assertThat(encodeAndDecode.isCompressed).isTrue();
            buffer = decompress(buffer);
        }
        NettyTestUtil.verifyBufferResponseHeader(bufferResponse, encodeAndDecode);
        Assertions.assertThat(buffer.readableBytes()).isEqualTo(BUFFER_SIZE);
        for (int i2 = 0; i2 < BUFFER_SIZE; i2 += 8) {
            Assertions.assertThat(buffer.asByteBuf().readLong()).isEqualTo(i2);
        }
        encodeAndDecode.releaseBuffer();
        if (z2) {
            buffer.recycleBuffer();
        }
        Assertions.assertThat(encodeAndDecode.getBuffer().isRecycled()).isTrue();
    }

    private Buffer decompress(Buffer buffer) {
        NetworkBuffer networkBuffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE), FreeingBufferRecycler.INSTANCE);
        buffer.asByteBuf().readBytes(networkBuffer.asByteBuf(), buffer.readableBytes());
        networkBuffer.setCompressed(true);
        return decompressor.decompressToOriginalBuffer(networkBuffer);
    }
}
