/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition;

import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Optional;
import org.apache.flink.core.memory.MemorySegmentProvider;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NoOpBufferPool;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.util.function.SupplierWithException;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

public class InputGateFairnessTest {
    @Test
    public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
        int i;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        BufferConsumer bufferConsumer = BufferBuilderTestUtils.createFilledFinishedBufferConsumer(42);
        PipelinedSubpartition[] sources = new PipelinedSubpartition[37];
        for (int i2 = 0; i2 < 37; ++i2) {
            PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition);
            for (int p = 0; p < 27; ++p) {
                partition.add(bufferConsumer.copy());
            }
            partition.finish();
            sources[i2] = partition;
        }
        ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])sources);
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        InputChannel[] inputChannels = new InputChannel[37];
        for (i = 0; i < 37; ++i) {
            inputChannels[i] = InputChannelTestUtils.createLocalInputChannel(gate, i, resultPartitionManager);
        }
        InputGateFairnessTest.setupInputGate(gate, inputChannels);
        for (i = 1036; i > 0; --i) {
            Assert.assertNotNull((Object)gate.getNext());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (PipelinedSubpartition source : sources) {
                int size = source.getCurrentNumberOfBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
        }
        Assert.assertFalse((boolean)gate.getNext().isPresent());
    }

    @Test
    public void testFairConsumptionLocalChannels() throws Exception {
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        try (BufferConsumer bufferConsumer = BufferBuilderTestUtils.createFilledFinishedBufferConsumer(42);){
            int i;
            PipelinedSubpartition[] sources = new PipelinedSubpartition[37];
            for (int i2 = 0; i2 < 37; ++i2) {
                sources[i2] = new PipelinedSubpartition(0, resultPartition);
            }
            ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])sources);
            SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
            InputChannel[] inputChannels = new InputChannel[37];
            for (i = 0; i < 37; ++i) {
                inputChannels[i] = InputChannelTestUtils.createLocalInputChannel(gate, i, resultPartitionManager);
            }
            sources[12].add(bufferConsumer.copy());
            InputGateFairnessTest.setupInputGate(gate, inputChannels);
            for (i = 0; i < 999; ++i) {
                Assert.assertNotNull((Object)gate.getNext());
                int min = Integer.MAX_VALUE;
                int max = 0;
                for (PipelinedSubpartition source : sources) {
                    int size = source.getCurrentNumberOfBuffers();
                    min = Math.min(min, size);
                    max = Math.max(max, size);
                }
                Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
                if (i % 74 != 0) continue;
                this.fillRandom(sources, 3, bufferConsumer);
            }
        }
    }

    @Test
    public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
        int i;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = TestBufferFactory.createBuffer(42);
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = InputGateFairnessTest.createRemoteInputChannel(gate, i, connManager);
            for (int p = 0; p < 27; ++p) {
                channel.onBuffer(mockBuffer, p, -1);
            }
            channel.onBuffer(EventSerializer.toBuffer((AbstractEvent)EndOfPartitionEvent.INSTANCE), 27, -1);
        }
        gate.setInputChannels((InputChannel[])channels);
        gate.setup();
        gate.requestPartitions();
        for (i = 1036; i > 0; --i) {
            Assert.assertNotNull((Object)gate.getNext());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
        }
        Assert.assertFalse((boolean)gate.getNext().isPresent());
    }

    @Test
    public void testFairConsumptionRemoteChannels() throws Exception {
        int i;
        int numberOfChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = TestBufferFactory.createBuffer(42);
        SingleInputGate gate = this.createFairnessVerifyingInputGate(37);
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        int[] channelSequenceNums = new int[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = InputGateFairnessTest.createRemoteInputChannel(gate, i, connManager);
        }
        channels[11].onBuffer(mockBuffer, 0, -1);
        channelSequenceNums[11] = channelSequenceNums[11] + 1;
        InputGateFairnessTest.setupInputGate(gate, (InputChannel[])channels);
        for (i = 0; i < 999; ++i) {
            Assert.assertNotNull((Object)gate.getNext());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
            if (i % 74 != 0) continue;
            this.fillRandom(channels, channelSequenceNums, 3, mockBuffer);
        }
    }

    private SingleInputGate createFairnessVerifyingInputGate(int numberOfChannels) {
        return new FairnessVerifyingInputGate("Test Task Name", new IntermediateDataSetID(), 0, numberOfChannels);
    }

    private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, BufferConsumer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        for (Integer i : poss) {
            partitions[i].add(buffer.copy());
        }
    }

    private void fillRandom(RemoteInputChannel[] partitions, int[] sequenceNumbers, int numPerPartition, Buffer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        Iterator iterator = poss.iterator();
        while (iterator.hasNext()) {
            int i;
            int n = i = ((Integer)iterator.next()).intValue();
            int n2 = sequenceNumbers[n];
            sequenceNumbers[n] = n2 + 1;
            partitions[i].onBuffer(buffer, n2, -1);
        }
    }

    public static RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate, int channelIndex, ConnectionManager connectionManager) {
        return InputChannelBuilder.newBuilder().setChannelIndex(channelIndex).setConnectionManager(connectionManager).buildRemoteChannel(inputGate);
    }

    public static void setupInputGate(SingleInputGate gate, InputChannel ... channels) throws IOException {
        gate.setInputChannels(channels);
        gate.setup();
        gate.requestPartitions();
    }

    private static class FairnessVerifyingInputGate
    extends SingleInputGate {
        private static final SupplierWithException<BufferPool, IOException> STUB_BUFFER_POOL_FACTORY = NoOpBufferPool::new;
        private final ArrayDeque<InputChannel> channelsWithData;
        private final HashSet<InputChannel> uniquenessChecker;

        public FairnessVerifyingInputGate(String owningTaskName, IntermediateDataSetID consumedResultId, int consumedSubpartitionIndex, int numberOfInputChannels) {
            super(owningTaskName, 0, consumedResultId, ResultPartitionType.PIPELINED, consumedSubpartitionIndex, numberOfInputChannels, SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, STUB_BUFFER_POOL_FACTORY, null, (MemorySegmentProvider)new InputChannelTestUtils.UnpooledMemorySegmentProvider(32768));
            try {
                Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
                f.setAccessible(true);
                this.channelsWithData = (ArrayDeque)f.get((Object)this);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.uniquenessChecker = new HashSet();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public Optional<BufferOrEvent> getNext() throws IOException, InterruptedException {
            ArrayDeque<InputChannel> arrayDeque = this.channelsWithData;
            synchronized (arrayDeque) {
                Assert.assertTrue((String)"too many input channels", (this.channelsWithData.size() <= this.getNumberOfInputChannels() ? 1 : 0) != 0);
                this.ensureUnique(this.channelsWithData);
            }
            return super.getNext();
        }

        private void ensureUnique(Collection<InputChannel> channels) {
            HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
            for (InputChannel channel : channels) {
                if (uniquenessChecker.add(channel)) continue;
                Assert.fail((String)("Duplicate channel in input gate: " + channel));
            }
            Assert.assertTrue((String)"found duplicate input channels", (uniquenessChecker.size() == channels.size() ? 1 : 0) != 0);
            uniquenessChecker.clear();
        }
    }
}

