package org.apache.flink.runtime.shuffle;

import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.Optional;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.clusterframework.types.ResourceID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder;
import org.apache.flink.shaded.guava32.com.google.common.collect.ImmutableMap;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/shuffle/NettyShuffleUtilsTest.class */
class NettyShuffleUtilsTest {
    NettyShuffleUtilsTest() {
    }

    @Test
    void testComputeRequiredNetworkBuffers() throws Exception {
        Optional of = Optional.of(Integer.MAX_VALUE);
        IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
        IntermediateDataSetID intermediateDataSetID2 = new IntermediateDataSetID();
        IntermediateDataSetID intermediateDataSetID3 = new IntermediateDataSetID();
        IntermediateDataSetID intermediateDataSetID4 = new IntermediateDataSetID();
        IntermediateDataSetID intermediateDataSetID5 = new IntermediateDataSetID();
        int computeNetworkBuffersForAnnouncing = NettyShuffleUtils.computeNetworkBuffersForAnnouncing(5, 8, of, 8, 12, ImmutableMap.of(intermediateDataSetID, 3, intermediateDataSetID2, 4), ImmutableMap.of(intermediateDataSetID, 1, intermediateDataSetID2, 1), ImmutableMap.of(intermediateDataSetID3, 5, intermediateDataSetID4, 6, intermediateDataSetID5, 10), ImmutableMap.of(intermediateDataSetID, ResultPartitionType.PIPELINED_BOUNDED, intermediateDataSetID2, ResultPartitionType.BLOCKING), ImmutableMap.of(intermediateDataSetID3, ResultPartitionType.PIPELINED_BOUNDED, intermediateDataSetID4, ResultPartitionType.BLOCKING, intermediateDataSetID5, ResultPartitionType.BLOCKING));
        NettyShuffleEnvironment build = new NettyShuffleEnvironmentBuilder().setNumNetworkBuffers(computeNetworkBuffersForAnnouncing).setNetworkBuffersPerChannel(5).setSortShuffleMinBuffers(12).setSortShuffleMinParallelism(8).build();
        SingleInputGate createInputGate = createInputGate(build, ResultPartitionType.PIPELINED_BOUNDED, 3);
        createInputGate.setup();
        SingleInputGate createInputGate2 = createInputGate(build, ResultPartitionType.BLOCKING, 4);
        createInputGate2.setup();
        ResultPartition createResultPartition = createResultPartition(build, ResultPartitionType.PIPELINED_BOUNDED, 5);
        createResultPartition.setup();
        ResultPartition createResultPartition2 = createResultPartition(build, ResultPartitionType.BLOCKING, 6);
        createResultPartition2.setup();
        ResultPartition createResultPartition3 = createResultPartition(build, ResultPartitionType.BLOCKING, 10);
        createResultPartition3.setup();
        Assertions.assertThat(computeNetworkBuffersForAnnouncing).isEqualTo(calculateBuffersConsumption(createInputGate) + calculateBuffersConsumption(createInputGate2) + calculateBuffersConsumption(createResultPartition) + calculateBuffersConsumption(createResultPartition2) + calculateBuffersConsumption(createResultPartition3));
        createInputGate.close();
        createInputGate2.close();
        createResultPartition.close();
        createResultPartition2.close();
        createResultPartition3.close();
    }

    private SingleInputGate createInputGate(NettyShuffleEnvironment nettyShuffleEnvironment, ResultPartitionType resultPartitionType, int i) throws IOException {
        TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndexArr = new TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex[i];
        for (int i2 = 0; i2 < i; i2++) {
            shuffleDescriptorAndIndexArr[i2] = new TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex(NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation(new IntermediateResultPartitionID(), ResourceID.generate()), i2);
        }
        return (SingleInputGate) nettyShuffleEnvironment.createInputGates(nettyShuffleEnvironment.createShuffleIOOwnerContext("", ExecutionGraphTestUtils.createExecutionAttemptId(), new UnregisteredMetricsGroup()), SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, Collections.singletonList(new InputGateDeploymentDescriptor(new IntermediateDataSetID(), resultPartitionType, 0, shuffleDescriptorAndIndexArr))).iterator().next();
    }

    private ResultPartition createResultPartition(NettyShuffleEnvironment nettyShuffleEnvironment, ResultPartitionType resultPartitionType, int i) {
        NettyShuffleDescriptor createRemoteWithIdAndLocation = NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation(new IntermediateResultPartitionID(), ResourceID.generate());
        return (ResultPartition) nettyShuffleEnvironment.createResultPartitionWriters(nettyShuffleEnvironment.createShuffleIOOwnerContext("", ExecutionGraphTestUtils.createExecutionAttemptId(), new UnregisteredMetricsGroup()), Collections.singletonList(new ResultPartitionDeploymentDescriptor(new PartitionDescriptor(new IntermediateDataSetID(), 2, createRemoteWithIdAndLocation.getResultPartitionID().getPartitionId(), resultPartitionType, i, 0, false, true, false), createRemoteWithIdAndLocation, 1))).iterator().next();
    }

    private int calculateBuffersConsumption(SingleInputGate singleInputGate) throws Exception {
        singleInputGate.setChannelStateWriter(ChannelStateWriter.NO_OP);
        singleInputGate.finishReadRecoveredState();
        while (!singleInputGate.getStateConsumedFuture().isDone()) {
            singleInputGate.pollNext();
        }
        singleInputGate.convertRecoveredInputChannels();
        int i = 0;
        Iterator it = singleInputGate.inputChannels().iterator();
        while (it.hasNext()) {
            i += ((InputChannel) it.next()).getNumberOfAvailableBuffers();
        }
        return i + singleInputGate.getBufferPool().getMaxNumberOfMemorySegments();
    }

    private int calculateBuffersConsumption(ResultPartition resultPartition) {
        return !resultPartition.getPartitionType().canBePipelinedConsumed() ? resultPartition.getBufferPool().getNumberOfRequiredMemorySegments() : resultPartition.getBufferPool().getMaxNumberOfMemorySegments();
    }
}
