package org.apache.flink.streaming.api.graph;

import java.util.Objects;
import java.util.function.Predicate;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ExecutionOptions;
import org.apache.flink.core.io.SimpleVersionedSerializerTypeSerializerProxy;
import org.apache.flink.streaming.api.datastream.CustomSinkOperatorUidHashes;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.CommitterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.SinkWriterOperatorFactory;
import org.apache.flink.streaming.runtime.operators.sink.TestSinkV2;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;

/* loaded from: input_file:org/apache/flink/streaming/api/graph/SinkV2TransformationTranslatorITCase.class */
class SinkV2TransformationTranslatorITCase {
    static final String NAME = "FileSink";
    static final String SLOT_SHARE_GROUP = "FileGroup";
    static final String UID = "FileUid";
    static final int PARALLELISM = 2;

    SinkV2TransformationTranslatorITCase() {
    }

    protected static void assertNoUnalignedOutput(StreamNode streamNode) {
        Assertions.assertThat(streamNode.getOutEdges()).allMatch(streamEdge -> {
            return !streamEdge.supportsUnalignedCheckpoints();
        });
    }

    Sink<Integer> simpleSink() {
        return TestSinkV2.newBuilder().build();
    }

    Sink<Integer> sinkWithCommitter() {
        return TestSinkV2.newBuilder().setDefaultCommitter().build();
    }

    Sink<Integer> sinkWithCommitterAndGlobalCommitter() {
        return TestSinkV2.newBuilder().setDefaultCommitter().setWithPostCommitTopology(true).build();
    }

    DataStreamSink<Integer> sinkTo(DataStream<Integer> dataStream, Sink<Integer> sink) {
        return dataStream.sinkTo(sink);
    }

    @Test
    void testSettingOperatorUidHash() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}).sinkTo(sinkWithCommitterAndGlobalCommitter(), CustomSinkOperatorUidHashes.builder().setWriterUidHash("f6b178ce445dc3ffaa06bad27a51fead").setCommitterUidHash("68ac8ae79eae4e3135a54f9689c4aa10").setGlobalCommitterUidHash("77e6aa6eeb1643b3765e1e4a7a672f37").build()).name(NAME);
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        Assertions.assertThat(findWriter(streamGraph).getUserHash()).isEqualTo("f6b178ce445dc3ffaa06bad27a51fead");
        Assertions.assertThat(findCommitter(streamGraph).getUserHash()).isEqualTo("68ac8ae79eae4e3135a54f9689c4aa10");
        Assertions.assertThat(findGlobalCommitter(streamGraph).getUserHash()).isEqualTo("77e6aa6eeb1643b3765e1e4a7a672f37");
    }

    @Test
    void testSettingOperatorUids() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}).sinkTo(sinkWithCommitterAndGlobalCommitter()).name(NAME).uid("f6b178ce445dc3ffaa06bad27a51fead");
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        Assertions.assertThat(findWriter(streamGraph).getTransformationUID()).isEqualTo("f6b178ce445dc3ffaa06bad27a51fead");
        Assertions.assertThat(findCommitter(streamGraph).getTransformationUID()).isEqualTo(String.format("Sink Committer: %s", "f6b178ce445dc3ffaa06bad27a51fead"));
        Assertions.assertThat(findGlobalCommitter(streamGraph).getTransformationUID()).isEqualTo(String.format("Sink %s Global Committer", "f6b178ce445dc3ffaa06bad27a51fead"));
    }

    @Test
    void testSettingOperatorNames() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}).sinkTo(sinkWithCommitterAndGlobalCommitter()).name(NAME);
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        Assertions.assertThat(findWriter(streamGraph).getOperatorName()).isEqualTo(String.format("%s: Writer", NAME));
        Assertions.assertThat(findCommitter(streamGraph).getOperatorName()).isEqualTo(String.format("%s: Committer", NAME));
        Assertions.assertThat(findGlobalCommitter(streamGraph).getOperatorName()).isEqualTo(String.format("%s: Global Committer", NAME));
    }

    @EnumSource(RuntimeExecutionMode.class)
    @ParameterizedTest
    void generateWriterTopology(RuntimeExecutionMode runtimeExecutionMode) {
        StreamGraph buildGraph = buildGraph(simpleSink(), runtimeExecutionMode);
        StreamNode findNodeName = findNodeName(buildGraph, str -> {
            return str.contains("Source");
        });
        StreamNode findWriter = findWriter(buildGraph);
        Assertions.assertThat(buildGraph.getStreamNodes()).hasSize(PARALLELISM);
        validateTopology(findNodeName, IntSerializer.class, findWriter, SinkWriterOperatorFactory.class, PARALLELISM, -1);
    }

    @EnumSource(RuntimeExecutionMode.class)
    @ParameterizedTest
    void generateWriterCommitterTopology(RuntimeExecutionMode runtimeExecutionMode) {
        StreamGraph buildGraph = buildGraph(sinkWithCommitter(), runtimeExecutionMode);
        StreamNode findNodeName = findNodeName(buildGraph, str -> {
            return str.contains("Source");
        });
        StreamNode findWriter = findWriter(buildGraph);
        validateTopology(findNodeName, IntSerializer.class, findWriter, SinkWriterOperatorFactory.class, PARALLELISM, -1);
        StreamNode findNodeName2 = findNodeName(buildGraph, str2 -> {
            return str2.contains("Committer");
        });
        Assertions.assertThat(buildGraph.getStreamNodes()).hasSize(3);
        assertNoUnalignedOutput(findWriter);
        validateTopology(findWriter, SimpleVersionedSerializerTypeSerializerProxy.class, findNodeName2, CommitterOperatorFactory.class, PARALLELISM, -1);
    }

    @ParameterizedTest
    @CsvSource({"STREAMING, true", "STREAMING, false", "BATCH, true", "BATCH, false"})
    void testParallelismConfigured(RuntimeExecutionMode runtimeExecutionMode, boolean z) {
        StreamGraph buildGraph = buildGraph(sinkWithCommitter(), runtimeExecutionMode, z);
        StreamNode findWriter = findWriter(buildGraph);
        StreamNode findCommitter = findCommitter(buildGraph);
        Assertions.assertThat(findWriter.isParallelismConfigured()).isEqualTo(z);
        Assertions.assertThat(findCommitter.isParallelismConfigured()).isEqualTo(z);
    }

    @EnumSource(RuntimeExecutionMode.class)
    @ParameterizedTest
    void throwExceptionWithoutSettingUid(RuntimeExecutionMode runtimeExecutionMode) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        Configuration configuration = new Configuration();
        configuration.set(ExecutionOptions.RUNTIME_MODE, runtimeExecutionMode);
        executionEnvironment.configure(configuration, getClass().getClassLoader());
        executionEnvironment.getConfig().disableAutoGeneratedUIDs();
        sinkTo(executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}), simpleSink());
        Objects.requireNonNull(executionEnvironment);
        Assertions.assertThatThrownBy(executionEnvironment::getStreamGraph).isInstanceOf(IllegalStateException.class);
    }

    @Test
    void disableOperatorChain() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        sinkTo(executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}), sinkWithCommitter()).name(NAME).disableChaining();
        StreamGraph streamGraph = executionEnvironment.getStreamGraph();
        StreamNode findWriter = findWriter(streamGraph);
        StreamNode findCommitter = findCommitter(streamGraph);
        Assertions.assertThat(findWriter.getOperatorFactory().getChainingStrategy()).isEqualTo(ChainingStrategy.NEVER);
        Assertions.assertThat(findCommitter.getOperatorFactory().getChainingStrategy()).isEqualTo(ChainingStrategy.NEVER);
    }

    void validateTopology(StreamNode streamNode, Class<?> cls, StreamNode streamNode2, Class<? extends StreamOperatorFactory> cls2, int i, int i2) {
        Assertions.assertThat(((StreamEdge) streamNode.getOutEdges().get(0)).getTargetId()).isEqualTo(streamNode2.getId());
        Assertions.assertThat(streamNode.getTypeSerializerOut()).isInstanceOf(cls);
        Assertions.assertThat(((StreamEdge) streamNode2.getInEdges().get(0)).getTargetId()).isEqualTo(streamNode2.getId());
        Assertions.assertThat(streamNode2.getTypeSerializersIn()[0]).isInstanceOf(cls);
        Assertions.assertThat(streamNode2.getOperatorName()).isNotEqualTo(streamNode.getOperatorName());
        Assertions.assertThat(streamNode2.getTransformationUID()).isNotEqualTo(streamNode.getTransformationUID());
        Assertions.assertThat(streamNode2.getOperatorFactory()).isInstanceOf(cls2);
        Assertions.assertThat(streamNode2.getParallelism()).isEqualTo(i);
        Assertions.assertThat(streamNode2.getMaxParallelism()).isEqualTo(i2);
        Assertions.assertThat(streamNode2.getOperatorFactory().getChainingStrategy()).isEqualTo(ChainingStrategy.ALWAYS);
        Assertions.assertThat(streamNode2.getSlotSharingGroup()).isEqualTo(SLOT_SHARE_GROUP);
    }

    StreamGraph buildGraph(Sink<Integer> sink, RuntimeExecutionMode runtimeExecutionMode) {
        return buildGraph(sink, runtimeExecutionMode, true);
    }

    StreamGraph buildGraph(Sink<Integer> sink, RuntimeExecutionMode runtimeExecutionMode, boolean z) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        Configuration configuration = new Configuration();
        configuration.set(ExecutionOptions.RUNTIME_MODE, runtimeExecutionMode);
        executionEnvironment.configure(configuration, getClass().getClassLoader());
        setSinkProperty(sinkTo(executionEnvironment.fromData(new Integer[]{1, Integer.valueOf(PARALLELISM)}).rebalance(), sink), z);
        executionEnvironment.getExecutionPlan();
        return executionEnvironment.getStreamGraph();
    }

    private void setSinkProperty(DataStreamSink<Integer> dataStreamSink, boolean z) {
        dataStreamSink.name(NAME);
        dataStreamSink.uid(UID);
        if (z) {
            dataStreamSink.setParallelism(PARALLELISM);
        }
        dataStreamSink.slotSharingGroup(SLOT_SHARE_GROUP);
    }

    StreamNode findNodeName(StreamGraph streamGraph, Predicate<String> predicate) {
        return (StreamNode) streamGraph.getStreamNodes().stream().filter(streamNode -> {
            return predicate.test(streamNode.getOperatorName());
        }).findFirst().orElseThrow(() -> {
            return new IllegalStateException("Can not find the node");
        });
    }

    StreamNode findWriter(StreamGraph streamGraph) {
        return findNodeName(streamGraph, str -> {
            return str.contains("Writer") && !str.contains("Committer");
        });
    }

    StreamNode findCommitter(StreamGraph streamGraph) {
        return findNodeName(streamGraph, str -> {
            return str.contains("Committer") && !str.contains("Global Committer");
        });
    }

    StreamNode findGlobalCommitter(StreamGraph streamGraph) {
        return findNodeName(streamGraph, str -> {
            return str.contains("Global Committer");
        });
    }
}
