/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint.channel;

import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutor;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.SyncChannelStateWriteRequestExecutor;
import org.apache.flink.runtime.checkpoint.channel.TestException;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorage;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.function.BiConsumerWithException;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class ChannelStateWriterImplTest {
    private static final long CHECKPOINT_ID = 42L;
    private static final String TASK_NAME = "test";
    private static final JobID JOB_ID = new JobID();
    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
    private static final int SUBTASK_INDEX = 0;
    private static final CheckpointStorage CHECKPOINT_STORAGE = new JobManagerCheckpointStorage();

    ChannelStateWriterImplTest() {
    }

    @Test
    void testAddEventBuffer() throws Exception {
        NetworkBuffer dataBuf = this.getBuffer();
        NetworkBuffer eventBuf = this.getBuffer();
        eventBuf.setDataType(Buffer.DataType.EVENT_BUFFER);
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            this.callStart((ChannelStateWriter)writer);
            this.callAddInputData((ChannelStateWriter)writer, eventBuf, dataBuf);
            Assertions.assertThatThrownBy(worker::processAllRequests).isInstanceOf(IllegalArgumentException.class);
        }));
        Assertions.assertThat((boolean)dataBuf.isRecycled()).isTrue();
    }

    @Test
    void testResultCompletion() throws IOException {
        ChannelStateWriter.ChannelStateWriteResult result;
        try (ChannelStateWriterImpl writer = this.openWriter();){
            this.callStart((ChannelStateWriter)writer);
            result = writer.getAndRemoveWriteResult(42L);
            Assertions.assertThat((CompletableFuture)result.resultSubpartitionStateHandles).isNotDone();
            Assertions.assertThat((CompletableFuture)result.inputChannelStateHandles).isNotDone();
        }
        Assertions.assertThat((CompletableFuture)result.inputChannelStateHandles).isDone();
        Assertions.assertThat((CompletableFuture)result.resultSubpartitionStateHandles).isDone();
    }

    @Test
    void testAbort() throws Exception {
        NetworkBuffer buffer = this.getBuffer();
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            this.callStart((ChannelStateWriter)writer);
            ChannelStateWriter.ChannelStateWriteResult result = writer.getAndRemoveWriteResult(42L);
            this.callAddInputData((ChannelStateWriter)writer, buffer);
            this.callAbort((ChannelStateWriter)writer);
            worker.processAllRequests();
            Assertions.assertThat((boolean)result.isDone()).isTrue();
            Assertions.assertThat((boolean)buffer.isRecycled()).isTrue();
        }));
    }

    @Test
    void testAbortClearsResults() throws Exception {
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            this.callStart((ChannelStateWriter)writer);
            writer.abort(42L, (Throwable)new TestException(), true);
            Assertions.assertThatThrownBy(() -> writer.getAndRemoveWriteResult(42L)).isInstanceOf(IllegalArgumentException.class);
        }));
    }

    @Test
    void testAbortDoesNotClearsResults() throws Exception {
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            this.callStart((ChannelStateWriter)writer);
            this.callAbort((ChannelStateWriter)writer);
            worker.processAllRequests();
            writer.getAndRemoveWriteResult(42L);
        }));
    }

    @Test
    void testAbortIgnoresMissing() throws Exception {
        this.executeCallbackAndProcessWithSyncWorker(this::callAbort);
    }

    @Test
    void testAbortOldAndStartNewCheckpoint() throws Exception {
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            int checkpoint42 = 42;
            int checkpoint43 = 43;
            writer.start((long)checkpoint42, CheckpointOptions.forCheckpointWithDefaultLocation());
            writer.abort((long)checkpoint42, (Throwable)new TestException(), false);
            writer.start((long)checkpoint43, CheckpointOptions.forCheckpointWithDefaultLocation());
            worker.processAllRequests();
            ChannelStateWriter.ChannelStateWriteResult result42 = writer.getAndRemoveWriteResult((long)checkpoint42);
            Assertions.assertThat((boolean)result42.isDone()).isTrue();
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> result42.getInputChannelStateHandles().get()).as("The result should have failed.", new Object[0])).hasCauseInstanceOf(TestException.class);
            ChannelStateWriter.ChannelStateWriteResult result43 = writer.getAndRemoveWriteResult((long)checkpoint43);
            Assertions.assertThat((boolean)result43.isDone()).isFalse();
        }));
    }

    @Test
    void testBuffersRecycledOnError() {
        NetworkBuffer buffer = this.getBuffer();
        ChannelStateWriterImpl writer = new ChannelStateWriterImpl(JOB_VERTEX_ID, TASK_NAME, 0, new ConcurrentHashMap(), this.failingWorker(), 5);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.callAddInputData((ChannelStateWriter)writer, buffer)).isInstanceOf(RuntimeException.class)).hasCauseInstanceOf(TestException.class);
        Assertions.assertThat((boolean)buffer.isRecycled()).isTrue();
    }

    @Test
    void testBuffersRecycledOnClose() throws Exception {
        NetworkBuffer buffer = this.getBuffer();
        this.executeCallbackAndProcessWithSyncWorker(writer -> {
            this.callStart((ChannelStateWriter)writer);
            this.callAddInputData((ChannelStateWriter)writer, buffer);
            Assertions.assertThat((boolean)buffer.isRecycled()).isFalse();
        });
        Assertions.assertThat((boolean)buffer.isRecycled()).isTrue();
    }

    @Test
    void testNoAddDataAfterFinished() throws Exception {
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(writer, worker) -> {
            this.callStart((ChannelStateWriter)writer);
            this.callFinish((ChannelStateWriter)writer);
            worker.processAllRequests();
            this.callAddInputData((ChannelStateWriter)writer, new NetworkBuffer[0]);
            Assertions.assertThatThrownBy(worker::processAllRequests).isInstanceOf(IllegalArgumentException.class);
        }));
    }

    @Test
    void testAddDataNotStarted() {
        Assertions.assertThatThrownBy(() -> this.executeCallbackAndProcessWithSyncWorker(x$0 -> this.callAddInputData((ChannelStateWriter)x$0, new NetworkBuffer[0]))).isInstanceOf(IllegalArgumentException.class);
    }

    @Test
    void testFinishNotStarted() {
        Assertions.assertThatThrownBy(() -> this.executeCallbackAndProcessWithSyncWorker(this::callFinish)).isInstanceOf(IllegalArgumentException.class);
    }

    @Test
    void testRethrowOnClose() {
        Assertions.assertThatThrownBy(() -> this.executeCallbackAndProcessWithSyncWorker(writer -> {
            try {
                this.callFinish((ChannelStateWriter)writer);
            }
            catch (IllegalArgumentException illegalArgumentException) {
                // empty catch block
            }
        })).isInstanceOf(IllegalArgumentException.class);
    }

    @Test
    void testRethrowOnNextCall() {
        SyncChannelStateWriteRequestExecutor worker = new SyncChannelStateWriteRequestExecutor(JOB_ID);
        ChannelStateWriterImpl writer = new ChannelStateWriterImpl(JOB_VERTEX_ID, TASK_NAME, 0, new ConcurrentHashMap(), (ChannelStateWriteRequestExecutor)worker, 5);
        worker.registerSubtask(JOB_VERTEX_ID, 0);
        worker.setThrown(new TestException());
        Assertions.assertThatThrownBy(() -> this.callStart((ChannelStateWriter)writer)).hasCauseInstanceOf(TestException.class);
    }

    @Test
    void testLimit() throws IOException {
        int maxCheckpoints = 3;
        try (ChannelStateWriterImpl writer = new ChannelStateWriterImpl(JOB_VERTEX_ID, TASK_NAME, 0, () -> CHECKPOINT_STORAGE.createCheckpointStorage(JOB_ID), maxCheckpoints, new ChannelStateWriteRequestExecutorFactory(JOB_ID), 5);){
            for (int i = 0; i < maxCheckpoints; ++i) {
                writer.start((long)i, CheckpointOptions.forCheckpointWithDefaultLocation());
            }
            Assertions.assertThatThrownBy(() -> writer.start((long)maxCheckpoints, CheckpointOptions.forCheckpointWithDefaultLocation())).isInstanceOf(IllegalStateException.class);
        }
    }

    @Test
    void testNoStartAfterClose() throws IOException {
        ChannelStateWriterImpl writer = this.openWriter();
        writer.close();
        Assertions.assertThatThrownBy(() -> writer.start(42L, CheckpointOptions.forCheckpointWithDefaultLocation())).hasCauseInstanceOf(IllegalStateException.class);
    }

    @Test
    void testNoAddDataAfterClose() throws IOException {
        ChannelStateWriterImpl writer = this.openWriter();
        this.callStart((ChannelStateWriter)writer);
        writer.close();
        Assertions.assertThatThrownBy(() -> this.callAddInputData((ChannelStateWriter)writer, new NetworkBuffer[0])).hasCauseInstanceOf(IllegalStateException.class);
    }

    private NetworkBuffer getBuffer() {
        return new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment((int)123, null), FreeingBufferRecycler.INSTANCE);
    }

    private ChannelStateWriteRequestExecutor failingWorker() {
        return new ChannelStateWriteRequestExecutor(){

            public void submit(ChannelStateWriteRequest e) {
                throw new TestException();
            }

            public void submitPriority(ChannelStateWriteRequest e) {
                throw new TestException();
            }

            public void start() throws IllegalStateException {
            }

            public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
            }

            public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {
            }
        };
    }

    private void executeCallbackAndProcessWithSyncWorker(Consumer<ChannelStateWriter> writerConsumer) throws Exception {
        this.executeCallbackWithSyncWorker((BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>)((BiConsumerWithException)(channelStateWriter, syncChannelStateWriterWorker) -> {
            writerConsumer.accept((ChannelStateWriter)channelStateWriter);
            syncChannelStateWriterWorker.processAllRequests();
        }));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void executeCallbackWithSyncWorker(BiConsumerWithException<ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception> testFn) throws Exception {
        SyncChannelStateWriteRequestExecutor worker = new SyncChannelStateWriteRequestExecutor(JOB_ID);
        try (ChannelStateWriterImpl writer = new ChannelStateWriterImpl(JOB_VERTEX_ID, TASK_NAME, 0, new ConcurrentHashMap(), (ChannelStateWriteRequestExecutor)worker, 5);){
            worker.registerSubtask(JOB_VERTEX_ID, 0);
            testFn.accept((Object)writer, (Object)worker);
        }
        finally {
            worker.releaseSubtask(JOB_VERTEX_ID, 0);
        }
    }

    private ChannelStateWriterImpl openWriter() throws IOException {
        return new ChannelStateWriterImpl(JOB_VERTEX_ID, TASK_NAME, 0, () -> CHECKPOINT_STORAGE.createCheckpointStorage(JOB_ID), new ChannelStateWriteRequestExecutorFactory(JOB_ID), 5);
    }

    private void callStart(ChannelStateWriter writer) {
        writer.start(42L, CheckpointOptions.forCheckpointWithDefaultLocation());
    }

    private void callAddInputData(ChannelStateWriter writer, NetworkBuffer ... buffer) {
        writer.addInputData(42L, new InputChannelInfo(1, 1), 1, CloseableIterator.ofElements(Buffer::recycleBuffer, (Object[])buffer));
    }

    private void callAbort(ChannelStateWriter writer) {
        writer.abort(42L, (Throwable)new TestException(), false);
    }

    private void callFinish(ChannelStateWriter writer) {
        writer.finishInput(42L);
        writer.finishOutput(42L);
    }
}

