/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.sink;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.SavepointType;
import org.apache.flink.runtime.checkpoint.SnapshotType;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
import org.apache.flink.table.runtime.generated.HashFunction;
import org.apache.flink.table.runtime.generated.RecordEqualiser;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.sink.SinkUpsertMaterializer;
import org.apache.flink.table.runtime.operators.sink.SinkUpsertMaterializerStateBackend;
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor;
import org.apache.flink.table.runtime.util.StateConfigUtil;
import org.apache.flink.table.runtime.util.StreamRecordUtils;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.utils.HandwrittenSelectorUtil;
import org.apache.flink.types.RowKind;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class SinkUpsertMaterializerRescalingTest {
    @Parameterized.Parameter
    public SinkUpsertMaterializerStateBackend backend;
    private static final StateTtlConfig TTL_CONFIG = StateConfigUtil.createTtlConfig((long)1000L);
    private static final LogicalType[] LOGICAL_TYPES = new LogicalType[]{new BigIntType(), new IntType(), new VarCharType()};
    private static final RowDataKeySelector KEY_SELECTOR = HandwrittenSelectorUtil.getRowDataSelector(new int[]{1}, LOGICAL_TYPES);
    private static final RowDataHarnessAssertor ASSERTOR = new RowDataHarnessAssertor(LOGICAL_TYPES);
    private static final GeneratedRecordEqualiser EQUALISER = new MyGeneratedRecordEqualiser();
    private static final GeneratedRecordEqualiser UPSERT_KEY_EQUALISER = new GeneratedRecordEqualiser("", "", new Object[0]){

        public RecordEqualiser newInstance(ClassLoader classLoader) {
            return new TestUpsertKeyEqualiser();
        }
    };

    @Parameterized.Parameters(name="stateBackend={0}")
    public static Object[][] generateTestParameters() {
        ArrayList<Object[]> result = new ArrayList<Object[]>();
        for (SinkUpsertMaterializerStateBackend backend : SinkUpsertMaterializerStateBackend.values()) {
            result.add(new Object[]{backend});
        }
        return (Object[][])result.toArray((T[])new Object[0][]);
    }

    @Test
    public void testScaleUpThenDown() throws Exception {
        this.testRescaleFromToFrom(10, 2, 3, this.backend, this.backend);
    }

    @Test
    public void testScaleDownThenUp() throws Exception {
        this.testRescaleFromToFrom(10, 3, 2, this.backend, this.backend);
    }

    @Test
    public void testRecovery() throws Exception {
        this.testRescaleFromToFrom(1, 1, 1, this.backend, this.backend);
    }

    @Test
    public void testForwardAndBackwardMigration() throws Exception {
        this.testRescaleFromToFrom(7, 3, 3, this.backend, this.getOtherBackend(this.backend));
    }

    @Test
    public void testScaleUpThenDownWithMigration() throws Exception {
        this.testRescaleFromToFrom(7, 1, 5, this.backend, this.getOtherBackend(this.backend));
    }

    @Test
    public void testScaleDownThenUpWithMigration() throws Exception {
        this.testRescaleFromToFrom(7, 5, 1, this.backend, this.getOtherBackend(SinkUpsertMaterializerStateBackend.HEAP));
    }

    private SinkUpsertMaterializerStateBackend getOtherBackend(SinkUpsertMaterializerStateBackend backend) {
        return backend == SinkUpsertMaterializerStateBackend.HEAP ? SinkUpsertMaterializerStateBackend.ROCKSDB : SinkUpsertMaterializerStateBackend.HEAP;
    }

    private void testRescaleFromToFrom(int maxParallelism, int fromParallelism, int toParallelism, SinkUpsertMaterializerStateBackend fromBackend, SinkUpsertMaterializerStateBackend toBackend) throws Exception {
        int[] currentParallelismRef = new int[]{fromParallelism};
        boolean useSavepoint = fromBackend != toBackend;
        OneInputStreamOperator[] materializers = new OneInputStreamOperator[maxParallelism];
        KeyedOneInputStreamOperatorTestHarness[] harnesses = new KeyedOneInputStreamOperatorTestHarness[maxParallelism];
        ToIntFunction<StreamRecord> combinedHarnesses = r -> {
            try {
                int subtaskIndex = KeyGroupRangeAssignment.assignKeyToParallelOperator((Object)KEY_SELECTOR.getKey((Object)((RowData)r.getValue())), (int)maxParallelism, (int)currentParallelismRef[0]);
                harnesses[subtaskIndex].processElement(r);
                return subtaskIndex;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        };
        this.initHarnessesAndMaterializers(harnesses, materializers, fromBackend, maxParallelism, fromParallelism, null);
        int idx = combinedHarnesses.applyAsInt(StreamRecordUtils.insertRecord(1L, 1, "a1"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.INSERT, 1L, 1, "a1"));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.insertRecord(2L, 1, "a2"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.UPDATE_AFTER, 2L, 1, "a2"));
        List<OperatorSubtaskState> subtaskStates = this.snapshotHarnesses(harnesses, fromParallelism, 1L, useSavepoint);
        currentParallelismRef[0] = toParallelism;
        this.initHarnessesAndMaterializers(harnesses, materializers, toBackend, maxParallelism, toParallelism, subtaskStates);
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.insertRecord(3L, 1, "a3"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.UPDATE_AFTER, 3L, 1, "a3"));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.insertRecord(4L, 1, "a4"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.UPDATE_AFTER, 4L, 1, "a4"));
        subtaskStates = this.snapshotHarnesses(harnesses, toParallelism, 2L, useSavepoint);
        currentParallelismRef[0] = fromParallelism;
        this.initHarnessesAndMaterializers(harnesses, materializers, fromBackend, maxParallelism, fromParallelism, subtaskStates);
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.deleteRecord(4L, 1, "a4"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.UPDATE_AFTER, 3L, 1, "a3"));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.deleteRecord(2L, 1, "a2"));
        ASSERTOR.shouldEmitNothing((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx]);
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.deleteRecord(3L, 1, "a3"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.UPDATE_AFTER, 1L, 1, "a1"));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.deleteRecord(1L, 1, "a1"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.DELETE, 1L, 1, "a1"));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.insertRecord(4L, 1, "a4"));
        ASSERTOR.shouldEmit((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx], StreamRecordUtils.rowOfKind(RowKind.INSERT, 4L, 1, "a4"));
        Arrays.stream(harnesses).filter(Objects::nonNull).forEach(h -> h.setStateTtlProcessingTime(1002L));
        idx = combinedHarnesses.applyAsInt(StreamRecordUtils.deleteRecord(4L, 1, "a4"));
        ASSERTOR.shouldEmitNothing((AbstractStreamOperatorTestHarness<RowData>)harnesses[idx]);
        Arrays.stream(harnesses).filter(Objects::nonNull).forEach(h -> {
            try {
                h.close();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    private void initHarnessesAndMaterializers(KeyedOneInputStreamOperatorTestHarness<RowData, RowData, RowData>[] harnesses, OneInputStreamOperator<RowData, RowData>[] materializers, SinkUpsertMaterializerStateBackend backend, int maxParallelism, int parallelism, @Nullable List<OperatorSubtaskState> subtaskStates) throws Exception {
        for (int i = 0; i < parallelism; ++i) {
            materializers[i] = SinkUpsertMaterializer.create((StateTtlConfig)TTL_CONFIG, (RowType)RowType.of((LogicalType[])LOGICAL_TYPES), (GeneratedRecordEqualiser)EQUALISER, (GeneratedRecordEqualiser)UPSERT_KEY_EQUALISER, null);
            harnesses[i] = new KeyedOneInputStreamOperatorTestHarness(materializers[i], (KeySelector)KEY_SELECTOR, (TypeInformation)KEY_SELECTOR.getProducedType(), maxParallelism, parallelism, i);
            harnesses[i].setStateBackend(backend.create(false));
            if (subtaskStates != null) {
                OperatorSubtaskState operatorSubtaskState = AbstractStreamOperatorTestHarness.repackageState((OperatorSubtaskState[])subtaskStates.toArray(new OperatorSubtaskState[0]));
                harnesses[i].initializeState(AbstractStreamOperatorTestHarness.repartitionOperatorState((OperatorSubtaskState)operatorSubtaskState, (int)maxParallelism, (int)subtaskStates.size(), (int)parallelism, (int)i));
            }
            harnesses[i].open();
            harnesses[i].setStateTtlProcessingTime(1L);
        }
    }

    private List<OperatorSubtaskState> snapshotHarnesses(KeyedOneInputStreamOperatorTestHarness<RowData, RowData, RowData>[] harnesses, int parallelism, long checkpointId, boolean useSavepoint) {
        return Arrays.stream(harnesses, 0, parallelism).map(h -> {
            try {
                return h.snapshotWithLocalState(checkpointId, 0L, (SnapshotType)(useSavepoint ? SavepointType.savepoint((SavepointFormatType)SavepointFormatType.CANONICAL) : CheckpointType.CHECKPOINT)).getJobManagerOwnedState();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).collect(Collectors.toList());
    }

    private static class MyGeneratedRecordEqualiser
    extends GeneratedRecordEqualiser {
        public MyGeneratedRecordEqualiser() {
            super("", "", new Object[0]);
        }

        public RecordEqualiser newInstance(ClassLoader classLoader) {
            return new TestRecordEqualiser();
        }
    }

    protected static class TestUpsertKeyEqualiser
    implements RecordEqualiser,
    HashFunction {
        protected TestUpsertKeyEqualiser() {
        }

        public boolean equals(RowData row1, RowData row2) {
            return row1.getRowKind() == row2.getRowKind() && row1.getLong(0) == row2.getLong(0);
        }

        public int hashCode(Object data) {
            RowData rd = (RowData)data;
            return Objects.hash(rd.getRowKind(), rd.getLong(0));
        }
    }

    protected static class TestRecordEqualiser
    implements RecordEqualiser,
    HashFunction {
        protected TestRecordEqualiser() {
        }

        public boolean equals(RowData row1, RowData row2) {
            return row1.getRowKind() == row2.getRowKind() && row1.getLong(0) == row2.getLong(0) && row1.getInt(1) == row2.getInt(1) && row1.getString(2).equals(row2.getString(2));
        }

        public int hashCode(Object data) {
            RowData rd = (RowData)data;
            return Objects.hash(rd.getRowKind(), rd.getLong(0), rd.getInt(1), rd.getString(2));
        }
    }
}

