package org.apache.flink.runtime.jobgraph.forwardgroup;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.class */
class StreamNodeForwardGroupTest {
    StreamNodeForwardGroupTest() {
    }

    @Test
    void testStreamNodeForwardGroup() {
        HashSet hashSet = new HashSet();
        hashSet.add(createStreamNode(0, 1, 1));
        hashSet.add(createStreamNode(1, 1, 1));
        StreamNodeForwardGroup streamNodeForwardGroup = new StreamNodeForwardGroup(hashSet);
        Assertions.assertThat(streamNodeForwardGroup.getParallelism()).isEqualTo(1);
        Assertions.assertThat(streamNodeForwardGroup.getMaxParallelism()).isEqualTo(1);
        Assertions.assertThat(streamNodeForwardGroup.size()).isEqualTo(2);
        hashSet.add(createStreamNode(3, 1, 1));
        Assertions.assertThat(new StreamNodeForwardGroup(hashSet).size()).isEqualTo(3);
    }

    @Test
    void testMergeForwardGroup() {
        HashMap hashMap = new HashMap();
        StreamNodeForwardGroup createForwardGroupAndUpdateStreamNodeRetriever = createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(0, -1, -1), hashMap);
        createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(1, -1, -1), hashMap));
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.isParallelismDecided()).isFalse();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.isMaxParallelismDecided()).isFalse();
        createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(2, 2, 4), hashMap));
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getParallelism()).isEqualTo(2);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getMaxParallelism()).isEqualTo(4);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(3, 2, 5), hashMap))).isTrue();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getMaxParallelism()).isEqualTo(4);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(4, 2, 3), hashMap))).isTrue();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getMaxParallelism()).isEqualTo(3);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(5, -1, 1), hashMap))).isFalse();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(6, 1, 3), hashMap))).isFalse();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.mergeForwardGroup(createForwardGroupAndUpdateStreamNodeRetriever(createStreamNode(7, -1, 2), hashMap))).isTrue();
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.size()).isEqualTo(6);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getParallelism()).isEqualTo(2);
        Assertions.assertThat(createForwardGroupAndUpdateStreamNodeRetriever.getMaxParallelism()).isEqualTo(2);
        Iterator it = createForwardGroupAndUpdateStreamNodeRetriever.getVertexIds().iterator();
        while (it.hasNext()) {
            StreamNode streamNode = hashMap.get((Integer) it.next());
            Assertions.assertThat(streamNode.getParallelism()).isEqualTo(createForwardGroupAndUpdateStreamNodeRetriever.getParallelism());
            Assertions.assertThat(streamNode.getMaxParallelism()).isEqualTo(createForwardGroupAndUpdateStreamNodeRetriever.getMaxParallelism());
        }
    }

    private static StreamNode createStreamNode(int i, int i2, int i3) {
        StreamNode streamNode = new StreamNode(Integer.valueOf(i), (String) null, (String) null, (StreamOperator) null, (String) null, (Class) null);
        if (i2 > 0) {
            streamNode.setParallelism(Integer.valueOf(i2));
        }
        if (i3 > 0) {
            streamNode.setMaxParallelism(i3);
        }
        return streamNode;
    }

    private StreamNodeForwardGroup createForwardGroupAndUpdateStreamNodeRetriever(StreamNode streamNode, Map<Integer, StreamNode> map) {
        map.put(Integer.valueOf(streamNode.getId()), streamNode);
        return new StreamNodeForwardGroup(Collections.singleton(streamNode));
    }
}
