/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.repartitioner;

import java.util.List;
import java.util.Random;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.spark.api.Repartitioner;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.repartition.EqualPartitioner;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.common.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class EqualRepartitioner
implements Repartitioner {
    private static final Logger log = LoggerFactory.getLogger(EqualRepartitioner.class);

    @Override
    public <T> JavaRDD<T> repartition(JavaRDD<T> rdd, int minObjectsPerPartition, int numExecutors) {
        List partitionCounts = rdd.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
        return EqualRepartitioner.repartition(rdd, numExecutors, partitionCounts);
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> rdd, int numPartitions, List<Tuple2<Integer, Integer>> partitionCounts) {
        int totalObjects = 0;
        int initialPartitions = partitionCounts.size();
        for (Tuple2<Integer, Integer> t2 : partitionCounts) {
            totalObjects += ((Integer)t2._2()).intValue();
        }
        int minAllowable = (int)Math.floor((double)totalObjects / (double)numPartitions);
        int maxAllowable = (int)Math.ceil((double)totalObjects / (double)numPartitions);
        boolean repartitionRequired = false;
        for (Tuple2<Integer, Integer> t2 : partitionCounts) {
            if ((Integer)t2._2() >= minAllowable && (Integer)t2._2() <= maxAllowable) continue;
            repartitionRequired = true;
            break;
        }
        if (initialPartitions == numPartitions && !repartitionRequired) {
            return rdd;
        }
        JavaPairRDD pairIndexed = SparkUtils.indexedRDD(rdd);
        int remainder = totalObjects % numPartitions;
        int[] remainderPartitions = null;
        if (remainder > 0) {
            remainderPartitions = new int[remainder];
            int[] temp = new int[numPartitions];
            for (int i = 0; i < temp.length; ++i) {
                temp[i] = i;
            }
            MathUtils.shuffleArray((int[])temp, (Random)new Random());
            System.arraycopy(temp, 0, remainderPartitions, 0, remainder);
        }
        int partitionSizeExRemainder = totalObjects / numPartitions;
        pairIndexed = pairIndexed.partitionBy((Partitioner)new EqualPartitioner(numPartitions, partitionSizeExRemainder, remainderPartitions));
        return pairIndexed.values();
    }
}

