package org.apache.flink.runtime.checkpoint;

import java.util.Collections;
import java.util.Random;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.class */
public class TaskStateSnapshotTest extends TestLogger {
    @Test
    public void putGetSubtaskStateByOperatorID() {
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        OperatorID operatorID = new OperatorID();
        OperatorID operatorID2 = new OperatorID();
        OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState();
        OperatorSubtaskState operatorSubtaskState2 = new OperatorSubtaskState();
        OperatorSubtaskState operatorSubtaskState3 = new OperatorSubtaskState();
        Assert.assertNull(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID));
        Assert.assertNull(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID2));
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID2, operatorSubtaskState2);
        Assert.assertEquals(operatorSubtaskState, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID));
        Assert.assertEquals(operatorSubtaskState2, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID2));
        Assert.assertEquals(operatorSubtaskState, taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState3));
        Assert.assertEquals(operatorSubtaskState3, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID));
    }

    @Test
    public void hasState() {
        Random random = new Random(66L);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        Assert.assertFalse(taskStateSnapshot.hasState());
        OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState();
        Assert.assertFalse(operatorSubtaskState.hasState());
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState);
        Assert.assertFalse(taskStateSnapshot.hasState());
        OperatorSubtaskState operatorSubtaskState2 = new OperatorSubtaskState(StateHandleDummyUtil.createNewOperatorStateHandle(2, random), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null, (StateObjectCollection) null, (StateObjectCollection) null);
        Assert.assertTrue(operatorSubtaskState2.hasState());
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState2);
        Assert.assertTrue(taskStateSnapshot.hasState());
    }

    @Test
    public void discardState() throws Exception {
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        OperatorID operatorID = new OperatorID();
        OperatorID operatorID2 = new OperatorID();
        OperatorSubtaskState operatorSubtaskState = (OperatorSubtaskState) Mockito.mock(OperatorSubtaskState.class);
        OperatorSubtaskState operatorSubtaskState2 = (OperatorSubtaskState) Mockito.mock(OperatorSubtaskState.class);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID2, operatorSubtaskState2);
        taskStateSnapshot.discardState();
        ((OperatorSubtaskState) Mockito.verify(operatorSubtaskState)).discardState();
        ((OperatorSubtaskState) Mockito.verify(operatorSubtaskState2)).discardState();
    }

    @Test
    public void getStateSize() {
        Random random = new Random(66L);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        Assert.assertEquals(0L, taskStateSnapshot.getStateSize());
        OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState();
        Assert.assertFalse(operatorSubtaskState.hasState());
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState);
        Assert.assertEquals(0L, taskStateSnapshot.getStateSize());
        OperatorStateHandle createNewOperatorStateHandle = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
        OperatorSubtaskState operatorSubtaskState2 = new OperatorSubtaskState(createNewOperatorStateHandle, (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null, (StateObjectCollection) null, (StateObjectCollection) null);
        OperatorStateHandle createNewOperatorStateHandle2 = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
        OperatorSubtaskState operatorSubtaskState3 = new OperatorSubtaskState((OperatorStateHandle) null, createNewOperatorStateHandle2, (KeyedStateHandle) null, (KeyedStateHandle) null, (StateObjectCollection) null, (StateObjectCollection) null);
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState2);
        taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState3);
        Assert.assertEquals(createNewOperatorStateHandle.getStateSize() + createNewOperatorStateHandle2.getStateSize(), taskStateSnapshot.getStateSize());
    }

    @Test
    public void testSizeIncludesChannelState() {
        Random random = new Random();
        InputChannelStateHandle createNewInputChannelStateHandle = StateHandleDummyUtil.createNewInputChannelStateHandle(10, random);
        ResultSubpartitionStateHandle createNewResultSubpartitionStateHandle = StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random);
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.singleton(createNewInputChannelStateHandle), StateObjectCollection.singleton(createNewResultSubpartitionStateHandle))));
        Assert.assertEquals(createNewInputChannelStateHandle.getStateSize() + createNewResultSubpartitionStateHandle.getStateSize(), taskStateSnapshot.getStateSize());
        Assert.assertTrue(taskStateSnapshot.hasState());
    }
}
