/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.api.writer;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Queue;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.api.writer.BroadcastRecordWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriterTest;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
import org.apache.flink.testutils.serialization.types.IntType;
import org.apache.flink.testutils.serialization.types.SerializationTestType;
import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory;
import org.apache.flink.testutils.serialization.types.Util;
import org.junit.Assert;
import org.junit.Test;

public class BroadcastRecordWriterTest
extends RecordWriterTest {
    public BroadcastRecordWriterTest() {
        super(true);
    }

    @Test
    public void testBroadcastMixedRandomEmitRecord() throws Exception {
        int numberOfChannels = 4;
        int numberOfRecords = 8;
        int bufferSize = 32;
        Queue[] queues = new Queue[4];
        for (int i = 0; i < 4; ++i) {
            queues[i] = new ArrayDeque();
        }
        TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, 32);
        RecordWriterTest.CollectingPartitionWriter partitionWriter = new RecordWriterTest.CollectingPartitionWriter(queues, bufferProvider);
        BroadcastRecordWriter writer = new BroadcastRecordWriter((ResultPartitionWriter)partitionWriter, 0L, "test");
        SpillingAdaptiveSpanningRecordDeserializer deserializer = new SpillingAdaptiveSpanningRecordDeserializer(new String[]{this.tempFolder.getRoot().getAbsolutePath()});
        Util.MockRecords records = Util.randomRecords((int)8, (SerializationTestTypeFactory)SerializationTestTypeFactory.INT);
        HashMap serializedRecords = new HashMap();
        for (int i = 0; i < 4; ++i) {
            serializedRecords.put(i, new ArrayDeque());
        }
        int index = 0;
        for (SerializationTestType record : records) {
            int randomChannel = index++ % 4;
            writer.randomEmit((IOReadableWritable)record, randomChannel);
            ((ArrayDeque)serializedRecords.get(randomChannel)).add(record);
            writer.broadcastEmit((IOReadableWritable)record);
            for (int i = 0; i < 4; ++i) {
                ((ArrayDeque)serializedRecords.get(i)).add(record);
            }
        }
        int numberOfCreatedBuffers = bufferProvider.getNumberOfCreatedBuffers();
        Assert.assertEquals((long)8L, (long)numberOfCreatedBuffers);
        for (int i = 0; i < 4; ++i) {
            Assert.assertEquals((long)8L, (long)queues[i].size());
            int excessRandomRecords = i < 0 ? 1 : 0;
            int numberOfRandomRecords = 2 + excessRandomRecords;
            int numberOfTotalRecords = 8 + numberOfRandomRecords;
            this.verifyDeserializationResults(queues[i], (RecordDeserializer<SerializationTestType>)deserializer, (ArrayDeque)serializedRecords.get(i), numberOfCreatedBuffers, numberOfTotalRecords);
        }
    }

    @Test
    public void testRandomEmitAndBufferRecycling() throws Exception {
        int recordSize = 8;
        TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(2, 2 * recordSize);
        RecordWriterTest.KeepingPartitionWriter partitionWriter = new RecordWriterTest.KeepingPartitionWriter(bufferProvider){

            @Override
            public int getNumberOfSubpartitions() {
                return 2;
            }
        };
        BroadcastRecordWriter writer = new BroadcastRecordWriter((ResultPartitionWriter)partitionWriter, 0L, "test");
        List<Buffer> buffers = Arrays.asList(bufferProvider.requestBuffer(), bufferProvider.requestBuffer());
        buffers.forEach(Buffer::recycleBuffer);
        Assert.assertEquals((long)2L, (long)bufferProvider.getNumberOfAvailableBuffers());
        writer.randomEmit((IOReadableWritable)new IntType(1), 0);
        writer.broadcastEmit((IOReadableWritable)new IntType(2));
        Assert.assertEquals((long)1L, (long)bufferProvider.getNumberOfAvailableBuffers());
        Assert.assertEquals((long)1L, (long)partitionWriter.getAddedBufferConsumers(0).size());
        this.closeConsumer(partitionWriter, 0, 2 * recordSize);
        Assert.assertEquals((long)1L, (long)bufferProvider.getNumberOfAvailableBuffers());
        writer.broadcastEmit((IOReadableWritable)new IntType(3));
        Assert.assertEquals((long)0L, (long)bufferProvider.getNumberOfAvailableBuffers());
        Assert.assertEquals((long)2L, (long)partitionWriter.getAddedBufferConsumers(1).size());
        this.closeConsumer(partitionWriter, 1, recordSize);
        Assert.assertEquals((long)1L, (long)bufferProvider.getNumberOfAvailableBuffers());
    }

    public void closeConsumer(RecordWriterTest.KeepingPartitionWriter partitionWriter, int subpartitionIndex, int expectedSize) {
        BufferConsumer bufferConsumer = partitionWriter.getAddedBufferConsumers(subpartitionIndex).get(0);
        Buffer buffer = bufferConsumer.build();
        bufferConsumer.close();
        Assert.assertEquals((long)expectedSize, (long)buffer.getSize());
        buffer.recycleBuffer();
    }
}

