/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.chaining;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.DataSourceTask;
import org.apache.flink.runtime.operators.DataSourceTaskTest;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.FlatMapDriver;
import org.apache.flink.runtime.operators.FlatMapTaskTest;
import org.apache.flink.runtime.operators.ReduceTaskTest;
import org.apache.flink.runtime.operators.chaining.ChainedFlatMapDriver;
import org.apache.flink.runtime.operators.chaining.SynchronousChainedCombineDriver;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.testutils.TaskTestBase;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class ChainTaskTest
extends TaskTestBase {
    private static final int MEMORY_MANAGER_SIZE = 0x300000;
    private static final int NETWORK_BUFFER_SIZE = 1024;
    private final List<Record> outList = new ArrayList<Record>();
    private final RecordComparatorFactory compFact = new RecordComparatorFactory(new int[]{0}, new Class[]{IntValue.class}, new boolean[]{true});
    private final RecordSerializerFactory serFact = RecordSerializerFactory.get();

    ChainTaskTest() {
    }

    @Test
    void testMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        double memoryFraction = 1.0;
        try {
            this.initEnvironment(0x300000L, 1024);
            this.addInput(new UniformRecordGenerator(100, 20, false), 0);
            this.addOutput(this.outList);
            TaskConfig combineConfig = new TaskConfig(new Configuration());
            combineConfig.addInputToGroup(0);
            combineConfig.setInputSerializer((TypeSerializerFactory)this.serFact, 0);
            combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            combineConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 0);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 1);
            combineConfig.setRelativeMemoryDriver(1.0);
            combineConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(ReduceTaskTest.MockCombiningReduceStub.class));
            this.getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
            this.registerTask(FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            BatchTask testTask = new BatchTask((Environment)this.mockEnv);
            try {
                testTask.invoke();
            }
            catch (Exception e) {
                e.printStackTrace();
                Assertions.fail((String)"Invoke method caused exception.");
            }
            Assertions.assertThat(this.outList).hasSize(100);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
    }

    @Test
    void testFailingMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        long memorySize = 0x300000L;
        int bufferSize = 1038336;
        double memoryFraction = 1.0;
        try {
            this.initEnvironment(0x300000L, 1038336);
            this.addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
            this.addOutput(this.outList);
            TaskConfig combineConfig = new TaskConfig(new Configuration());
            combineConfig.addInputToGroup(0);
            combineConfig.setInputSerializer((TypeSerializerFactory)this.serFact, 0);
            combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            combineConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 0);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 1);
            combineConfig.setRelativeMemoryDriver(1.0);
            combineConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(MockFailingCombineStub.class));
            this.getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
            this.registerTask(FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            BatchTask testTask = new BatchTask((Environment)this.mockEnv);
            boolean stubFailed = false;
            try {
                testTask.invoke();
            }
            catch (Exception e) {
                stubFailed = true;
            }
            ((AbstractBooleanAssert)Assertions.assertThat((boolean)stubFailed).withFailMessage("Function exception was not forwarded.", new Object[0])).isTrue();
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
    }

    @Test
    void testBatchTaskOutputInCloseMethod() {
        int numChainedTasks = 10;
        int keyCnt = 100;
        int valCnt = 10;
        try {
            this.initEnvironment(0x300000L, 1024);
            this.addInput(new UniformRecordGenerator(100, 10, false), 0);
            this.addOutput(this.outList);
            this.registerTask(FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            for (int i = 0; i < 10; ++i) {
                TaskConfig taskConfig = new TaskConfig(new Configuration());
                taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
                taskConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
                taskConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(MockDuplicateLastValueMapFunction.class));
                this.getTaskConfig().addChainedTask(ChainedFlatMapDriver.class, taskConfig, "chained-" + i);
            }
            BatchTask testTask = new BatchTask((Environment)this.mockEnv);
            testTask.invoke();
            Assertions.assertThat(this.outList).hasSize(1010);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
    }

    @Test
    void testDataSourceTaskOutputInCloseMethod() throws IOException {
        int numChainedTasks = 10;
        int keyCnt = 100;
        int valCnt = 10;
        File tempTestFile = new File(this.tempFolder.toFile(), UUID.randomUUID().toString());
        DataSourceTaskTest.InputFilePreparator.prepareInputFile(new UniformRecordGenerator(100, 10, false), tempTestFile, true);
        this.initEnvironment(0x300000L, 1024);
        this.addOutput(this.outList);
        DataSourceTask testTask = new DataSourceTask((Environment)this.mockEnv);
        this.registerFileInputTask((AbstractInvokable)testTask, DataSourceTaskTest.MockInputFormat.class, tempTestFile.toURI().toString(), "\n");
        for (int i = 0; i < 10; ++i) {
            TaskConfig taskConfig = new TaskConfig(new Configuration());
            taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            taskConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            taskConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(MockDuplicateLastValueMapFunction.class));
            this.getTaskConfig().addChainedTask(ChainedFlatMapDriver.class, taskConfig, "chained-" + i);
        }
        try {
            testTask.invoke();
            Assertions.assertThat(this.outList).hasSize(1010);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)"Invoke method caused exception.");
        }
    }

    public static class MockDuplicateLastValueMapFunction<T>
    extends RichFlatMapFunction<T, T> {
        private boolean closed = false;
        private transient T value;
        private transient Collector<T> out;

        public void flatMap(T value, Collector<T> out) throws Exception {
            if (this.closed) {
                throw new IllegalStateException("Task is already closed.");
            }
            this.value = value;
            this.out = out;
            out.collect(value);
        }

        public void close() throws Exception {
            this.closed = true;
            this.out.collect(this.value);
        }
    }

    public static final class MockFailingCombineStub
    implements GroupReduceFunction<Record, Record>,
    GroupCombineFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private int cnt = 0;

        public void reduce(Iterable<Record> records, Collector<Record> out) throws Exception {
            if (++this.cnt >= 5) {
                throw new RuntimeException("Expected Test Exception");
            }
            for (Record r : records) {
                out.collect((Object)r);
            }
        }

        public void combine(Iterable<Record> values, Collector<Record> out) throws Exception {
            this.reduce(values, out);
        }
    }
}

