/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.testutils.EmptyStreamStateHandle;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class ChannelStateNoRescalingPartitionerTest {
    private static final OperatorID OPERATOR_ID = new OperatorID();
    private final int oldParallelism;
    private final int newParallelism;
    private final int offsetsSize;
    private final Function<OperatorSubtaskState, ? extends StateObjectCollection<?>> extractState;

    @Parameterized.Parameters(name="oldParallelism: {0}, newParallelism: {1}, offsetSize: {2}")
    public static Collection<Object[]> parameters() {
        ArrayList<Object[]> params = new ArrayList<Object[]>();
        int[] parLevels = new int[]{1, 2};
        int[] offsetSizes = new int[]{0, 1, 2};
        List<Function> extractors = Arrays.asList(OperatorSubtaskState::getInputChannelState, OperatorSubtaskState::getResultSubpartitionState);
        for (int oldParallelism : parLevels) {
            for (int newParallelism : parLevels) {
                for (int offsetSize : offsetSizes) {
                    for (Function stateExtractor : extractors) {
                        params.add(new Object[]{oldParallelism, newParallelism, offsetSize, stateExtractor});
                    }
                }
            }
        }
        return params;
    }

    public ChannelStateNoRescalingPartitionerTest(int oldParallelism, int newParallelism, int offsetsSize, Function<OperatorSubtaskState, ? extends StateObjectCollection<?>> extractState) {
        this.oldParallelism = oldParallelism;
        this.newParallelism = newParallelism;
        this.offsetsSize = offsetsSize;
        this.extractState = extractState;
    }

    @Test
    public <T extends AbstractChannelStateHandle<?>> void testNoRescaling() {
        OperatorState state = new OperatorState(OPERATOR_ID, this.oldParallelism, this.oldParallelism);
        state.putState(0, new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.singleton((StateObject)new InputChannelStateHandle(new InputChannelInfo(0, 0), (StreamStateHandle)new EmptyStreamStateHandle(), this.getOffset())), StateObjectCollection.singleton((StateObject)new ResultSubpartitionStateHandle(new ResultSubpartitionInfo(0, 0), (StreamStateHandle)new EmptyStreamStateHandle(), this.getOffset()))));
        try {
            StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonList(state), (int)this.newParallelism, Collections.singletonList(OperatorIDPair.generatedIDOnly((OperatorID)OPERATOR_ID)), this.extractState, (OperatorStateRepartitioner)StateAssignmentOperation.channelStateNonRescalingRepartitioner((String)"test"));
        }
        catch (IllegalArgumentException e) {
            if (!this.shouldFail()) {
                throw e;
            }
            return;
        }
        if (this.shouldFail()) {
            Assert.fail((String)("expected to fail for: oldParallelism=" + this.oldParallelism + ", newParallelism=" + this.newParallelism + ", offsetsSize=" + this.offsetsSize + ", extractState=" + this.extractState));
        }
    }

    private boolean shouldFail() {
        return this.oldParallelism != this.newParallelism && this.offsetsSize > 0;
    }

    private List<Long> getOffset() {
        ArrayList<Long> offsets = new ArrayList<Long>(this.offsetsSize);
        for (int i = 0; i < this.offsetsSize; ++i) {
            offsets.add(0L);
        }
        return offsets;
    }
}

