/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.api.worker;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.datavec.spark.util.SerializableHadoopConfig;
import org.deeplearning4j.core.loader.MultiDataSetLoader;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.iterator.PathSparkMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public class ExecuteWorkerPathMDSFlatMap<R extends TrainingResult>
implements FlatMapFunction<Iterator<String>, R> {
    private final FlatMapFunction<Iterator<MultiDataSet>, R> workerFlatMap;
    private MultiDataSetLoader loader;
    private final int maxDataSetObjects;
    private final Broadcast<SerializableHadoopConfig> hadoopConfig;

    public ExecuteWorkerPathMDSFlatMap(TrainingWorker<R> worker, MultiDataSetLoader loader, Broadcast<SerializableHadoopConfig> hadoopConfig) {
        this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<R>(worker);
        this.loader = loader;
        this.hadoopConfig = hadoopConfig;
        WorkerConfiguration conf = worker.getDataConfiguration();
        int dataSetObjectNumExamples = conf.getDataSetObjectSizeExamples();
        int workerMinibatchSize = conf.getBatchSizePerWorker();
        int maxMinibatches = conf.getMaxBatchesPerWorker() > 0 ? conf.getMaxBatchesPerWorker() : Integer.MAX_VALUE;
        this.maxDataSetObjects = maxMinibatches == Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)Math.ceil((double)(maxMinibatches * workerMinibatchSize) / (double)dataSetObjectNumExamples);
    }

    public Iterator<R> call(Iterator<String> iter) throws Exception {
        ArrayList<String> list = new ArrayList<String>();
        int count = 0;
        while (iter.hasNext() && count++ < this.maxDataSetObjects) {
            list.add(iter.next());
        }
        return this.workerFlatMap.call((Object)new PathSparkMultiDataSetIterator(list.iterator(), this.loader, this.hadoopConfig));
    }
}

