/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.FullyFinishedOperatorState;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.checkpoint.StateHandleDummyUtil;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class StateAssignmentOperationTest {
    @RegisterExtension
    private static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION = TestingUtils.defaultExecutorExtension();
    private static final int MAX_P = 256;

    StateAssignmentOperationTest() {
    }

    @Test
    void testRepartitionSplitDistributeStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[30]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 15L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[40]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    void testRepartitionUnionState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[50]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[20]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testPartiallyReported() {
        RoundRobinOperatorStateRepartitioner.StateEntry stateEntry = new RoundRobinOperatorStateRepartitioner.StateEntry(0, 5);
        stateEntry.addEntry(0, null);
        stateEntry.addEntry(1, null);
        stateEntry.addEntry(3, null);
        org.junit.jupiter.api.Assertions.assertTrue((boolean)stateEntry.isPartiallyReported());
        stateEntry.addEntry(2, null);
        stateEntry.addEntry(4, null);
        org.junit.jupiter.api.Assertions.assertFalse((boolean)stateEntry.isPartiallyReported());
    }

    @Test
    void testRepartitionBroadcastState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[60]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[60]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    void testRepartitionBroadcastStateWithNullSubtaskState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[60]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    void testRepartitionBroadcastStateWithEmptySubtaskState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[60]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        operatorState.putState(1, OperatorSubtaskState.builder().build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    void testReDistributeCombinedPartitionableStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(null, null, operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(6);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{52L, 63L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{67L, 74L, 75L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{77L, 88L, 92L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{101L, 123L, 127L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[130]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(3);
        metaInfoMap2.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap2.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{20L, 27L, 28L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 44L, 48L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{57L, 79L, 83L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[86]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyAndCollectStateInfo(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism, Map<String, Integer> stateInfoCounts) {
        HashMap newManagedOperatorStates = new HashMap();
        StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonMap(operatorID, operatorState), (int)newParallelism, OperatorSubtaskState::getManagedOperatorState, (OperatorStateRepartitioner)RoundRobinOperatorStateRepartitioner.INSTANCE, newManagedOperatorStates);
        for (List operatorStateHandles : newManagedOperatorStates.values()) {
            EnumMap stateModeOffsets = new EnumMap(OperatorStateHandle.Mode.class);
            for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
                stateModeOffsets.put(mode, new HashMap());
            }
            for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
                for (Map.Entry stateNameToMetaInfo : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
                    String stateName = (String)stateNameToMetaInfo.getKey();
                    stateInfoCounts.merge(stateName, 1, Integer::sum);
                    OperatorStateHandle.StateMetaInfo stateMetaInfo = (OperatorStateHandle.StateMetaInfo)stateNameToMetaInfo.getValue();
                    ((Map)stateModeOffsets.get(stateMetaInfo.getDistributionMode())).merge(stateName, stateMetaInfo.getOffsets().length, Integer::sum);
                }
            }
            for (Map.Entry entry : stateModeOffsets.entrySet()) {
                OperatorStateHandle.Mode mode = (OperatorStateHandle.Mode)entry.getKey();
                Map stateOffsets = (Map)entry.getValue();
                if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals((Object)mode)) {
                    if (oldParallelism < newParallelism) {
                        stateOffsets.values().forEach(length -> Assertions.assertThat((int)length).isOne());
                        continue;
                    }
                    stateOffsets.values().forEach(length -> Assertions.assertThat((int)length).isEqualTo(2));
                    continue;
                }
                if (OperatorStateHandle.Mode.UNION.equals((Object)mode)) {
                    stateOffsets.values().forEach(length -> Assertions.assertThat((int)length).isEqualTo(2));
                    continue;
                }
                stateOffsets.values().forEach(length -> Assertions.assertThat((int)length).isEqualTo(3));
            }
        }
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID) {
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assertions.assertThat(stateInfoCounts).hasSize(2);
        if (stateInfoCounts.containsKey("t-1")) {
            if (oldParallelism < newParallelism) {
                Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-1"))).isEqualTo(2);
                Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-2"))).isEqualTo(2);
            } else {
                Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-1"))).isOne();
                Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-2"))).isOne();
            }
        }
        if (stateInfoCounts.containsKey("t-3")) {
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-3"))).isEqualTo(2 * newParallelism);
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-4"))).isEqualTo(newParallelism);
        }
        if (stateInfoCounts.containsKey("t-5")) {
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-5"))).isEqualTo(newParallelism);
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-6"))).isEqualTo(newParallelism);
        }
    }

    private void verifyCombinedPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assertions.assertThat((int)stateInfoCounts.size()).isEqualTo(6);
        Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-1"))).isEqualTo(2 * newParallelism);
        Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-2"))).isEqualTo(newParallelism);
        if (oldParallelism < newParallelism) {
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-3"))).isEqualTo(2);
        } else {
            Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-3"))).isOne();
        }
        Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-4"))).isEqualTo(newParallelism);
        Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-5"))).isEqualTo(newParallelism);
        Assertions.assertThat((int)((Integer)stateInfoCounts.get("t-6"))).isEqualTo(newParallelism);
    }

    @Test
    void testChannelStateAssignmentStability() throws JobException, JobExecutionException {
        int numOperators = 10;
        int numSubTasks = 100;
        List<OperatorID> operatorIds = this.buildOperatorIds(numOperators);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, numSubTasks, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            for (int subtaskIdx = 0; subtaskIdx < numSubTasks; ++subtaskIdx) {
                Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorId), operatorId, subtaskIdx)).isEqualTo((Object)states.get(operatorId).getState(subtaskIdx));
            }
        }
    }

    @Test
    void testChannelStateAssignmentDownscalingTwoDifferentGates() throws JobException, JobExecutionException {
        JobVertex upstream1 = this.createJobVertex(new OperatorID(), 2);
        JobVertex upstream2 = this.createJobVertex(new OperatorID(), 2);
        JobVertex downstream = this.createJobVertex(new OperatorID(), 2);
        List<OperatorID> operatorIds = Stream.of(upstream1, upstream2, downstream).map(v -> ((OperatorIDPair)v.getOperatorIDs().get(0)).getGeneratedOperatorID()).collect(Collectors.toList());
        Map<OperatorID, OperatorState> states = this.buildOperatorStatesForTwoGates(operatorIds, 3);
        this.connectVertices(upstream1, downstream, SubtaskStateMapper.ARBITRARY, SubtaskStateMapper.RANGE);
        this.connectVertices(upstream2, downstream, SubtaskStateMapper.ROUND_ROBIN, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, ExecutionJobVertex> vertices = this.toExecutionVertices(upstream1, upstream2, downstream);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 0).getInputRescalingDescriptor()).isEqualTo((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.array(this.gate(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1)), InflightDataRescalingDescriptorUtil.set(1), InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING), this.gate(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1)), Collections.emptySet(), InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING))));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 0).getInputRescalingDescriptor()).isEqualTo((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.array(this.gate(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1)), InflightDataRescalingDescriptorUtil.set(1), InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING), this.gate(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1)), Collections.emptySet(), InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING))));
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor gate(int[] oldIndices, RescaleMappings rescaleMapping, Set<Integer> ambiguousSubtaskIndexes, InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType mappingType) {
        return new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor(oldIndices, rescaleMapping, ambiguousSubtaskIndexes, mappingType);
    }

    @Test
    public void testChannelStateAssignmentTwoGatesPartiallyDownscaling() throws JobException, JobExecutionException {
        JobVertex upstream1 = this.createJobVertex(new OperatorID(), 2);
        JobVertex upstream2 = this.createJobVertex(new OperatorID(), 2);
        JobVertex downstream = this.createJobVertex(new OperatorID(), 3);
        List<OperatorID> operatorIds = Stream.of(upstream1, upstream2, downstream).map(v -> ((OperatorIDPair)v.getOperatorIDs().get(0)).getGeneratedOperatorID()).collect(Collectors.toList());
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 3);
        this.connectVertices(upstream1, downstream, SubtaskStateMapper.ARBITRARY, SubtaskStateMapper.FULL);
        this.connectVertices(upstream2, downstream, SubtaskStateMapper.ROUND_ROBIN, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, ExecutionJobVertex> vertices = this.toExecutionVertices(upstream1, upstream2, downstream);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        Assertions.assertThat((int)this.getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 0).getInputChannelState().size()).isEqualTo(6);
        Assertions.assertThat((int)this.getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 1).getInputChannelState().size()).isEqualTo(6);
        Assertions.assertThat((int)this.getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 2).getInputChannelState().size()).isEqualTo(6);
    }

    @Test
    void testChannelStateAssignmentDownscaling() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 3);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 2, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0, 1);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 1, 2);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0, 2);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
        }
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1, 2))), InflightDataRescalingDescriptorUtil.set(new Integer[0])));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1, 2))), InflightDataRescalingDescriptorUtil.set(new Integer[0])));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(1)));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1, 2), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(1)));
    }

    @Test
    void testChannelStateAssignmentNoRescale() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 2);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 2, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 1);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
        }
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptor.NO_RESCALE);
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptor.NO_RESCALE);
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptor.NO_RESCALE);
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptor.NO_RESCALE);
    }

    @Test
    void testChannelStateAssignmentUpscaling() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 2);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 3, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 0, 1);
            this.assertState(vertices, operatorId, states, 2, OperatorSubtaskState::getInputChannelState, 1);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
            this.assertState(vertices, operatorId, states, 2, OperatorSubtaskState::getResultSubpartitionState, new int[0]);
        }
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(new Integer[0])));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(new Integer[0])));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 2).getOutputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptor.NO_RESCALE);
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)));
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 2).getInputRescalingDescriptor()).isEqualTo((Object)InflightDataRescalingDescriptorUtil.rescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)));
    }

    @Test
    void testOnlyUpstreamChannelStateAssignment() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        HashMap<OperatorID, OperatorState> states = new HashMap<OperatorID, OperatorState>();
        Random random = new Random();
        OperatorState upstreamState = new OperatorState(null, null, operatorIds.get(0), 2, 256);
        OperatorSubtaskState state = OperatorSubtaskState.builder().setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)))).build();
        upstreamState.putState(0, state);
        states.put(operatorIds.get(0), upstreamState);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 3, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        ExecutionJobVertex jobVertexWithFinishedOperator = vertices.get(operatorIds.get(0));
        for (ExecutionVertex task : jobVertexWithFinishedOperator.getTaskVertices()) {
            Assertions.assertThat((Object)task.getCurrentExecutionAttempt().getTaskRestore()).isNotNull();
        }
        ExecutionJobVertex jobVertexWithoutFinishedOperator = vertices.get(operatorIds.get(1));
        for (ExecutionVertex task : jobVertexWithoutFinishedOperator.getTaskVertices()) {
            Assertions.assertThat((Object)task.getCurrentExecutionAttempt().getTaskRestore()).isNotNull();
        }
    }

    @Test
    void testOnlyUpstreamChannelRescaleStateAssignment() throws JobException, JobExecutionException {
        Random random = new Random();
        OperatorSubtaskState upstreamOpState = OperatorSubtaskState.builder().setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)))).build();
        this.testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 5, 7);
    }

    @Test
    void testOnlyDownstreamChannelRescaleStateAssignment() throws JobException, JobExecutionException {
        Random random = new Random();
        OperatorSubtaskState downstreamOpState = OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, random), StateHandleDummyUtil.createNewInputChannelStateHandle(10, random)))).build();
        this.testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 5, 5);
    }

    private void testOnlyUpstreamOrDownstreamRescalingInternal(@Nullable OperatorSubtaskState upstreamOpState, @Nullable OperatorSubtaskState downstreamOpState, int expectedUpstreamCount, int expectedDownstreamCount) throws JobException, JobExecutionException {
        Preconditions.checkArgument((upstreamOpState != downstreamOpState && (upstreamOpState == null || downstreamOpState == null) ? 1 : 0) != 0, (Object)"Either upstream or downstream state must exist, but not both");
        int upstreamParallelism = 5;
        int downstreamParallelism = 5;
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        HashMap<OperatorID, OperatorState> states = new HashMap<OperatorID, OperatorState>();
        OperatorState upstreamState = new OperatorState(null, null, operatorIds.get(0), upstreamParallelism, 256);
        OperatorState downstreamState = new OperatorState(null, null, operatorIds.get(1), downstreamParallelism, 256);
        states.put(operatorIds.get(0), upstreamState);
        states.put(operatorIds.get(1), downstreamState);
        if (upstreamOpState != null) {
            upstreamState.putState(0, upstreamOpState);
            downstreamParallelism = 3;
        }
        if (downstreamOpState != null) {
            downstreamState.putState(0, downstreamOpState);
            upstreamParallelism = 3;
        }
        ArrayList<OperatorIdWithParallelism> opIdWithParallelism = new ArrayList<OperatorIdWithParallelism>(2);
        opIdWithParallelism.add(new OperatorIdWithParallelism(operatorIds.get(0), upstreamParallelism));
        opIdWithParallelism.add(new OperatorIdWithParallelism(operatorIds.get(1), downstreamParallelism));
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(opIdWithParallelism, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        ExecutionJobVertex upstreamExecutionJobVertex = vertices.get(operatorIds.get(0));
        ExecutionJobVertex downstreamExecutionJobVertex = vertices.get(operatorIds.get(1));
        List<TaskStateSnapshot> upstreamTaskStateSnapshots = this.getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
        List<TaskStateSnapshot> downstreamTaskStateSnapshots = this.getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);
        this.checkMappings(upstreamTaskStateSnapshots, TaskStateSnapshot::getOutputRescalingDescriptor, expectedUpstreamCount);
        this.checkMappings(downstreamTaskStateSnapshots, TaskStateSnapshot::getInputRescalingDescriptor, expectedDownstreamCount);
    }

    private void checkMappings(List<TaskStateSnapshot> taskStateSnapshots, Function<TaskStateSnapshot, InflightDataRescalingDescriptor> extractFun, int expectedCount) {
        Assertions.assertThat((int)taskStateSnapshots.stream().map(extractFun).mapToInt(x -> {
            int len = x.getOldSubtaskIndexes(0).length;
            Assertions.assertThat((int)len).isGreaterThan(0);
            return len;
        }).sum()).isEqualTo(expectedCount);
    }

    @Test
    void testStateWithFullyFinishedOperators() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(Collections.singletonList(operatorIds.get(1)), 3);
        FullyFinishedOperatorState operatorState = new FullyFinishedOperatorState(null, null, operatorIds.get(0), 3, 256);
        states.put(operatorIds.get(0), (OperatorState)operatorState);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 2, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        ExecutionJobVertex jobVertexWithFinishedOperator = vertices.get(operatorIds.get(0));
        for (ExecutionVertex task : jobVertexWithFinishedOperator.getTaskVertices()) {
            JobManagerTaskRestore taskRestore = task.getCurrentExecutionAttempt().getTaskRestore();
            Assertions.assertThat((boolean)taskRestore.getTaskStateSnapshot().isTaskDeployedAsFinished()).isTrue();
        }
        ExecutionJobVertex jobVertexWithoutFinishedOperator = vertices.get(operatorIds.get(1));
        for (ExecutionVertex task : jobVertexWithoutFinishedOperator.getTaskVertices()) {
            JobManagerTaskRestore taskRestore = task.getCurrentExecutionAttempt().getTaskRestore();
            Assertions.assertThat((boolean)taskRestore.getTaskStateSnapshot().isTaskDeployedAsFinished()).isFalse();
        }
    }

    private void assertState(Map<OperatorID, ExecutionJobVertex> vertices, OperatorID operatorId, Map<OperatorID, OperatorState> states, int newSubtaskIndex, Function<OperatorSubtaskState, StateObjectCollection<?>> extractor, int ... oldSubtaskIndexes) {
        OperatorSubtaskState subState = this.getAssignedState(vertices.get(operatorId), operatorId, newSubtaskIndex);
        Assertions.assertThat((boolean)extractor.apply(subState).containsAll((Collection)Arrays.stream(oldSubtaskIndexes).boxed().flatMap(oldIndex -> ((StateObjectCollection)extractor.apply(((OperatorState)states.get(operatorId)).getState(oldIndex.intValue()))).stream()).collect(Collectors.toList()))).isTrue();
    }

    @Test
    void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell() {
        int numSubTasks = 1;
        OperatorID operatorId = new OperatorID();
        OperatorID userDefinedOperatorId = new OperatorID();
        List<OperatorID> operatorIds = Collections.singletonList(userDefinedOperatorId);
        ExecutionJobVertex executionJobVertex = this.buildExecutionJobVertex(operatorId, userDefinedOperatorId, 1);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, Collections.singleton(executionJobVertex), states, false).assignStates();
        Assertions.assertThat((Object)this.getAssignedState(executionJobVertex, operatorId, 0)).isEqualTo((Object)states.get(userDefinedOperatorId).getState(0));
    }

    @Test
    void assigningStateHandlesCanNotBeNull() {
        OperatorState state = new OperatorState(null, null, new OperatorID(), 1, 256);
        List managedKeyedStateHandles = StateAssignmentOperation.getManagedKeyedStateHandles((OperatorState)state, (KeyGroupRange)KeyGroupRange.of((int)0, (int)1));
        List rawKeyedStateHandles = StateAssignmentOperation.getRawKeyedStateHandles((OperatorState)state, (KeyGroupRange)KeyGroupRange.of((int)0, (int)1));
        Assertions.assertThat((List)managedKeyedStateHandles).isEmpty();
        Assertions.assertThat((List)rawKeyedStateHandles).isEmpty();
    }

    private List<OperatorID> buildOperatorIds(int numOperators) {
        return IntStream.range(0, numOperators).mapToObj(j -> new OperatorID()).collect(Collectors.toList());
    }

    private Map<OperatorID, OperatorState> buildOperatorStates(List<OperatorID> operatorIDs, int numSubTasks) {
        Random random = new Random();
        OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1);
        return operatorIDs.stream().collect(Collectors.toMap(Function.identity(), operatorID -> {
            OperatorState state = new OperatorState("", "", operatorID, numSubTasks, 256);
            for (int i = 0; i < numSubTasks; ++i) {
                state.putState(i, OperatorSubtaskState.builder().setManagedOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random)))).setRawOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random)))).setManagedKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i)))).setRawKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i)))).setInputChannelState(operatorID == operatorIDs.get(0) ? StateObjectCollection.empty() : new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, random), StateHandleDummyUtil.createNewInputChannelStateHandle(10, random)))).setResultSubpartitionState(operatorID == lastId ? StateObjectCollection.empty() : new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)))).build());
            }
            return state;
        }));
    }

    private Map<OperatorID, OperatorState> buildOperatorStatesForTwoGates(List<OperatorID> operatorIDs, int numSubTasks) {
        Random random = new Random();
        return operatorIDs.stream().collect(Collectors.toMap(Function.identity(), operatorID -> {
            OperatorState state = new OperatorState("", "", operatorID, numSubTasks, 256);
            for (int i = 0; i < numSubTasks; ++i) {
                OperatorSubtaskState.Builder builder = OperatorSubtaskState.builder().setManagedOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random)))).setRawOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(5, random), StateHandleDummyUtil.createNewOperatorStateHandle(5, random)))).setManagedKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i)))).setRawKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i))));
                if (operatorID == operatorIDs.get(2)) {
                    builder.setInputChannelState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, 0, random), StateHandleDummyUtil.createNewInputChannelStateHandle(10, 1, random))));
                } else {
                    builder.setInputChannelState(StateObjectCollection.empty());
                }
                if (operatorID != operatorIDs.get(2)) {
                    builder.setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, 0, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, 0, random))));
                } else {
                    builder.setResultSubpartitionState(StateObjectCollection.empty());
                }
                state.putState(i, builder.build());
            }
            return state;
        }));
    }

    private Map<OperatorID, ExecutionJobVertex> buildVertices(List<OperatorID> operatorIds, int parallelisms, SubtaskStateMapper downstreamRescaler, SubtaskStateMapper upstreamRescaler) throws JobException, JobExecutionException {
        List<OperatorIdWithParallelism> opIdsWithParallelism = operatorIds.stream().map(operatorID -> new OperatorIdWithParallelism((OperatorID)operatorID, parallelisms)).collect(Collectors.toList());
        return this.buildVertices(opIdsWithParallelism, downstreamRescaler, upstreamRescaler);
    }

    private Map<OperatorID, ExecutionJobVertex> buildVertices(List<OperatorIdWithParallelism> operatorIdsAndParallelism, SubtaskStateMapper downstreamRescaler, SubtaskStateMapper upstreamRescaler) throws JobException, JobExecutionException {
        JobVertex[] jobVertices = (JobVertex[])operatorIdsAndParallelism.stream().map(idWithParallelism -> this.createJobVertex(idWithParallelism.getOperatorID(), idWithParallelism.getOperatorID(), idWithParallelism.getParallelism())).toArray(JobVertex[]::new);
        for (int index = 1; index < jobVertices.length; ++index) {
            this.connectVertices(jobVertices[index - 1], jobVertices[index], upstreamRescaler, downstreamRescaler);
        }
        return this.toExecutionVertices(jobVertices);
    }

    private Map<OperatorID, ExecutionJobVertex> toExecutionVertices(JobVertex ... jobVertices) throws JobException, JobExecutionException {
        JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices);
        DefaultExecutionGraph eg = TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).build((ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor());
        return Arrays.stream(jobVertices).collect(Collectors.toMap(jobVertex -> ((OperatorIDPair)jobVertex.getOperatorIDs().get(0)).getGeneratedOperatorID(), arg_0 -> StateAssignmentOperationTest.lambda$toExecutionVertices$15((ExecutionGraph)eg, arg_0)));
    }

    private void connectVertices(JobVertex upstream, JobVertex downstream, SubtaskStateMapper upstreamRescaler, SubtaskStateMapper downstreamRescaler) {
        JobEdge jobEdge = JobVertexConnectionUtils.connectNewDataSetAsInput(downstream, upstream, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
        jobEdge.setDownstreamSubtaskStateMapper(downstreamRescaler);
        jobEdge.setUpstreamSubtaskStateMapper(upstreamRescaler);
    }

    private ExecutionJobVertex buildExecutionJobVertex(OperatorID operatorID, OperatorID userDefinedOperatorId, int parallelism) {
        try {
            JobVertex jobVertex = this.createJobVertex(operatorID, userDefinedOperatorId, parallelism);
            return ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex);
        }
        catch (Exception e) {
            throw new AssertionError("Cannot create ExecutionJobVertex", e);
        }
    }

    private JobVertex createJobVertex(OperatorID operatorID, int parallelism) {
        return this.createJobVertex(operatorID, operatorID, parallelism);
    }

    private JobVertex createJobVertex(OperatorID operatorID, OperatorID userDefinedOperatorId, int parallelism) {
        JobVertex jobVertex = new JobVertex(operatorID.toHexString(), new JobVertexID(), Collections.singletonList(OperatorIDPair.of((OperatorID)operatorID, (OperatorID)userDefinedOperatorId, (String)"operatorName", (String)"operatorUid")));
        jobVertex.setInvokableClass(NoOpInvokable.class);
        jobVertex.setParallelism(parallelism);
        return jobVertex;
    }

    private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(ExecutionJobVertex executionJobVertex) {
        return Arrays.stream(executionJobVertex.getTaskVertices()).map(ExecutionVertex::getCurrentExecutionAttempt).map(Execution::getTaskRestore).map(JobManagerTaskRestore::getTaskStateSnapshot).collect(Collectors.toList());
    }

    private OperatorSubtaskState getAssignedState(ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
        return executionJobVertex.getTaskVertices()[subtaskIdx].getCurrentExecutionAttempt().getTaskRestore().getTaskStateSnapshot().getSubtaskStateByOperatorID(operatorId);
    }

    @Test
    void testMixedExchangesForwardAndHashNoStateOnForward() throws JobException, JobExecutionException {
        JobVertex source = this.createJobVertex(new OperatorID(), 2);
        JobVertex map1 = this.createJobVertex(new OperatorID(), 2);
        JobVertex map2 = this.createJobVertex(new OperatorID(), 3);
        List operatorIds = Stream.of(source, map1, map2).map(v -> ((OperatorIDPair)v.getOperatorIDs().get(0)).getGeneratedOperatorID()).collect(Collectors.toList());
        HashMap<OperatorID, OperatorState> states = new HashMap<OperatorID, OperatorState>();
        Random random = new Random();
        OperatorState sourceState = new OperatorState("", "", (OperatorID)operatorIds.get(0), 2, 256);
        for (int i = 0; i < 2; ++i) {
            sourceState.putState(i, OperatorSubtaskState.builder().setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, 1, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(0), sourceState);
        OperatorState map1State = new OperatorState("", "", (OperatorID)operatorIds.get(1), 2, 256);
        for (int i = 0; i < 2; ++i) {
            map1State.putState(i, OperatorSubtaskState.builder().build());
        }
        states.put((OperatorID)operatorIds.get(1), map1State);
        OperatorState map2State = new OperatorState("", "", (OperatorID)operatorIds.get(2), 2, 256);
        for (int i = 0; i < 2; ++i) {
            map2State.putState(i, OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, 0, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(2), map2State);
        this.connectVertices(source, map1, SubtaskStateMapper.RANGE, SubtaskStateMapper.RANGE);
        this.connectVertices(source, map2, SubtaskStateMapper.ARBITRARY, SubtaskStateMapper.RANGE);
        Map<OperatorID, ExecutionJobVertex> vertices = this.toExecutionVertices(source, map1, map2);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        Assertions.assertThat((Object)this.getAssignedState(vertices.get(operatorIds.get(2)), (OperatorID)operatorIds.get(2), 0)).isNotNull();
    }

    @Test
    void testMixedExchangesMultipleGatesWithPartialState() throws JobException, JobExecutionException {
        JobVertex upstream1 = this.createJobVertex(new OperatorID(), 2);
        JobVertex upstream2 = this.createJobVertex(new OperatorID(), 2);
        JobVertex upstream3 = this.createJobVertex(new OperatorID(), 2);
        JobVertex downstream = this.createJobVertex(new OperatorID(), 2);
        List operatorIds = Stream.of(upstream1, upstream2, upstream3, downstream).map(v -> ((OperatorIDPair)v.getOperatorIDs().get(0)).getGeneratedOperatorID()).collect(Collectors.toList());
        HashMap<OperatorID, OperatorState> states = new HashMap<OperatorID, OperatorState>();
        Random random = new Random();
        OperatorState upstream1State = new OperatorState("", "", (OperatorID)operatorIds.get(0), 3, 256);
        for (int i = 0; i < 3; ++i) {
            upstream1State.putState(i, OperatorSubtaskState.builder().build());
        }
        states.put((OperatorID)operatorIds.get(0), upstream1State);
        OperatorState upstream2State = new OperatorState("", "", (OperatorID)operatorIds.get(1), 3, 256);
        for (int i = 0; i < 3; ++i) {
            upstream2State.putState(i, OperatorSubtaskState.builder().setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, 0, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(1), upstream2State);
        OperatorState upstream3State = new OperatorState("", "", (OperatorID)operatorIds.get(2), 3, 256);
        for (int i = 0; i < 3; ++i) {
            upstream3State.putState(i, OperatorSubtaskState.builder().build());
        }
        states.put((OperatorID)operatorIds.get(2), upstream3State);
        OperatorState downstreamState = new OperatorState("", "", (OperatorID)operatorIds.get(3), 3, 256);
        for (int i = 0; i < 3; ++i) {
            downstreamState.putState(i, OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, 1, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(3), downstreamState);
        this.connectVertices(upstream1, downstream, SubtaskStateMapper.RANGE, SubtaskStateMapper.RANGE);
        this.connectVertices(upstream2, downstream, SubtaskStateMapper.ARBITRARY, SubtaskStateMapper.RANGE);
        this.connectVertices(upstream3, downstream, SubtaskStateMapper.ROUND_ROBIN, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, ExecutionJobVertex> vertices = this.toExecutionVertices(upstream1, upstream2, upstream3, downstream);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        OperatorSubtaskState downstreamAssignedState = this.getAssignedState(vertices.get(operatorIds.get(3)), (OperatorID)operatorIds.get(3), 0);
        Assertions.assertThat((Object)downstreamAssignedState).isNotNull();
        Assertions.assertThat((Collection)downstreamAssignedState.getInputChannelState()).isNotEmpty();
    }

    @Test
    void testMixedExchangesRescaleAndRebalanceNoStateOnRescale() throws JobException, JobExecutionException {
        JobVertex source = this.createJobVertex(new OperatorID(), 4);
        JobVertex sink = this.createJobVertex(new OperatorID(), 2);
        List operatorIds = Stream.of(source, sink).map(v -> ((OperatorIDPair)v.getOperatorIDs().get(0)).getGeneratedOperatorID()).collect(Collectors.toList());
        HashMap<OperatorID, OperatorState> states = new HashMap<OperatorID, OperatorState>();
        Random random = new Random();
        OperatorState sourceState = new OperatorState("", "", (OperatorID)operatorIds.get(0), 4, 256);
        for (int i = 0; i < 4; ++i) {
            sourceState.putState(i, OperatorSubtaskState.builder().setResultSubpartitionState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, 0, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(0), sourceState);
        OperatorState sinkState = new OperatorState("", "", (OperatorID)operatorIds.get(1), 4, 256);
        for (int i = 0; i < 4; ++i) {
            sinkState.putState(i, OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, 0, random)))).build());
        }
        states.put((OperatorID)operatorIds.get(1), sinkState);
        this.connectVertices(source, sink, SubtaskStateMapper.ROUND_ROBIN, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, ExecutionJobVertex> vertices = this.toExecutionVertices(source, sink);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        OperatorSubtaskState sinkAssignedState = this.getAssignedState(vertices.get(operatorIds.get(1)), (OperatorID)operatorIds.get(1), 0);
        Assertions.assertThat((Object)sinkAssignedState).isNotNull();
    }

    private static /* synthetic */ ExecutionJobVertex lambda$toExecutionVertices$15(ExecutionGraph eg, JobVertex jobVertex) {
        try {
            return eg.getJobVertex(jobVertex.getID());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static class OperatorIdWithParallelism {
        private final OperatorID operatorID;
        private final int parallelism;

        public OperatorID getOperatorID() {
            return this.operatorID;
        }

        public int getParallelism() {
            return this.parallelism;
        }

        public OperatorIdWithParallelism(OperatorID operatorID, int parallelism) {
            this.operatorID = operatorID;
            this.parallelism = parallelism;
        }
    }
}

