package org.apache.flink.runtime.state;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.testcontainers.utility.ThrowingFunction;

/* loaded from: input_file:org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.class */
public class OperatorStateRestoreOperationTest {
    private static ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory(ExecutionConfig executionConfig, CloseableRegistry closeableRegistry, ClassLoader classLoader) {
        return collection -> {
            return new DefaultOperatorStateBackendBuilder(classLoader, executionConfig, false, collection, closeableRegistry).build();
        };
    }

    private static OperatorStateHandle createOperatorStateHandle(ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> throwingFunction, Map<String, List<String>> map, Map<String, Map<String, String>> map2) throws Exception {
        OperatorStateBackend operatorStateBackend = (OperatorStateBackend) throwingFunction.apply(Collections.emptyList());
        try {
            for (String str : map.keySet()) {
                operatorStateBackend.getListState(new ListStateDescriptor(str, String.class)).addAll(map.get(str));
            }
            for (String str2 : map2.keySet()) {
                operatorStateBackend.getBroadcastState(new MapStateDescriptor(str2, String.class, String.class)).putAll(map2.get(str2));
            }
            OperatorStateHandle operatorStateHandle = (OperatorStateHandle) Objects.requireNonNull(((SnapshotResult) operatorStateBackend.snapshot(1L, 1L, new MemCheckpointStreamFactory(4096), CheckpointOptions.forCheckpointWithDefaultLocation()).get()).getJobManagerOwnedSnapshot());
            if (operatorStateBackend != null) {
                operatorStateBackend.close();
            }
            return operatorStateHandle;
        } catch (Throwable th) {
            if (operatorStateBackend != null) {
                try {
                    operatorStateBackend.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static void verifyOperatorStateHandle(ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> throwingFunction, Collection<OperatorStateHandle> collection, Map<String, List<String>> map, Map<String, Map<String, String>> map2) throws Exception {
        OperatorStateBackend operatorStateBackend = (OperatorStateBackend) throwingFunction.apply(collection);
        try {
            for (String str : map.keySet()) {
                Assertions.assertThat(operatorStateBackend.getListState(new ListStateDescriptor(str, String.class)).get()).containsExactlyElementsOf(map.get(str));
            }
            for (String str2 : map.keySet()) {
                Assertions.assertThat(operatorStateBackend.getListState(new ListStateDescriptor(str2, String.class)).get()).containsExactlyElementsOf(map.get(str2));
            }
            for (String str3 : map2.keySet()) {
                BroadcastState broadcastState = operatorStateBackend.getBroadcastState(new MapStateDescriptor(str3, String.class, String.class));
                HashMap hashMap = new HashMap();
                broadcastState.iterator().forEachRemaining(entry -> {
                    hashMap.put((String) entry.getKey(), (String) entry.getValue());
                });
                Assertions.assertThat(hashMap).containsAllEntriesOf(map2.get(str3));
            }
            if (operatorStateBackend != null) {
                operatorStateBackend.close();
            }
        } catch (Throwable th) {
            if (operatorStateBackend != null) {
                try {
                    operatorStateBackend.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @ValueSource(booleans = {true, false})
    @ParameterizedTest
    void testRestoringMixedOperatorState(boolean z) throws Exception {
        ExecutionConfig executionConfig = new ExecutionConfig();
        executionConfig.setUseSnapshotCompression(z);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory = createOperatorStateBackendFactory(executionConfig, new CloseableRegistry(), getClass().getClassLoader());
        HashMap hashMap = new HashMap();
        hashMap.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
        hashMap.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("a1", Collections.singletonMap("foo", "bar"));
        hashMap2.put("a2", Collections.singletonMap("bar", "foo"));
        verifyOperatorStateHandle(createOperatorStateBackendFactory, Collections.singletonList(createOperatorStateHandle(createOperatorStateBackendFactory, hashMap, hashMap2)), hashMap, hashMap2);
    }

    @ValueSource(booleans = {true, false})
    @ParameterizedTest
    void testMergeOperatorState(boolean z) throws Exception {
        ExecutionConfig executionConfig = new ExecutionConfig();
        executionConfig.setUseSnapshotCompression(z);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory = createOperatorStateBackendFactory(executionConfig, new CloseableRegistry(), getClass().getClassLoader());
        HashMap hashMap = new HashMap();
        hashMap.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
        hashMap.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("s1", Arrays.asList("foo4", "foo5", "foo6"));
        hashMap2.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
        OperatorStateHandle createOperatorStateHandle = createOperatorStateHandle(createOperatorStateBackendFactory, hashMap, Collections.emptyMap());
        OperatorStateHandle createOperatorStateHandle2 = createOperatorStateHandle(createOperatorStateBackendFactory, hashMap, Collections.emptyMap());
        HashMap hashMap3 = new HashMap();
        for (String str : hashMap.keySet()) {
            ((List) hashMap3.computeIfAbsent(str, str2 -> {
                return new ArrayList();
            })).addAll((Collection) hashMap.get(str));
        }
        for (String str3 : hashMap2.keySet()) {
            ((List) hashMap3.computeIfAbsent(str3, str4 -> {
                return new ArrayList();
            })).addAll((Collection) hashMap.get(str3));
        }
        verifyOperatorStateHandle(createOperatorStateBackendFactory, Arrays.asList(createOperatorStateHandle, createOperatorStateHandle2), hashMap3, Collections.emptyMap());
    }

    @ValueSource(booleans = {true, false})
    @ParameterizedTest
    void testEmptyPartitionedOperatorState(boolean z) throws Exception {
        ExecutionConfig executionConfig = new ExecutionConfig();
        executionConfig.setUseSnapshotCompression(z);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory = createOperatorStateBackendFactory(executionConfig, new CloseableRegistry(), getClass().getClassLoader());
        HashMap hashMap = new HashMap();
        hashMap.put("bufferState", Collections.emptyList());
        hashMap.put("offsetState", Collections.singletonList("foo"));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("whateverState", Collections.emptyMap());
        verifyOperatorStateHandle(createOperatorStateBackendFactory, Collections.singletonList(createOperatorStateHandle(createOperatorStateBackendFactory, hashMap, hashMap2)), hashMap, hashMap2);
    }

    @ValueSource(booleans = {true, false})
    @ParameterizedTest
    void testRepartitionOperatorState(boolean z) throws Exception {
        ExecutionConfig executionConfig = new ExecutionConfig();
        executionConfig.setUseSnapshotCompression(z);
        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> createOperatorStateBackendFactory = createOperatorStateBackendFactory(executionConfig, new CloseableRegistry(), getClass().getClassLoader());
        HashMap hashMap = new HashMap();
        hashMap.put("bufferState", (List) IntStream.range(0, 10).mapToObj(i -> {
            return "foo" + i;
        }).collect(Collectors.toList()));
        hashMap.put("offsetState", (List) IntStream.range(0, 10).mapToObj(i2 -> {
            return "bar" + i2;
        }).collect(Collectors.toList()));
        OperatorStateHandle createOperatorStateHandle = createOperatorStateHandle(createOperatorStateBackendFactory, hashMap, Collections.emptyMap());
        Iterator it = Arrays.asList(1, 2, 5, 10).iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            List repartitionState = new RoundRobinOperatorStateRepartitioner().repartitionState(Collections.singletonList(Collections.singletonList(createOperatorStateHandle)), 1, intValue);
            for (int i3 = 0; i3 < intValue; i3++) {
                verifyOperatorStateHandle(createOperatorStateBackendFactory, (Collection) repartitionState.get(i3), getExpectedSplit(hashMap, intValue, i3), Collections.emptyMap());
            }
        }
    }

    private static Map<String, List<String>> getExpectedSplit(Map<String, List<String>> map, int i, int i2) {
        HashMap hashMap = new HashMap();
        for (String str : map.keySet()) {
            int size = map.get(str).size();
            hashMap.put(str, map.get(str).subList((i2 * size) / i, ((i2 + 1) * size) / i));
        }
        return hashMap;
    }
}
