package org.apache.flink.streaming.api.graph;

import java.util.Arrays;
import java.util.Collection;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ExecutionOptions;
import org.apache.flink.core.io.SimpleVersionedSerializerTypeSerializerProxy;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.runtime.operators.sink.BatchCommitterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.BatchGlobalCommitterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.StatelessSinkWriterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.StreamingCommitterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.StreamingGlobalCommitterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.TestSink;
import org.apache.flink.util.TestLogger;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/flink/streaming/api/graph/SinkTransformationTranslatorTest.class */
public class SinkTransformationTranslatorTest extends TestLogger {

    @Parameterized.Parameter
    public RuntimeExecutionMode runtimeExecutionMode;

    @Parameterized.Parameter(1)
    public Class<?> committerClass;

    @Parameterized.Parameter(PARALLELISM)
    public Class<?> globalCommitterClass;
    static final String NAME = "FileSink";
    static final String SLOT_SHARE_GROUP = "FileGroup";
    static final String UID = "FileUid";
    static final int PARALLELISM = 2;

    @Parameterized.Parameters(name = "Execution Mode: {0}, Expected Committer Operator: {1}, Expected Global Committer Operator: {2}")
    public static Collection<Object[]> data() {
        return Arrays.asList(new Object[]{RuntimeExecutionMode.STREAMING, StreamingCommitterOperatorFactory.class, StreamingGlobalCommitterOperatorFactory.class}, new Object[]{RuntimeExecutionMode.BATCH, BatchCommitterOperatorFactory.class, BatchGlobalCommitterOperatorFactory.class});
    }

    @Test
    public void generateWriterTopology() {
        StreamGraph buildGraph = buildGraph(TestSink.newBuilder().build(), this.runtimeExecutionMode);
        StreamNode findNodeNameContains = findNodeNameContains(buildGraph, "Source");
        StreamNode findNodeNameContains2 = findNodeNameContains(buildGraph, "Writer");
        MatcherAssert.assertThat(Integer.valueOf(buildGraph.getStreamNodes().size()), CoreMatchers.equalTo(Integer.valueOf(PARALLELISM)));
        validateTopology(findNodeNameContains, IntSerializer.class, findNodeNameContains2, String.format("Sink Writer: %s", NAME), UID, StatelessSinkWriterOperatorFactory.class, PARALLELISM, -1);
    }

    @Test
    public void generateWriterCommitterTopology() {
        StreamGraph buildGraph = buildGraph(TestSink.newBuilder().setDefaultCommitter().build(), this.runtimeExecutionMode);
        StreamNode findNodeNameContains = findNodeNameContains(buildGraph, "Writer");
        StreamNode findNodeNameContains2 = findNodeNameContains(buildGraph, "Committer");
        MatcherAssert.assertThat(Integer.valueOf(buildGraph.getStreamNodes().size()), CoreMatchers.equalTo(3));
        validateTopology(findNodeNameContains, SimpleVersionedSerializerTypeSerializerProxy.class, findNodeNameContains2, String.format("Sink Committer: %s", NAME), String.format("Sink Committer: %s", UID), this.committerClass, this.runtimeExecutionMode == RuntimeExecutionMode.STREAMING ? PARALLELISM : 1, this.runtimeExecutionMode == RuntimeExecutionMode.STREAMING ? -1 : 1);
    }

    @Test
    public void generateWriterCommitterGlobalCommitterTopology() {
        StreamGraph buildGraph = buildGraph(TestSink.newBuilder().setDefaultCommitter().setDefaultGlobalCommitter().build(), this.runtimeExecutionMode);
        validateTopology(findNodeNameContains(buildGraph, "Committer"), SimpleVersionedSerializerTypeSerializerProxy.class, findNodeNameContains(buildGraph, "Global Committer"), String.format("Sink Global Committer: %s", NAME), String.format("Sink Global Committer: %s", UID), this.globalCommitterClass, 1, 1);
    }

    @Test
    public void generateWriterGlobalCommitterTopology() {
        StreamGraph buildGraph = buildGraph(TestSink.newBuilder().setCommittableSerializer(TestSink.StringCommittableSerializer.INSTANCE).setDefaultGlobalCommitter().build(), this.runtimeExecutionMode);
        validateTopology(findNodeNameContains(buildGraph, "Writer"), SimpleVersionedSerializerTypeSerializerProxy.class, findNodeNameContains(buildGraph, "Global Committer"), String.format("Sink Global Committer: %s", NAME), String.format("Sink Global Committer: %s", UID), this.globalCommitterClass, 1, 1);
    }

    @Test(expected = IllegalStateException.class)
    public void throwExceptionWithoutSettingUid() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        Configuration configuration = new Configuration();
        configuration.set(ExecutionOptions.RUNTIME_MODE, this.runtimeExecutionMode);
        executionEnvironment.configure(configuration, getClass().getClassLoader());
        executionEnvironment.getConfig().disableAutoGeneratedUIDs();
        executionEnvironment.fromElements(new Integer[]{1, Integer.valueOf(PARALLELISM)}).sinkTo(TestSink.newBuilder().build());
        executionEnvironment.getStreamGraph();
    }

    @Test
    public void disableOperatorChain() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.fromElements(new Integer[]{1, Integer.valueOf(PARALLELISM)}).sinkTo(TestSink.newBuilder().setDefaultCommitter().setDefaultGlobalCommitter().build()).disableChaining();
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        StreamNode findNodeNameContains = findNodeNameContains(streamGraph, "Writer");
        StreamNode findNodeNameContains2 = findNodeNameContains(streamGraph, "Committer");
        StreamNode findNodeNameContains3 = findNodeNameContains(streamGraph, "Global Committer");
        MatcherAssert.assertThat(findNodeNameContains.getOperatorFactory().getChainingStrategy(), CoreMatchers.is(ChainingStrategy.NEVER));
        MatcherAssert.assertThat(findNodeNameContains2.getOperatorFactory().getChainingStrategy(), CoreMatchers.is(ChainingStrategy.ALWAYS));
        MatcherAssert.assertThat(findNodeNameContains3.getOperatorFactory().getChainingStrategy(), CoreMatchers.is(ChainingStrategy.ALWAYS));
    }

    private void validateTopology(StreamNode streamNode, Class<?> cls, StreamNode streamNode2, String str, String str2, Class<?> cls2, int i, int i2) {
        MatcherAssert.assertThat(Integer.valueOf(((StreamEdge) streamNode.getOutEdges().get(0)).getTargetId()), CoreMatchers.equalTo(Integer.valueOf(streamNode2.getId())));
        MatcherAssert.assertThat(streamNode.getTypeSerializerOut(), CoreMatchers.instanceOf(cls));
        MatcherAssert.assertThat(Integer.valueOf(((StreamEdge) streamNode2.getInEdges().get(0)).getSourceId()), CoreMatchers.equalTo(Integer.valueOf(streamNode.getId())));
        MatcherAssert.assertThat(streamNode2.getTypeSerializersIn()[0], CoreMatchers.instanceOf(cls));
        MatcherAssert.assertThat(streamNode2.getOperatorName(), CoreMatchers.equalTo(str));
        MatcherAssert.assertThat(streamNode2.getTransformationUID(), CoreMatchers.equalTo(str2));
        MatcherAssert.assertThat(streamNode2.getOperatorFactory(), CoreMatchers.instanceOf(cls2));
        MatcherAssert.assertThat(Integer.valueOf(streamNode2.getParallelism()), CoreMatchers.equalTo(Integer.valueOf(i)));
        MatcherAssert.assertThat(Integer.valueOf(streamNode2.getMaxParallelism()), CoreMatchers.equalTo(Integer.valueOf(i2)));
        MatcherAssert.assertThat(streamNode2.getOperatorFactory().getChainingStrategy(), CoreMatchers.is(ChainingStrategy.ALWAYS));
        MatcherAssert.assertThat(streamNode2.getSlotSharingGroup(), CoreMatchers.equalTo(SLOT_SHARE_GROUP));
        MatcherAssert.assertThat(Integer.valueOf(streamNode2.getOutEdges().size()), CoreMatchers.equalTo(0));
    }

    private StreamGraph buildGraph(TestSink testSink, RuntimeExecutionMode runtimeExecutionMode) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        Configuration configuration = new Configuration();
        configuration.set(ExecutionOptions.RUNTIME_MODE, runtimeExecutionMode);
        executionEnvironment.configure(configuration, getClass().getClassLoader());
        setSinkProperty(executionEnvironment.fromElements(new Integer[]{1, Integer.valueOf(PARALLELISM)}).rebalance().sinkTo(testSink));
        return executionEnvironment.getStreamGraph("test");
    }

    private void setSinkProperty(DataStreamSink<Integer> dataStreamSink) {
        dataStreamSink.name(NAME);
        dataStreamSink.uid(UID);
        dataStreamSink.setParallelism(PARALLELISM);
        dataStreamSink.slotSharingGroup(SLOT_SHARE_GROUP);
    }

    private StreamNode findNodeNameContains(StreamGraph streamGraph, String str) {
        return (StreamNode) streamGraph.getStreamNodes().stream().filter(streamNode -> {
            return streamNode.getOperatorName().contains(str);
        }).findFirst().orElseThrow(() -> {
            return new IllegalStateException("Can not find the node contains " + str);
        });
    }
}
