package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/spark/datavec/RDDMiniBatches.class */
public class RDDMiniBatches implements Serializable {
    private int miniBatches;
    private JavaRDD<DataSet> toSplitJava;

    /* loaded from: input_file:org/deeplearning4j/spark/datavec/RDDMiniBatches$MiniBatchFunction.class */
    public static class MiniBatchFunction implements FlatMapFunction<Iterator<DataSet>, DataSet> {
        private int batchSize;

        public Iterator<DataSet> call(Iterator<DataSet> it) throws Exception {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            while (it.hasNext()) {
                arrayList2.add(it.next().copy());
                if (arrayList2.size() == this.batchSize) {
                    arrayList.add(DataSet.merge(arrayList2));
                    arrayList2.clear();
                }
            }
            if (arrayList2.size() > 0) {
                arrayList.add(DataSet.merge(arrayList2));
            }
            return arrayList.iterator();
        }

        public MiniBatchFunction(int i) {
            this.batchSize = i;
        }
    }

    public RDDMiniBatches(int i, JavaRDD<DataSet> javaRDD) {
        this.miniBatches = i;
        this.toSplitJava = javaRDD;
    }

    public JavaRDD<DataSet> miniBatchesJava() {
        return this.toSplitJava.mapPartitions(new MiniBatchFunction(this.miniBatches));
    }
}
