package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/StateAssignmentOperation.class */
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Set<ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment = new HashMap();

    public StateAssignmentOperation(long j, Set<ExecutionJobVertex> set, Map<OperatorID, OperatorState> map, boolean z) {
        this.restoreCheckpointId = j;
        this.tasks = (Set) Preconditions.checkNotNull(set);
        this.operatorStates = (Map) Preconditions.checkNotNull(map);
        this.allowNonRestoredState = z;
        this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(set.size());
    }

    public void assignStates() {
        checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        HashMap hashMap = new HashMap(this.operatorStates);
        for (ExecutionJobVertex executionJobVertex : this.tasks) {
            List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
            HashMap newHashMapWithExpectedSize = CollectionUtil.newHashMapWithExpectedSize(operatorIDs.size());
            for (OperatorIDPair operatorIDPair : operatorIDs) {
                Optional<OperatorID> userDefinedOperatorID = operatorIDPair.getUserDefinedOperatorID();
                Objects.requireNonNull(hashMap);
                OperatorID orElse = userDefinedOperatorID.filter((v1) -> {
                    return r1.containsKey(v1);
                }).orElse(operatorIDPair.getGeneratedOperatorID());
                OperatorState operatorState = (OperatorState) hashMap.remove(orElse);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorIDPair.getUserDefinedOperatorName(), operatorIDPair.getUserDefinedOperatorUid(), orElse, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                }
                newHashMapWithExpectedSize.put(operatorIDPair.getGeneratedOperatorID(), operatorState);
            }
            TaskStateAssignment taskStateAssignment = new TaskStateAssignment(executionJobVertex, newHashMapWithExpectedSize, this.consumerAssignment, this.vertexAssignments);
            this.vertexAssignments.put(executionJobVertex, taskStateAssignment);
            Iterator<IntermediateResult> it = executionJobVertex.getInputs().iterator();
            while (it.hasNext()) {
                this.consumerAssignment.put(it.next().getId(), taskStateAssignment);
            }
        }
        for (TaskStateAssignment taskStateAssignment2 : this.vertexAssignments.values()) {
            if (taskStateAssignment2.hasNonFinishedState || taskStateAssignment2.hasUpstreamOutputStates() || taskStateAssignment2.hasDownstreamInputStates()) {
                assignAttemptState(taskStateAssignment2);
            }
        }
        for (TaskStateAssignment taskStateAssignment3 : this.vertexAssignments.values()) {
            if (taskStateAssignment3.hasNonFinishedState || taskStateAssignment3.isFullyFinished || taskStateAssignment3.hasUpstreamOutputStates() || taskStateAssignment3.hasDownstreamInputStates()) {
                assignTaskStateToExecutionJobVertices(taskStateAssignment3);
            }
        }
    }

    private void assignAttemptState(TaskStateAssignment taskStateAssignment) {
        checkParallelismPreconditions(taskStateAssignment);
        List<KeyGroupRange> createKeyGroupPartitions = createKeyGroupPartitions(taskStateAssignment.executionJobVertex.getMaxParallelism(), taskStateAssignment.newParallelism);
        reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, (v0) -> {
            return v0.getManagedOperatorState();
        }, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subManagedOperatorState);
        reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, (v0) -> {
            return v0.getRawOperatorState();
        }, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subRawOperatorState);
        reDistributeInputChannelStates(taskStateAssignment);
        reDistributeResultSubpartitionStates(taskStateAssignment);
        reDistributeKeyedStates(createKeyGroupPartitions, taskStateAssignment);
    }

    private void assignTaskStateToExecutionJobVertices(TaskStateAssignment taskStateAssignment) {
        ExecutionJobVertex executionJobVertex = taskStateAssignment.executionJobVertex;
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        int parallelism = executionJobVertex.getParallelism();
        for (int i = 0; i < parallelism; i++) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt();
            if (taskStateAssignment.isFullyFinished) {
                assignFinishedStateToTask(currentExecutionAttempt);
            } else {
                assignNonFinishedStateToTask(taskStateAssignment, operatorIDs, i, currentExecutionAttempt);
            }
        }
    }

    private void assignFinishedStateToTask(Execution execution) {
        execution.setInitialState(new JobManagerTaskRestore(this.restoreCheckpointId, TaskStateSnapshot.FINISHED_ON_RESTORE));
    }

    private void assignNonFinishedStateToTask(TaskStateAssignment taskStateAssignment, List<OperatorIDPair> list, int i, Execution execution) {
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(list.size(), false);
        for (OperatorIDPair operatorIDPair : list) {
            taskStateSnapshot.putSubtaskStateByOperatorID(operatorIDPair.getGeneratedOperatorID(), taskStateAssignment.getSubtaskState(OperatorInstanceID.of(i, operatorIDPair.getGeneratedOperatorID())));
        }
        execution.setInitialState(new JobManagerTaskRestore(this.restoreCheckpointId, taskStateSnapshot));
    }

    public void checkParallelismPreconditions(TaskStateAssignment taskStateAssignment) {
        Iterator<OperatorState> it = taskStateAssignment.oldState.values().iterator();
        while (it.hasNext()) {
            checkParallelismPreconditions(it.next(), taskStateAssignment.executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<KeyGroupRange> list, TaskStateAssignment taskStateAssignment) {
        taskStateAssignment.oldState.forEach((operatorID, operatorState) -> {
            for (int i = 0; i < taskStateAssignment.newParallelism; i++) {
                OperatorInstanceID of = OperatorInstanceID.of(i, operatorID);
                Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates = reAssignSubKeyedStates(operatorState, list, i, taskStateAssignment.newParallelism, operatorState.getParallelism());
                taskStateAssignment.subManagedKeyedState.put(of, reAssignSubKeyedStates.f0);
                taskStateAssignment.subRawKeyedState.put(of, reAssignSubKeyedStates.f1);
            }
        });
    }

    private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> list, int i, int i2, int i3) {
        List<KeyedStateHandle> managedKeyedStateHandles;
        List<KeyedStateHandle> rawKeyedStateHandles;
        if (i2 != i3) {
            managedKeyedStateHandles = getManagedKeyedStateHandles(operatorState, list.get(i));
            rawKeyedStateHandles = getRawKeyedStateHandles(operatorState, list.get(i));
        } else if (operatorState.getState(i) != null) {
            managedKeyedStateHandles = operatorState.getState(i).getManagedKeyedState().asList();
            rawKeyedStateHandles = operatorState.getState(i).getRawKeyedState().asList();
        } else {
            managedKeyedStateHandles = Collections.emptyList();
            rawKeyedStateHandles = Collections.emptyList();
        }
        return (managedKeyedStateHandles.isEmpty() && rawKeyedStateHandles.isEmpty()) ? new Tuple2<>(Collections.emptyList(), Collections.emptyList()) : new Tuple2<>(managedKeyedStateHandles, rawKeyedStateHandles);
    }

    public static <T extends StateObject> void reDistributePartitionableStates(Map<OperatorID, OperatorState> map, int i, Function<OperatorSubtaskState, StateObjectCollection<T>> function, OperatorStateRepartitioner<T> operatorStateRepartitioner, Map<OperatorInstanceID, List<T>> map2) {
        Map splitManagedAndRawOperatorStates = splitManagedAndRawOperatorStates(map, function);
        map.forEach((operatorID, operatorState) -> {
            map2.putAll(applyRepartitioner(operatorID, operatorStateRepartitioner, (List) splitManagedAndRawOperatorStates.get(operatorID), operatorState.getParallelism(), i));
        });
    }

    public void reDistributeResultSubpartitionStates(TaskStateAssignment taskStateAssignment) {
        if (taskStateAssignment.hasOutputState || taskStateAssignment.hasDownstreamInputStates()) {
            checkForUnsupportedToplogyChanges(taskStateAssignment.oldState, (v0) -> {
                return v0.getResultSubpartitionState();
            }, taskStateAssignment.outputOperatorID);
            OperatorState operatorState = taskStateAssignment.oldState.get(taskStateAssignment.outputOperatorID);
            List splitBySubtasks = splitBySubtasks(operatorState, (v0) -> {
                return v0.getResultSubpartitionState();
            });
            ExecutionJobVertex executionJobVertex = taskStateAssignment.executionJobVertex;
            List<IntermediateDataSet> producedDataSets = executionJobVertex.getJobVertex().getProducedDataSets();
            if (operatorState.getParallelism() == executionJobVertex.getParallelism()) {
                taskStateAssignment.resultSubpartitionStates.putAll(toInstanceMap(taskStateAssignment.outputOperatorID, splitBySubtasks));
                return;
            }
            for (int i = 0; i < producedDataSets.size(); i++) {
                addToSubtasks(taskStateAssignment.resultSubpartitionStates, applyRepartitioner(taskStateAssignment.outputOperatorID, new MappingBasedRepartitioner(taskStateAssignment.getOutputMapping(i).getRescaleMappings()), producedDataSets.size() == 1 ? splitBySubtasks : getPartitionState(splitBySubtasks, (v0) -> {
                    return v0.getPartitionIdx();
                }, i), splitBySubtasks.size(), executionJobVertex.getParallelism()));
            }
        }
    }

    public void reDistributeInputChannelStates(TaskStateAssignment taskStateAssignment) {
        if (taskStateAssignment.hasInputState || taskStateAssignment.hasUpstreamOutputStates()) {
            checkForUnsupportedToplogyChanges(taskStateAssignment.oldState, (v0) -> {
                return v0.getInputChannelState();
            }, taskStateAssignment.inputOperatorID);
            ExecutionJobVertex executionJobVertex = taskStateAssignment.executionJobVertex;
            List<IntermediateResult> inputs = executionJobVertex.getInputs();
            OperatorState operatorState = taskStateAssignment.oldState.get(taskStateAssignment.inputOperatorID);
            List splitBySubtasks = splitBySubtasks(operatorState, (v0) -> {
                return v0.getInputChannelState();
            });
            boolean anyMatch = executionJobVertex.getJobVertex().getInputs().stream().map((v0) -> {
                return v0.getDownstreamSubtaskStateMapper();
            }).anyMatch(subtaskStateMapper -> {
                return subtaskStateMapper.equals(SubtaskStateMapper.FULL);
            });
            Stream<R> map = executionJobVertex.getInputs().stream().map((v0) -> {
                return v0.getProducer();
            });
            Map<ExecutionJobVertex, TaskStateAssignment> map2 = this.vertexAssignments;
            Objects.requireNonNull(map2);
            boolean anyMatch2 = map.map((v1) -> {
                return r1.get(v1);
            }).anyMatch(taskStateAssignment2 -> {
                return taskStateAssignment.oldState.get(taskStateAssignment.inputOperatorID).getParallelism() != taskStateAssignment2.executionJobVertex.getParallelism();
            });
            if (operatorState.getParallelism() == executionJobVertex.getParallelism() && (!anyMatch || !anyMatch2)) {
                taskStateAssignment.inputChannelStates.putAll(toInstanceMap(taskStateAssignment.inputOperatorID, splitBySubtasks));
                return;
            }
            for (int i = 0; i < inputs.size(); i++) {
                addToSubtasks(taskStateAssignment.inputChannelStates, applyRepartitioner(taskStateAssignment.inputOperatorID, new MappingBasedRepartitioner(taskStateAssignment.getInputMapping(i).getRescaleMappings()), inputs.size() == 1 ? splitBySubtasks : getPartitionState(splitBySubtasks, (v0) -> {
                    return v0.getGateIdx();
                }, i), splitBySubtasks.size(), taskStateAssignment.newParallelism));
            }
        }
    }

    private static <K, V> void addToSubtasks(Map<K, List<V>> map, Map<K, List<V>> map2) {
        map2.forEach((obj, list) -> {
            ((List) map.computeIfAbsent(obj, obj -> {
                return new ArrayList(list.size());
            })).addAll(list);
        });
    }

    private <T extends AbstractChannelStateHandle<?>> void checkForUnsupportedToplogyChanges(Map<OperatorID, OperatorState> map, Function<OperatorSubtaskState, StateObjectCollection<T>> function, OperatorID operatorID) {
        List list = (List) map.entrySet().stream().filter(entry -> {
            return !((OperatorID) entry.getKey()).equals(operatorID);
        }).filter(entry2 -> {
            return hasChannelState((OperatorState) entry2.getValue(), function);
        }).map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalStateException("Cannot recover from unaligned checkpoint when topology changes, such that data exchanges with persisted data are now chained.\nThe following operators contain channel state: " + list);
        }
    }

    private <T extends AbstractChannelStateHandle<?>> boolean hasChannelState(OperatorState operatorState, Function<OperatorSubtaskState, StateObjectCollection<T>> function) {
        return operatorState.getSubtaskStates().values().stream().anyMatch(operatorSubtaskState -> {
            return !isEmpty((StateObjectCollection) function.apply(operatorSubtaskState));
        });
    }

    private <T extends AbstractChannelStateHandle<?>> boolean isEmpty(StateObjectCollection<T> stateObjectCollection) {
        return stateObjectCollection.stream().allMatch(abstractChannelStateHandle -> {
            return abstractChannelStateHandle.getOffsets().isEmpty();
        });
    }

    private static <T extends AbstractChannelStateHandle<I>, I> List<List<T>> getPartitionState(List<List<T>> list, Function<I, Integer> function, int i) {
        return (List) list.stream().map(list2 -> {
            return (List) list2.stream().filter(abstractChannelStateHandle -> {
                return ((Integer) function.apply(abstractChannelStateHandle.getInfo())).intValue() == i;
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
    }

    private static <T extends StateObject> Map<OperatorID, List<List<T>>> splitManagedAndRawOperatorStates(Map<OperatorID, OperatorState> map, Function<OperatorSubtaskState, StateObjectCollection<T>> function) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return splitBySubtasks((OperatorState) entry.getValue(), function);
        }));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T extends StateObject> List<List<T>> splitBySubtasks(OperatorState operatorState, Function<OperatorSubtaskState, StateObjectCollection<T>> function) {
        ArrayList arrayList = new ArrayList(operatorState.getParallelism());
        for (int i = 0; i < operatorState.getParallelism(); i++) {
            OperatorSubtaskState state = operatorState.getState(i);
            arrayList.add(state == null ? Collections.emptyList() : function.apply(state).asList());
        }
        return arrayList;
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList arrayList = null;
        for (int i = 0; i < parallelism; i++) {
            if (operatorState.getState(i) != null) {
                StateObjectCollection<KeyedStateHandle> managedKeyedState = operatorState.getState(i).getManagedKeyedState();
                if (arrayList == null) {
                    arrayList = new ArrayList(parallelism * managedKeyedState.size());
                }
                extractIntersectingState(managedKeyedState, keyGroupRange, arrayList);
            }
        }
        return arrayList != null ? arrayList : Collections.emptyList();
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList arrayList = null;
        for (int i = 0; i < parallelism; i++) {
            if (operatorState.getState(i) != null) {
                StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
                if (arrayList == null) {
                    arrayList = new ArrayList(parallelism * rawKeyedState.size());
                }
                extractIntersectingState(rawKeyedState, keyGroupRange, arrayList);
            }
        }
        return arrayList != null ? arrayList : Collections.emptyList();
    }

    @VisibleForTesting
    public static void extractIntersectingState(Collection<? extends KeyedStateHandle> collection, KeyGroupRange keyGroupRange, List<KeyedStateHandle> list) {
        KeyedStateHandle intersection;
        for (KeyedStateHandle keyedStateHandle : collection) {
            if (keyedStateHandle != null && (intersection = keyedStateHandle.getIntersection(keyGroupRange)) != null) {
                list.add(intersection);
            }
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int i, int i2) {
        Preconditions.checkArgument(i >= i2);
        ArrayList arrayList = new ArrayList(i2);
        for (int i3 = 0; i3 < i2; i3++) {
            arrayList.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(i, i2, i3));
        }
        return arrayList;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.canRescaleMaxParallelism(operatorState.getMaxParallelism())) {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
            LOG.debug("Rescaling maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), Integer.valueOf(executionJobVertex.getMaxParallelism()), Integer.valueOf(operatorState.getMaxParallelism())});
            executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
        }
    }

    private static void checkStateMappingCompleteness(boolean z, Map<OperatorID, OperatorState> map, Set<ExecutionJobVertex> set) {
        HashSet hashSet = new HashSet();
        Iterator<ExecutionJobVertex> it = set.iterator();
        while (it.hasNext()) {
            for (OperatorIDPair operatorIDPair : it.next().getOperatorIDs()) {
                hashSet.add(operatorIDPair.getGeneratedOperatorID());
                Optional<OperatorID> userDefinedOperatorID = operatorIDPair.getUserDefinedOperatorID();
                Objects.requireNonNull(hashSet);
                userDefinedOperatorID.ifPresent((v1) -> {
                    r1.add(v1);
                });
            }
        }
        for (Map.Entry<OperatorID, OperatorState> entry : map.entrySet()) {
            if (!hashSet.contains(entry.getKey())) {
                OperatorState value = entry.getValue();
                if (!z) {
                    throw new IllegalStateException("There is no operator for the state " + value.getOperatorID());
                }
                LOG.info("Skipped checkpoint state for operator {}.", value.getOperatorID());
            }
        }
    }

    public static <T> Map<OperatorInstanceID, List<T>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner<T> operatorStateRepartitioner, List<List<T>> list, int i, int i2) {
        return toInstanceMap(operatorID, applyRepartitioner(operatorStateRepartitioner, list, i, i2));
    }

    private static <T> Map<OperatorInstanceID, List<T>> toInstanceMap(OperatorID operatorID, List<List<T>> list) {
        HashMap newHashMapWithExpectedSize = CollectionUtil.newHashMapWithExpectedSize(list.size());
        for (int i = 0; i < list.size(); i++) {
            Preconditions.checkNotNull(Boolean.valueOf(list.get(i) != null), "states.get(subtaskIndex) is null");
            newHashMapWithExpectedSize.put(OperatorInstanceID.of(i, operatorID), list.get(i));
        }
        return newHashMapWithExpectedSize;
    }

    public static <T> List<List<T>> applyRepartitioner(OperatorStateRepartitioner<T> operatorStateRepartitioner, List<List<T>> list, int i, int i2) {
        return list == null ? Collections.emptyList() : operatorStateRepartitioner.repartitionState(list, i, i2);
    }
}
