package org.apache.flink.runtime.checkpoint.channel;

import java.util.ArrayList;
import java.util.Random;
import java.util.function.Function;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.CloseableIterator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.class */
class ChannelStateWriteRequestDispatcherImplTest {
    private static final JobID JOB_ID = new JobID();
    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
    private static final int SUBTASK_INDEX = 0;

    ChannelStateWriteRequestDispatcherImplTest() {
    }

    @Test
    void testPartialInputChannelStateWrite() throws Exception {
        testBuffersRecycled(networkBufferArr -> {
            return ChannelStateWriteRequest.write(JOB_VERTEX_ID, 0, 1L, new InputChannelInfo(1, 2), CloseableIterator.ofElements((v0) -> {
                v0.recycleBuffer();
            }, networkBufferArr));
        });
    }

    @Test
    void testPartialResultSubpartitionStateWrite() throws Exception {
        testBuffersRecycled(networkBufferArr -> {
            return ChannelStateWriteRequest.write(JOB_VERTEX_ID, 0, 1L, new ResultSubpartitionInfo(1, 2), networkBufferArr);
        });
    }

    private void testBuffersRecycled(Function<NetworkBuffer[], ChannelStateWriteRequest> function) throws Exception {
        ChannelStateWriteRequestDispatcherImpl channelStateWriteRequestDispatcherImpl = new ChannelStateWriteRequestDispatcherImpl(new JobManagerCheckpointStorage(), JOB_ID, new ChannelStateSerializerImpl());
        ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = new ChannelStateWriter.ChannelStateWriteResult();
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, 0));
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.start(JOB_VERTEX_ID, 0, 1L, channelStateWriteResult, CheckpointStorageLocationReference.getDefault()));
        channelStateWriteResult.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
        channelStateWriteResult.getInputChannelStateHandles().completeExceptionally(new TestException());
        NetworkBuffer[] networkBufferArr = {buffer(), buffer()};
        channelStateWriteRequestDispatcherImpl.dispatch(function.apply(networkBufferArr));
        for (NetworkBuffer networkBuffer : networkBufferArr) {
            Assertions.assertThat(networkBuffer.isRecycled()).isTrue();
        }
    }

    @Test
    void testStartNewCheckpointForSameSubtask() throws Exception {
        testStartNewCheckpointAndCheckOldCheckpointResult(false);
    }

    @Test
    void testStartNewCheckpointForDifferentSubtask() throws Exception {
        testStartNewCheckpointAndCheckOldCheckpointResult(true);
    }

    private void testStartNewCheckpointAndCheckOldCheckpointResult(boolean z) throws Exception {
        ChannelStateWriteRequestDispatcherImpl channelStateWriteRequestDispatcherImpl = new ChannelStateWriteRequestDispatcherImpl(new JobManagerCheckpointStorage(), JOB_ID, new ChannelStateSerializerImpl());
        ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = new ChannelStateWriter.ChannelStateWriteResult();
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, 0));
        JobVertexID jobVertexID = JOB_VERTEX_ID;
        if (z) {
            jobVertexID = new JobVertexID();
            channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(jobVertexID, 0));
        }
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.start(JOB_VERTEX_ID, 0, 1L, channelStateWriteResult, CheckpointStorageLocationReference.getDefault()));
        Assertions.assertThat(channelStateWriteResult.isDone()).isFalse();
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.start(jobVertexID, 0, 2L, new ChannelStateWriter.ChannelStateWriteResult(), CheckpointStorageLocationReference.getDefault()));
        ChannelStateWriteResultUtil.assertCheckpointFailureReason(channelStateWriteResult, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
    }

    @Test
    void testStartOldCheckpointForSameSubtask() throws Exception {
        testStartOldCheckpointAfterNewCheckpointAborted(false);
    }

    @Test
    void testStartOldCheckpointForDifferentSubtask() throws Exception {
        testStartOldCheckpointAfterNewCheckpointAborted(true);
    }

    private void testStartOldCheckpointAfterNewCheckpointAborted(boolean z) throws Exception {
        ChannelStateWriteRequestDispatcherImpl channelStateWriteRequestDispatcherImpl = new ChannelStateWriteRequestDispatcherImpl(new JobManagerCheckpointStorage(), JOB_ID, new ChannelStateSerializerImpl());
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, 0));
        JobVertexID jobVertexID = JOB_VERTEX_ID;
        if (z) {
            jobVertexID = new JobVertexID();
            channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(jobVertexID, 0));
        }
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.abort(JOB_VERTEX_ID, 0, 2L, new TestException()));
        ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = new ChannelStateWriter.ChannelStateWriteResult();
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.start(jobVertexID, 0, 1L, channelStateWriteResult, CheckpointStorageLocationReference.getDefault()));
        ChannelStateWriteResultUtil.assertCheckpointFailureReason(channelStateWriteResult, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
    }

    @Test
    void testAbortCheckpointAndCheckAllException() throws Exception {
        testAbortCheckpointAndCheckAllException(1);
        testAbortCheckpointAndCheckAllException(2);
        testAbortCheckpointAndCheckAllException(3);
        testAbortCheckpointAndCheckAllException(5);
        testAbortCheckpointAndCheckAllException(10);
    }

    private void testAbortCheckpointAndCheckAllException(int i) throws Exception {
        ChannelStateWriteRequestDispatcherImpl channelStateWriteRequestDispatcherImpl = new ChannelStateWriteRequestDispatcherImpl(new JobManagerCheckpointStorage(), JOB_ID, new ChannelStateSerializerImpl());
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, i2));
        }
        int nextInt = new Random().nextInt(i);
        channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.abort(JOB_VERTEX_ID, nextInt, 1L, new TestException()));
        for (int i3 = 0; i3 < i; i3++) {
            ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = new ChannelStateWriter.ChannelStateWriteResult();
            arrayList.add(channelStateWriteResult);
            channelStateWriteRequestDispatcherImpl.dispatch(ChannelStateWriteRequest.start(JOB_VERTEX_ID, i3, 1L, channelStateWriteResult, CheckpointStorageLocationReference.getDefault()));
        }
        Assertions.assertThat(arrayList).allMatch((v0) -> {
            return v0.isDone();
        });
        for (int i4 = 0; i4 < i; i4++) {
            ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult2 = (ChannelStateWriter.ChannelStateWriteResult) arrayList.get(i4);
            if (i4 == nextInt) {
                ChannelStateWriteResultUtil.assertHasSpecialCause(channelStateWriteResult2, TestException.class);
            } else {
                ChannelStateWriteResultUtil.assertCheckpointFailureReason(channelStateWriteResult2, CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
            }
        }
    }

    private NetworkBuffer buffer() {
        return new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(10), FreeingBufferRecycler.INSTANCE);
    }
}
