package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.SerializableObject;
import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor;
import org.apache.flink.util.concurrent.ScheduledExecutor;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.class */
class CheckpointStateRestoreTest {

    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();
    private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";

    CheckpointStateRestoreTest() {
    }

    @Test
    void testSetState() {
        try {
            KeyGroupsStateHandle generateKeyGroupState = CheckpointCoordinatorTestingUtils.generateKeyGroupState(KeyGroupRange.of(0, 0), Collections.singletonList(new SerializableObject()));
            JobVertexID jobVertexID = new JobVertexID();
            JobVertexID jobVertexID2 = new JobVertexID();
            ExecutionGraph build = new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder().addJobVertex(jobVertexID, 3, 256).addJobVertex(jobVertexID2, 2, 256).build((ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor());
            ExecutionJobVertex jobVertex = build.getJobVertex(jobVertexID);
            ExecutionJobVertex jobVertex2 = build.getJobVertex(jobVertexID2);
            ExecutionVertex executionVertex = jobVertex.getTaskVertices()[0];
            ExecutionVertex executionVertex2 = jobVertex.getTaskVertices()[1];
            ExecutionVertex executionVertex3 = jobVertex.getTaskVertices()[2];
            ExecutionVertex executionVertex4 = jobVertex2.getTaskVertices()[0];
            ExecutionVertex executionVertex5 = jobVertex2.getTaskVertices()[1];
            Execution currentExecutionAttempt = executionVertex.getCurrentExecutionAttempt();
            Execution currentExecutionAttempt2 = executionVertex2.getCurrentExecutionAttempt();
            Execution currentExecutionAttempt3 = executionVertex3.getCurrentExecutionAttempt();
            Execution currentExecutionAttempt4 = executionVertex4.getCurrentExecutionAttempt();
            Execution currentExecutionAttempt5 = executionVertex5.getCurrentExecutionAttempt();
            ScheduledExecutor manuallyTriggeredScheduledExecutor = new ManuallyTriggeredScheduledExecutor();
            CheckpointCoordinator build2 = new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder().setTimer(manuallyTriggeredScheduledExecutor).build(build);
            build2.triggerCheckpoint(false);
            manuallyTriggeredScheduledExecutor.triggerAll();
            long checkpointID = ((PendingCheckpoint) build2.getPendingCheckpoints().values().iterator().next()).getCheckpointID();
            TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
            taskStateSnapshot.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), OperatorSubtaskState.builder().setManagedKeyedState(generateKeyGroupState).build());
            build2.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(build.getJobID(), currentExecutionAttempt.getAttemptId(), checkpointID, new CheckpointMetrics(), taskStateSnapshot), TASK_MANAGER_LOCATION_INFO);
            build2.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(build.getJobID(), currentExecutionAttempt2.getAttemptId(), checkpointID, new CheckpointMetrics(), taskStateSnapshot), TASK_MANAGER_LOCATION_INFO);
            build2.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(build.getJobID(), currentExecutionAttempt3.getAttemptId(), checkpointID, new CheckpointMetrics(), taskStateSnapshot), TASK_MANAGER_LOCATION_INFO);
            build2.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(build.getJobID(), currentExecutionAttempt4.getAttemptId(), checkpointID), TASK_MANAGER_LOCATION_INFO);
            build2.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(build.getJobID(), currentExecutionAttempt5.getAttemptId(), checkpointID), TASK_MANAGER_LOCATION_INFO);
            Assertions.assertThat(build2.getNumberOfRetainedSuccessfulCheckpoints()).isOne();
            Assertions.assertThat(build2.getNumberOfPendingCheckpoints()).isZero();
            Assertions.assertThat(build2.restoreLatestCheckpointedStateToAll(new HashSet(Arrays.asList(jobVertex, jobVertex2)), false)).isTrue();
            Assertions.assertThat(currentExecutionAttempt.getTaskRestore().getTaskStateSnapshot()).isEqualTo(taskStateSnapshot);
            Assertions.assertThat(currentExecutionAttempt2.getTaskRestore().getTaskStateSnapshot()).isEqualTo(taskStateSnapshot);
            Assertions.assertThat(currentExecutionAttempt3.getTaskRestore().getTaskStateSnapshot()).isEqualTo(taskStateSnapshot);
            Assertions.assertThat(currentExecutionAttempt4.getTaskRestore()).isNull();
            Assertions.assertThat(currentExecutionAttempt5.getTaskRestore()).isNull();
        } catch (Exception e) {
            e.printStackTrace();
            Assertions.fail(e.getMessage());
        }
    }

    @Test
    void testNoCheckpointAvailable() {
        try {
            Assertions.assertThat(new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder().build((ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor()).restoreLatestCheckpointedStateToAll(Collections.emptySet(), false)).isFalse();
        } catch (Exception e) {
            e.printStackTrace();
            Assertions.fail(e.getMessage());
        }
    }

    @Test
    void testNonRestoredState() throws Exception {
        JobVertexID jobVertexID = new JobVertexID();
        JobVertexID jobVertexID2 = new JobVertexID();
        OperatorID fromJobVertexID = OperatorID.fromJobVertexID(jobVertexID);
        ExecutionVertex mockExecutionVertex = mockExecutionVertex(mockExecution(), jobVertexID, 0, 3);
        ExecutionVertex mockExecutionVertex2 = mockExecutionVertex(mockExecution(), jobVertexID, 1, 3);
        ExecutionVertex mockExecutionVertex3 = mockExecutionVertex(mockExecution(), jobVertexID, 2, 3);
        ExecutionVertex mockExecutionVertex4 = mockExecutionVertex(mockExecution(), jobVertexID2, 0, 2);
        ExecutionVertex mockExecutionVertex5 = mockExecutionVertex(mockExecution(), jobVertexID2, 1, 2);
        ExecutionJobVertex mockExecutionJobVertex = mockExecutionJobVertex(jobVertexID, new ExecutionVertex[]{mockExecutionVertex, mockExecutionVertex2, mockExecutionVertex3});
        ExecutionJobVertex mockExecutionJobVertex2 = mockExecutionJobVertex(jobVertexID2, new ExecutionVertex[]{mockExecutionVertex4, mockExecutionVertex5});
        HashSet hashSet = new HashSet();
        hashSet.add(mockExecutionJobVertex);
        hashSet.add(mockExecutionJobVertex2);
        CheckpointCoordinator build = new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder().build((ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor());
        HashMap hashMap = new HashMap();
        OperatorState operatorState = new OperatorState((String) null, (String) null, fromJobVertexID, 3, 3);
        operatorState.putState(0, OperatorSubtaskState.builder().build());
        operatorState.putState(1, OperatorSubtaskState.builder().build());
        operatorState.putState(2, OperatorSubtaskState.builder().build());
        hashMap.put(fromJobVertexID, operatorState);
        build.getCheckpointStore().addCheckpointAndSubsumeOldestOne(new CompletedCheckpoint(new JobID(), 0L, 1L, 2L, new HashMap(hashMap), Collections.emptyList(), CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION), new TestCompletedCheckpointStorageLocation(), (CompletedCheckpointStats) null), new CheckpointsCleaner(), () -> {
        });
        Assertions.assertThat(build.restoreLatestCheckpointedStateToAll(hashSet, false)).isTrue();
        Assertions.assertThat(build.restoreLatestCheckpointedStateToAll(hashSet, true)).isTrue();
        OperatorID fromJobVertexID2 = OperatorID.fromJobVertexID(new JobVertexID());
        OperatorState operatorState2 = new OperatorState((String) null, (String) null, fromJobVertexID2, 1, 1);
        operatorState2.putState(0, OperatorSubtaskState.builder().build());
        hashMap.put(fromJobVertexID2, operatorState2);
        build.getCheckpointStore().addCheckpointAndSubsumeOldestOne(new CompletedCheckpoint(new JobID(), 1L, 2L, 3L, new HashMap(hashMap), Collections.emptyList(), CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION), new TestCompletedCheckpointStorageLocation(), (CompletedCheckpointStats) null), new CheckpointsCleaner(), () -> {
        });
        Assertions.assertThat(build.restoreLatestCheckpointedStateToAll(hashSet, true)).isTrue();
        try {
            build.restoreLatestCheckpointedStateToAll(hashSet, false);
            Assertions.fail("Did not throw the expected Exception.");
        } catch (IllegalStateException e) {
        }
    }

    private Execution mockExecution() {
        return mockExecution(ExecutionState.RUNNING);
    }

    private Execution mockExecution(ExecutionState executionState) {
        Execution execution = (Execution) Mockito.mock(Execution.class);
        Mockito.when(execution.getAttemptId()).thenReturn(ExecutionGraphTestUtils.createExecutionAttemptId());
        Mockito.when(execution.getState()).thenReturn(executionState);
        return execution;
    }

    private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID jobVertexID, int i, int i2) {
        ExecutionVertex executionVertex = (ExecutionVertex) Mockito.mock(ExecutionVertex.class);
        Mockito.when(executionVertex.getJobvertexId()).thenReturn(jobVertexID);
        Mockito.when(Integer.valueOf(executionVertex.getParallelSubtaskIndex())).thenReturn(Integer.valueOf(i));
        Mockito.when(executionVertex.getCurrentExecutionAttempt()).thenReturn(execution);
        Mockito.when(Integer.valueOf(executionVertex.getTotalNumberOfParallelSubtasks())).thenReturn(Integer.valueOf(i2));
        Mockito.when(Integer.valueOf(executionVertex.getMaxParallelism())).thenReturn(Integer.valueOf(i2));
        return executionVertex;
    }

    private ExecutionJobVertex mockExecutionJobVertex(JobVertexID jobVertexID, ExecutionVertex[] executionVertexArr) {
        ExecutionJobVertex executionJobVertex = (ExecutionJobVertex) Mockito.mock(ExecutionJobVertex.class);
        Mockito.when(Integer.valueOf(executionJobVertex.getParallelism())).thenReturn(Integer.valueOf(executionVertexArr.length));
        Mockito.when(Integer.valueOf(executionJobVertex.getMaxParallelism())).thenReturn(Integer.valueOf(executionVertexArr.length));
        Mockito.when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
        Mockito.when(executionJobVertex.getTaskVertices()).thenReturn(executionVertexArr);
        Mockito.when(executionJobVertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(OperatorID.fromJobVertexID(jobVertexID))));
        Mockito.when(executionJobVertex.getProducedDataSets()).thenReturn(new IntermediateResult[0]);
        for (ExecutionVertex executionVertex : executionVertexArr) {
            Mockito.when(executionVertex.getJobVertex()).thenReturn(executionJobVertex);
        }
        return executionJobVertex;
    }
}
