/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
import org.apache.flink.core.testutils.CheckedThread;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.operators.CoGroupDriver;
import org.apache.flink.runtime.operators.CoGroupTaskExternalITCase;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.testutils.DelayingInfinitiveInputIterator;
import org.apache.flink.runtime.operators.testutils.DriverTestBase;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.TaskCancelThread;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.testutils.recordutils.RecordComparator;
import org.apache.flink.runtime.testutils.recordutils.RecordPairComparatorFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.TestTemplate;

class CoGroupTaskTest
extends DriverTestBase<CoGroupFunction<Record, Record, Record>> {
    private static final long SORT_MEM = 0x300000L;
    private final RecordComparator comparator1 = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
    private final RecordComparator comparator2 = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
    private final DriverTestBase.CountingOutputCollector output = new DriverTestBase.CountingOutputCollector();

    CoGroupTaskTest(ExecutionConfig config) {
        super(config, 0L, 2, 0x300000L);
    }

    @TestTemplate
    void testSortBoth1CoGroupTask() throws Exception {
        int keyCnt1 = 100;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 1;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        this.addInputSorted(new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate());
        this.addInputSorted(new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate());
        this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(expCnt);
    }

    @TestTemplate
    void testSortBoth2CoGroupTask() throws Exception {
        int keyCnt1 = 200;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 4;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        this.addInputSorted(new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate());
        this.addInputSorted(new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate());
        this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(expCnt);
    }

    @TestTemplate
    void testSortFirstCoGroupTask() throws Exception {
        int keyCnt1 = 200;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 4;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        this.addInputSorted(new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate());
        this.addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true));
        this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(expCnt);
    }

    @TestTemplate
    void testSortSecondCoGroupTask() throws Exception {
        int keyCnt1 = 200;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 4;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        this.addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true));
        this.addInputSorted(new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate());
        this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(expCnt);
    }

    @TestTemplate
    void testMergeCoGroupTask() throws Exception {
        int keyCnt1 = 200;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 4;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true));
        this.addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true));
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(expCnt);
    }

    @TestTemplate
    void testFailingSortCoGroupTask() {
        int keyCnt1 = 100;
        int valCnt1 = 2;
        int keyCnt2 = 200;
        int valCnt2 = 1;
        this.setOutput(this.output);
        this.addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true));
        this.addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true));
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        Assertions.assertThatThrownBy(() -> this.testDriver((Driver)testTask, MockFailingCoGroupStub.class)).isInstanceOf(ExpectedTestException.class);
    }

    @TestTemplate
    void testCancelCoGroupTaskWhileSorting1() throws Exception {
        int keyCnt = 10;
        int valCnt = 2;
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        final CoGroupDriver testTask = new CoGroupDriver();
        this.addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator1.duplicate());
        this.addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
        CheckedThread taskRunner = new CheckedThread(){

            public void go() throws Exception {
                CoGroupTaskTest.this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(1, (Thread)taskRunner, this);
        tct.start();
        tct.join();
        taskRunner.sync();
    }

    @TestTemplate
    void testCancelCoGroupTaskWhileSorting2() throws Exception {
        int keyCnt = 10;
        int valCnt = 2;
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        final CoGroupDriver testTask = new CoGroupDriver();
        this.addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
        this.addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator2.duplicate());
        CheckedThread taskRunner = new CheckedThread(){

            public void go() throws Exception {
                CoGroupTaskTest.this.testDriver((Driver)testTask, CoGroupTaskExternalITCase.MockCoGroupStub.class);
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(1, (Thread)taskRunner, this);
        tct.start();
        tct.join();
        taskRunner.sync();
    }

    @TestTemplate
    void testCancelCoGroupTaskWhileCoGrouping() throws Exception {
        int keyCnt = 100;
        int valCnt = 5;
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        final CoGroupDriver testTask = new CoGroupDriver();
        this.addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
        this.addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
        final OneShotLatch delayCoGroupProcessingLatch = new OneShotLatch();
        CheckedThread taskRunner = new CheckedThread(){

            public void go() throws Exception {
                CoGroupTaskTest.this.testDriver((Driver)testTask, (Function)new MockDelayingCoGroupStub(delayCoGroupProcessingLatch));
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(1, (Thread)taskRunner, this);
        tct.start();
        tct.join();
        delayCoGroupProcessingLatch.trigger();
        taskRunner.sync();
    }

    public static final class MockDelayingCoGroupStub
    extends RichCoGroupFunction<Record, Record, Record> {
        private static final long serialVersionUID = 1L;
        private final OneShotLatch delayCoGroupProcessingLatch;

        public MockDelayingCoGroupStub(OneShotLatch delayCoGroupProcessingLatch) {
            this.delayCoGroupProcessingLatch = delayCoGroupProcessingLatch;
        }

        public void coGroup(Iterable<Record> records1, Iterable<Record> records2, Collector<Record> out) throws InterruptedException {
            this.delayCoGroupProcessingLatch.await();
        }
    }

    public static class MockFailingCoGroupStub
    extends RichCoGroupFunction<Record, Record, Record> {
        private static final long serialVersionUID = 1L;
        private int cnt = 0;

        public void coGroup(Iterable<Record> records1, Iterable<Record> records2, Collector<Record> out) {
            int val1Cnt = 0;
            for (Record r : records1) {
                ++val1Cnt;
            }
            for (Record record2 : records2) {
                if (val1Cnt == 0) {
                    if (++this.cnt >= 10) {
                        throw new ExpectedTestException();
                    }
                    out.collect((Object)record2);
                    continue;
                }
                for (int i = 0; i < val1Cnt; ++i) {
                    if (++this.cnt >= 10) {
                        throw new ExpectedTestException();
                    }
                    out.collect((Object)record2);
                }
            }
        }
    }
}

