package org.apache.flink.runtime.checkpoint;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.metadata.CheckpointMetadata;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.class */
public class CheckpointMetadataLoadingTest {
    private final ClassLoader cl = getClass().getClassLoader();

    @Test
    public void testAllStateRestored() throws Exception {
        JobID jobID = new JobID();
        OperatorID operatorID = new OperatorID();
        CompletedCheckpoint loadAndValidateCheckpoint = Checkpoints.loadAndValidateCheckpoint(jobID, createTasks(operatorID, 128128, 128128), createSavepointWithOperatorSubtaskState(2147606770L, operatorID, 128128), this.cl, false);
        Assert.assertEquals(jobID, loadAndValidateCheckpoint.getJobId());
        Assert.assertEquals(2147606770L, loadAndValidateCheckpoint.getCheckpointID());
    }

    @Test
    public void testMaxParallelismMismatch() throws Exception {
        OperatorID operatorID = new OperatorID();
        CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState = createSavepointWithOperatorSubtaskState(242L, operatorID, 128128);
        try {
            Checkpoints.loadAndValidateCheckpoint(new JobID(), createTasks(operatorID, 128128, 128129), createSavepointWithOperatorSubtaskState, this.cl, false);
            Assert.fail("Did not throw expected Exception");
        } catch (IllegalStateException e) {
            Assert.assertTrue(e.getMessage().contains("Max parallelism mismatch"));
        }
    }

    @Test
    public void testNonRestoredStateWhenDisallowed() throws Exception {
        CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState = createSavepointWithOperatorSubtaskState(242L, new OperatorID(), 9);
        try {
            Checkpoints.loadAndValidateCheckpoint(new JobID(), Collections.emptyMap(), createSavepointWithOperatorSubtaskState, this.cl, false);
            Assert.fail("Did not throw expected Exception");
        } catch (IllegalStateException e) {
            Assert.assertTrue(e.getMessage().contains("allowNonRestoredState"));
        }
    }

    @Test
    public void testNonRestoredStateWhenAllowed() throws Exception {
        CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState = createSavepointWithOperatorSubtaskState(242L, new OperatorID(), 9);
        Assert.assertTrue(Checkpoints.loadAndValidateCheckpoint(new JobID(), Collections.emptyMap(), createSavepointWithOperatorSubtaskState, this.cl, true).getOperatorStates().isEmpty());
    }

    @Test
    public void testUnmatchedCoordinatorOnlyStateFails() throws Exception {
        OperatorState operatorState = new OperatorState(new OperatorID(), 617, 1234);
        operatorState.setCoordinatorState(new ByteStreamStateHandle("coordinatorState", new byte[0]));
        CompletedCheckpointStorageLocation createSavepointWithOperatorState = createSavepointWithOperatorState(42L, operatorState);
        try {
            Checkpoints.loadAndValidateCheckpoint(new JobID(), Collections.emptyMap(), createSavepointWithOperatorState, this.cl, false);
            Assert.fail("Did not throw expected Exception");
        } catch (IllegalStateException e) {
            Assert.assertTrue(e.getMessage().contains("allowNonRestoredState"));
        }
    }

    private static CompletedCheckpointStorageLocation createSavepointWithOperatorState(long j, OperatorState operatorState) throws IOException {
        CheckpointMetadata checkpointMetadata = new CheckpointMetadata(j, Collections.singletonList(operatorState), Collections.emptyList());
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        Throwable th = null;
        try {
            Checkpoints.storeCheckpointMetadata(checkpointMetadata, byteArrayOutputStream);
            ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle("checkpoint", byteArrayOutputStream.toByteArray());
            if (byteArrayOutputStream != null) {
                if (0 != 0) {
                    try {
                        byteArrayOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    byteArrayOutputStream.close();
                }
            }
            return new TestCompletedCheckpointStorageLocation(byteStreamStateHandle, "dummy/pointer");
        } catch (Throwable th3) {
            if (byteArrayOutputStream != null) {
                if (0 != 0) {
                    try {
                        byteArrayOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    byteArrayOutputStream.close();
                }
            }
            throw th3;
        }
    }

    private static CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState(long j, OperatorID operatorID, int i) throws IOException {
        Random random = new Random();
        OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(new OperatorStreamStateHandle(Collections.emptyMap(), new ByteStreamStateHandle("testHandler", new byte[0])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null, StateObjectCollection.singleton(StateHandleDummyUtil.createNewInputChannelStateHandle(10, random)), StateObjectCollection.singleton(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)));
        OperatorState operatorState = new OperatorState(operatorID, i, i);
        operatorState.putState(0, operatorSubtaskState);
        return createSavepointWithOperatorState(j, operatorState);
    }

    private static Map<JobVertexID, ExecutionJobVertex> createTasks(OperatorID operatorID, int i, int i2) {
        JobVertexID jobVertexID = new JobVertexID(operatorID.getLowerPart(), operatorID.getUpperPart());
        ExecutionJobVertex executionJobVertex = (ExecutionJobVertex) Mockito.mock(ExecutionJobVertex.class);
        Mockito.when(Integer.valueOf(executionJobVertex.getParallelism())).thenReturn(Integer.valueOf(i));
        Mockito.when(Integer.valueOf(executionJobVertex.getMaxParallelism())).thenReturn(Integer.valueOf(i2));
        Mockito.when(executionJobVertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(operatorID)));
        if (i != i2) {
            Mockito.when(Boolean.valueOf(executionJobVertex.isMaxParallelismConfigured())).thenReturn(true);
        }
        HashMap hashMap = new HashMap();
        hashMap.put(jobVertexID, executionJobVertex);
        return hashMap;
    }
}
