/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.asyncprocessing;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.state.v2.StateFuture;
import org.apache.flink.api.common.state.v2.StateIterator;
import org.apache.flink.core.state.StateFutureUtils;
import org.apache.flink.runtime.asyncprocessing.AbstractStateIterator;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutor;
import org.apache.flink.runtime.asyncprocessing.AsyncRequestContainer;
import org.apache.flink.runtime.asyncprocessing.EpochManager;
import org.apache.flink.runtime.asyncprocessing.MockAsyncRequestContainer;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.asyncprocessing.StateExecutionController;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequest;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.asyncprocessing.StateRequestType;
import org.apache.flink.runtime.asyncprocessing.declare.DeclarationManager;
import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class AbstractStateIteratorTest {
    @Test
    public void testPartialLoading() {
        TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
        StateExecutionController aec = new StateExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (AsyncExecutor)stateExecutor, new DeclarationManager(), EpochManager.ParallelMode.SERIAL_BETWEEN_EPOCH, 1, 100, 1000L, 1, null, null);
        stateExecutor.bindAec(aec);
        RecordContext recordContext = aec.buildContext((Object)"1", (Object)"key1");
        aec.setCurrentContext(recordContext);
        AtomicInteger processed = new AtomicInteger();
        aec.handleRequest(null, StateRequestType.MAP_ITER, null).thenAccept(iter -> {
            Assertions.assertThat((Object)iter).isInstanceOf(StateIterator.class);
            ((StateIterator)iter).onNext(item -> Assertions.assertThat((Integer)item).isEqualTo(processed.getAndIncrement())).thenAccept(v -> Assertions.assertThat((int)processed.get()).isEqualTo(100));
        });
        aec.drainInflightRecords(0);
    }

    @Test
    public void testPartialLoadingWithReturnValue() {
        TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
        StateExecutionController aec = new StateExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (AsyncExecutor)stateExecutor, new DeclarationManager(), EpochManager.ParallelMode.SERIAL_BETWEEN_EPOCH, 1, 100, 1000L, 1, null, null);
        stateExecutor.bindAec(aec);
        RecordContext recordContext = aec.buildContext((Object)"1", (Object)"key1");
        aec.setCurrentContext(recordContext);
        AtomicInteger processed = new AtomicInteger();
        aec.handleRequest(null, StateRequestType.MAP_ITER, null).thenAccept(iter -> {
            Assertions.assertThat((Object)iter).isInstanceOf(StateIterator.class);
            ((StateIterator)iter).onNext(item -> {
                Assertions.assertThat((Integer)item).isEqualTo(processed.getAndIncrement());
                return StateFutureUtils.completedFuture((Object)String.valueOf(item));
            }).thenAccept(strings -> {
                Assertions.assertThat((int)processed.get()).isEqualTo(100);
                int validate = 0;
                for (String item : strings) {
                    Assertions.assertThat((String)item).isEqualTo(String.valueOf(validate++));
                }
            });
        });
        aec.drainInflightRecords(0);
    }

    @Test
    public void testPartialLoadingWithConversionToIterable() {
        TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
        StateExecutionController aec = new StateExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (AsyncExecutor)stateExecutor, new DeclarationManager(), EpochManager.ParallelMode.SERIAL_BETWEEN_EPOCH, 1, 100, 1000L, 1, null, null);
        stateExecutor.bindAec(aec);
        RecordContext recordContext = aec.buildContext((Object)"1", (Object)"key1");
        aec.setCurrentContext(recordContext);
        AtomicInteger processed = new AtomicInteger();
        StateFutureUtils.toIterable((StateFuture)aec.handleRequest(null, StateRequestType.MAP_ITER, null)).thenAccept(iter -> {
            Assertions.assertThat((boolean)(iter instanceof Iterable));
            ((Iterable)iter).forEach(item -> Assertions.assertThat((Integer)item).isEqualTo(processed.getAndIncrement()));
            Assertions.assertThat((int)processed.get()).isEqualTo(100);
        });
        aec.drainInflightRecords(0);
    }

    static class TestIteratorStateExecutor
    implements StateExecutor {
        final int limit;
        final int step;
        StateExecutionController aec;
        int current = 0;
        AtomicInteger processedCount = new AtomicInteger(0);

        public TestIteratorStateExecutor(int limit, int step) {
            this.limit = limit;
            this.step = step;
        }

        public void bindAec(StateExecutionController aec) {
            this.aec = aec;
        }

        public CompletableFuture<Void> executeBatchRequests(AsyncRequestContainer<StateRequest<?, ?, ?, ?>> asyncRequestContainer) {
            Preconditions.checkArgument((boolean)(asyncRequestContainer instanceof MockAsyncRequestContainer));
            CompletableFuture<Void> future = new CompletableFuture<Void>();
            for (StateRequest request : ((MockAsyncRequestContainer)asyncRequestContainer).getStateRequestList()) {
                this.executeRequestSync(request);
            }
            future.complete(null);
            return future;
        }

        public AsyncRequestContainer<StateRequest<?, ?, ?, ?>> createRequestContainer() {
            return new MockAsyncRequestContainer();
        }

        public void executeRequestSync(StateRequest<?, ?, ?, ?> request) {
            if (request.getRequestType() == StateRequestType.MAP_ITER) {
                ArrayList<Integer> results = new ArrayList<Integer>(this.step);
                for (int i = 0; this.current < this.limit && i < this.step; ++i) {
                    results.add(this.current++);
                }
                request.getFuture().complete((Object)new TestIterator(request.getState(), request.getRequestType(), this.aec, results, this.current, this.limit));
            } else if (request.getRequestType() == StateRequestType.ITERATOR_LOADING) {
                Assertions.assertThat((Object)request.getPayload()).isInstanceOf(TestIterator.class);
                Assertions.assertThat((int)((TestIterator)((Object)request.getPayload())).current).isEqualTo(this.current);
                ArrayList<Integer> results = new ArrayList<Integer>(this.step);
                for (int i = 0; this.current < this.limit && i < this.step; ++i) {
                    results.add(this.current++);
                }
                request.getFuture().complete((Object)new TestIterator(request.getState(), ((TestIterator)((Object)request.getPayload())).getRequestType(), this.aec, results, this.current, this.limit));
            } else {
                org.junit.jupiter.api.Assertions.fail((String)("Unsupported request type " + request.getRequestType()));
            }
            this.processedCount.incrementAndGet();
        }

        public boolean fullyLoaded() {
            return false;
        }

        public void shutdown() {
        }

        static class TestIterator
        extends AbstractStateIterator<Integer> {
            final int current;
            final int limit;

            public TestIterator(State originalState, StateRequestType requestType, StateExecutionController aec, Collection<Integer> partialResult, int current, int limit) {
                super(originalState, requestType, (StateRequestHandler)aec, partialResult);
                this.current = current;
                this.limit = limit;
            }

            public boolean hasNextLoading() {
                return this.current < this.limit;
            }

            protected Object nextPayloadForContinuousLoading() {
                return this;
            }
        }
    }
}

