/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.ttl.mock;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyExtractorFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.PriorityComparable;
import org.apache.flink.runtime.state.PriorityComparator;
import org.apache.flink.runtime.state.SharedStateRegistry;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateSnapshotTransformers;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
import org.apache.flink.runtime.state.heap.InternalKeyContext;
import org.apache.flink.runtime.state.ttl.TtlStateFactory;
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
import org.apache.flink.runtime.state.ttl.mock.MockInternalAggregatingState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalFoldingState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalKvState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalListState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalMapState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalReducingState;
import org.apache.flink.runtime.state.ttl.mock.MockInternalValueState;
import org.apache.flink.util.FlinkRuntimeException;

public class MockKeyedStateBackend<K>
extends AbstractKeyedStateBackend<K> {
    private static final Map<Class<? extends StateDescriptor>, StateFactory> STATE_FACTORIES = Stream.of(Tuple2.of(ValueStateDescriptor.class, MockInternalValueState::createState), Tuple2.of(ListStateDescriptor.class, MockInternalListState::createState), Tuple2.of(MapStateDescriptor.class, MockInternalMapState::createState), Tuple2.of(ReducingStateDescriptor.class, MockInternalReducingState::createState), Tuple2.of(AggregatingStateDescriptor.class, MockInternalAggregatingState::createState), Tuple2.of(FoldingStateDescriptor.class, MockInternalFoldingState::createState)).collect(Collectors.toMap(t -> (Class)t.f0, t -> (StateFactory)t.f1));
    private final Map<String, Map<K, Map<Object, Object>>> stateValues;
    private final Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters;

    MockKeyedStateBackend(TaskKvStateRegistry kvStateRegistry, TypeSerializer<K> keySerializer, ClassLoader userCodeClassLoader, ExecutionConfig executionConfig, TtlTimeProvider ttlTimeProvider, Map<String, Map<K, Map<Object, Object>>> stateValues, Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters, CloseableRegistry cancelStreamRegistry, InternalKeyContext<K> keyContext) {
        super(kvStateRegistry, keySerializer, userCodeClassLoader, executionConfig, ttlTimeProvider, cancelStreamRegistry, keyContext);
        this.stateValues = stateValues;
        this.stateSnapshotFilters = stateSnapshotFilters;
    }

    @Nonnull
    public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(@Nonnull TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<S, SV> stateDesc, @Nonnull StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
        if (stateFactory == null) {
            String message = String.format("State %s is not supported by %s", stateDesc.getClass(), TtlStateFactory.class);
            throw new FlinkRuntimeException(message);
        }
        Object state = stateFactory.createInternalState(namespaceSerializer, stateDesc);
        this.stateSnapshotFilters.put(stateDesc.getName(), this.getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
        ((MockInternalKvState)state).values = () -> this.stateValues.computeIfAbsent(stateDesc.getName(), n -> new HashMap()).computeIfAbsent(this.getCurrentKey(), k -> new HashMap());
        return state;
    }

    private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(StateDescriptor<?, SV> stateDesc, StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
        Optional original = snapshotTransformFactory.createForDeserializedState();
        if (original.isPresent()) {
            if (stateDesc instanceof ListStateDescriptor) {
                return new StateSnapshotTransformers.ListStateSnapshotTransformer((StateSnapshotTransformer)original.get());
            }
            if (stateDesc instanceof MapStateDescriptor) {
                return new StateSnapshotTransformers.MapStateSnapshotTransformer((StateSnapshotTransformer)original.get());
            }
            return (StateSnapshotTransformer)original.get();
        }
        return null;
    }

    public int numKeyValueStateEntries() {
        int count = 0;
        for (String state : this.stateValues.keySet()) {
            for (K key : this.stateValues.get(state).keySet()) {
                count += this.stateValues.get(state).get(key).size();
            }
        }
        return count;
    }

    public boolean requiresLegacySynchronousTimerSnapshots() {
        return false;
    }

    public void notifyCheckpointComplete(long checkpointId) {
    }

    public void notifyCheckpointAborted(long checkpointId) {
    }

    public <N> Stream<K> getKeys(String state, N namespace) {
        return this.stateValues.get(state).entrySet().stream().filter(e -> ((Map)e.getValue()).containsKey(namespace)).map(Map.Entry::getKey);
    }

    @Nonnull
    public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(long checkpointId, long timestamp, @Nonnull CheckpointStreamFactory streamFactory, @Nonnull CheckpointOptions checkpointOptions) {
        return new FutureTask<SnapshotResult<KeyedStateHandle>>(() -> SnapshotResult.of(new MockKeyedStateHandle<K>(MockKeyedStateBackend.copy(this.stateValues, this.stateSnapshotFilters))));
    }

    static <K> Map<String, Map<K, Map<Object, Object>>> copy(Map<String, Map<K, Map<Object, Object>>> stateValues, Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters) {
        HashMap<String, Map<String, Map<Object, Object>>> snapshotStates = new HashMap<String, Map<String, Map<Object, Object>>>();
        for (String stateName : stateValues.keySet()) {
            StateSnapshotTransformer stateSnapshotTransformer = stateSnapshotFilters.getOrDefault(stateName, null);
            Map keyedValues = snapshotStates.computeIfAbsent(stateName, s -> new HashMap());
            for (K key : stateValues.get(stateName).keySet()) {
                Map snapshotedValues = keyedValues.computeIfAbsent(key, s -> new HashMap());
                for (Object namespace : stateValues.get(stateName).get(key).keySet()) {
                    MockKeyedStateBackend.copyEntry(stateValues, snapshotedValues, stateName, key, namespace, (StateSnapshotTransformer<Object>)stateSnapshotTransformer);
                }
            }
        }
        return snapshotStates;
    }

    private static <K> void copyEntry(Map<String, Map<K, Map<Object, Object>>> stateValues, Map<Object, Object> snapshotedValues, String stateName, K key, Object namespace, StateSnapshotTransformer<Object> stateSnapshotTransformer) {
        Object filteredValue;
        Object value = stateValues.get(stateName).get(key).get(namespace);
        value = value instanceof List ? new ArrayList((List)value) : value;
        value = value instanceof Map ? new HashMap((Map)value) : value;
        Object object = filteredValue = stateSnapshotTransformer == null ? value : stateSnapshotTransformer.filterOrTransform(value);
        if (filteredValue != null) {
            snapshotedValues.put(namespace, filteredValue);
        }
    }

    @Nonnull
    public <T extends HeapPriorityQueueElement & PriorityComparable> KeyGroupedInternalPriorityQueue<T> create(@Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
        return new HeapPriorityQueueSet(PriorityComparator.forPriorityComparableObjects(), KeyExtractorFunction.forKeyedObjects(), 0, this.keyGroupRange, 0);
    }

    static class MockKeyedStateHandle<K>
    implements KeyedStateHandle {
        private static final long serialVersionUID = 1L;
        final Map<String, Map<K, Map<Object, Object>>> snapshotStates;

        MockKeyedStateHandle(Map<String, Map<K, Map<Object, Object>>> snapshotStates) {
            this.snapshotStates = snapshotStates;
        }

        public void discardState() {
            this.snapshotStates.clear();
        }

        public long getStateSize() {
            throw new UnsupportedOperationException();
        }

        public void registerSharedStates(SharedStateRegistry stateRegistry) {
        }

        public KeyGroupRange getKeyGroupRange() {
            throw new UnsupportedOperationException();
        }

        public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) {
            throw new UnsupportedOperationException();
        }
    }

    private static interface StateFactory {
        public <N, SV, S extends State, IS extends S> IS createInternalState(TypeSerializer<N> var1, StateDescriptor<S, SV> var2) throws Exception;
    }
}

