/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.v2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutor;
import org.apache.flink.runtime.asyncprocessing.AsyncRequestContainer;
import org.apache.flink.runtime.asyncprocessing.EpochManager;
import org.apache.flink.runtime.asyncprocessing.MockAsyncRequestContainer;
import org.apache.flink.runtime.asyncprocessing.StateExecutionController;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequest;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.asyncprocessing.StateRequestType;
import org.apache.flink.runtime.asyncprocessing.declare.DeclarationManager;
import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
import org.apache.flink.runtime.state.v2.AbstractAggregatingState;
import org.apache.flink.runtime.state.v2.AbstractKeyedStateTestBase;
import org.apache.flink.runtime.state.v2.internal.InternalPartitionedState;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;

class AbstractAggregatingStateTest
extends AbstractKeyedStateTestBase {
    AbstractAggregatingStateTest() {
    }

    @Test
    public void testAggregating() {
        SumAggregator aggregator = new SumAggregator(1);
        AggregatingStateDescriptor descriptor = new AggregatingStateDescriptor("testAggState", (AggregateFunction)aggregator, (TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
        AbstractAggregatingState state = new AbstractAggregatingState((StateRequestHandler)this.aec, descriptor.getAggregateFunction(), descriptor.getSerializer().duplicate());
        this.aec.setCurrentContext(this.aec.buildContext((Object)"test", (Object)"test"));
        state.asyncClear();
        this.validateRequestRun((State)state, StateRequestType.CLEAR, null, 0);
        state.asyncGet();
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 0);
        state.asyncAdd((Object)1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_ADD, 2, 0);
        state.asyncAdd((Object)5);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_ADD, 6, 0);
    }

    @Test
    public void testAggregatingStateAddWithSyncAPI() {
        SumAggregator aggregator = new SumAggregator(1);
        AggregatingStateDescriptor descriptor = new AggregatingStateDescriptor("testState", (AggregateFunction)aggregator, (TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
        AggregatingStateExecutor aggregatingStateExecutor = new AggregatingStateExecutor();
        StateExecutionController aec = new StateExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (AsyncExecutor)aggregatingStateExecutor, new DeclarationManager(), EpochManager.ParallelMode.SERIAL_BETWEEN_EPOCH, 1, 100, 10000L, 1, null, null);
        AbstractAggregatingState aggregatingState = new AbstractAggregatingState((StateRequestHandler)aec, descriptor.getAggregateFunction(), descriptor.getSerializer());
        aec.setCurrentContext(aec.buildContext((Object)"test", (Object)"test"));
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"1");
        aggregatingState.add((Object)1);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(2);
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"2");
        aggregatingState.add((Object)2);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"2"))).isEqualTo(3);
    }

    @Test
    public void testMergeNamespace() throws Exception {
        SumAggregator aggregator = new SumAggregator(0);
        AggregatingStateDescriptor descriptor = new AggregatingStateDescriptor("testState", (AggregateFunction)aggregator, (TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
        AggregatingStateExecutor aggregatingStateExecutor = new AggregatingStateExecutor();
        StateExecutionController aec = new StateExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (AsyncExecutor)aggregatingStateExecutor, new DeclarationManager(), EpochManager.ParallelMode.SERIAL_BETWEEN_EPOCH, 1, 100, 10000L, 1, null, null);
        AbstractAggregatingState aggregatingState = new AbstractAggregatingState((StateRequestHandler)aec, descriptor.getAggregateFunction(), descriptor.getSerializer());
        aec.setCurrentContext(aec.buildContext((Object)"test", (Object)"test"));
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"1");
        aggregatingState.asyncAdd((Object)1);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"2");
        aggregatingState.asyncAdd((Object)2);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"2"))).isEqualTo(2);
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"3");
        aggregatingState.asyncAdd((Object)3);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(3);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"2"))).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"3"))).isEqualTo(3);
        ArrayList<String> sources = new ArrayList<String>(Arrays.asList("1", "2", "3"));
        aggregatingState.asyncMergeNamespaces((Object)"0", sources);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(6);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"2"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"3"))).isNull();
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"4");
        aggregatingState.asyncAdd((Object)4);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(6);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"4"))).isEqualTo(4);
        ArrayList<String> sources1 = new ArrayList<String>(Arrays.asList("4"));
        aggregatingState.asyncMergeNamespaces((Object)"0", sources1);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)aggregatingStateExecutor.getResultMap().size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(10);
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"1"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"2"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"3"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)aggregatingStateExecutor.getResultMap().get(Tuple2.of((Object)"test", (Object)"4"))).isNull();
    }

    static class AggregatingStateExecutor
    implements StateExecutor {
        private final HashMap<Tuple2<String, String>, Integer> hashMap = new HashMap();

        AggregatingStateExecutor() {
        }

        public CompletableFuture<Void> executeBatchRequests(AsyncRequestContainer asyncRequestContainer) {
            for (StateRequest stateRequest : ((MockAsyncRequestContainer)asyncRequestContainer).getStateRequestList()) {
                this.executeRequestSync(stateRequest);
            }
            CompletableFuture<Void> future = new CompletableFuture<Void>();
            future.complete(null);
            return future;
        }

        public AsyncRequestContainer<StateRequest<?, ?, ?, ?>> createRequestContainer() {
            return new MockAsyncRequestContainer();
        }

        public void executeRequestSync(StateRequest<?, ?, ?, ?> stateRequest) {
            String key = (String)stateRequest.getRecordContext().getKey();
            String namespace = (String)stateRequest.getNamespace();
            if (stateRequest.getRequestType() == StateRequestType.AGGREGATING_ADD) {
                if (stateRequest.getPayload() == null) {
                    this.hashMap.remove(Tuple2.of((Object)key, (Object)namespace));
                    stateRequest.getFuture().complete(null);
                } else {
                    this.hashMap.put((Tuple2<String, String>)Tuple2.of((Object)key, (Object)namespace), (Integer)stateRequest.getPayload());
                    stateRequest.getFuture().complete(null);
                }
            } else if (stateRequest.getRequestType() == StateRequestType.AGGREGATING_GET) {
                Integer val = this.hashMap.get(Tuple2.of((Object)key, (Object)namespace));
                stateRequest.getFuture().complete((Object)val);
            } else {
                throw new UnsupportedOperationException("Unsupported type");
            }
        }

        public HashMap<Tuple2<String, String>, Integer> getResultMap() {
            return this.hashMap;
        }

        public boolean fullyLoaded() {
            return false;
        }

        public void shutdown() {
        }
    }

    static class SumAggregator
    implements AggregateFunction<Integer, Integer, Integer> {
        private final int init;

        public SumAggregator(int init) {
            this.init = init;
        }

        public Integer createAccumulator() {
            return this.init;
        }

        public Integer add(Integer value, Integer accumulator) {
            return accumulator + value;
        }

        public Integer getResult(Integer accumulator) {
            return accumulator;
        }

        public Integer merge(Integer a, Integer b) {
            return a + b;
        }
    }
}

