/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.api.worker;

import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.stats.StatsCalculationHelper;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

public class ExecuteWorkerMultiDataSetFlatMap<R extends TrainingResult>
implements FlatMapFunction<Iterator<MultiDataSet>, R> {
    private final TrainingWorker<R> worker;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Iterator<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
        StatsCalculationHelper s;
        WorkerConfiguration dataConfig = this.worker.getDataConfiguration();
        boolean stats = dataConfig.isCollectTrainingStats();
        StatsCalculationHelper statsCalculationHelper = s = stats ? new StatsCalculationHelper() : null;
        if (stats) {
            s.logMethodStartTime();
        }
        if (!dataSetIterator.hasNext()) {
            if (stats) {
                s.logReturnTime();
            }
            return Collections.emptyIterator();
        }
        int batchSize = dataConfig.getBatchSizePerWorker();
        int prefetchCount = dataConfig.getPrefetchNumBatches();
        IteratorMultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
        if (prefetchCount > 0) {
            batchedIterator = new AsyncMultiDataSetIterator((MultiDataSetIterator)batchedIterator, prefetchCount);
        }
        try {
            int maxMinibatches;
            if (stats) {
                s.logInitialModelBefore();
            }
            ComputationGraph net = this.worker.getInitialModelGraph();
            if (stats) {
                s.logInitialModelAfter();
            }
            int miniBatchCount = 0;
            int n = maxMinibatches = dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE;
            while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
                Object result;
                if (stats) {
                    s.logNextDataSetBefore();
                }
                MultiDataSet next = (MultiDataSet)batchedIterator.next();
                if (stats) {
                    s.logNextDataSetAfter(next.getFeatures(0).size(0));
                }
                if (stats) {
                    s.logProcessMinibatchBefore();
                    result = this.worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                    s.logProcessMinibatchAfter();
                    if (result == null) continue;
                    s.logReturnTime();
                    SparkTrainingStats workerStats = (SparkTrainingStats)result.getSecond();
                    CommonSparkTrainingStats returnStats = s.build(workerStats);
                    ((TrainingResult)result.getFirst()).setStats(returnStats);
                    Iterator<TrainingResult> iterator = Collections.singletonList((TrainingResult)result.getFirst()).iterator();
                    return iterator;
                }
                result = this.worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result == null) continue;
                Iterator<R> iterator = Collections.singletonList(result).iterator();
                return iterator;
            }
            if (stats) {
                s.logReturnTime();
                Pair<R, SparkTrainingStats> pair = this.worker.getFinalResultWithStats(net);
                ((TrainingResult)pair.getFirst()).setStats(s.build((SparkTrainingStats)pair.getSecond()));
                Iterator<TrainingResult> iterator = Collections.singletonList((TrainingResult)pair.getFirst()).iterator();
                return iterator;
            }
            Iterator<R> iterator = Collections.singletonList(this.worker.getFinalResult(net)).iterator();
            return iterator;
        }
        finally {
            Nd4j.getExecutioner().commit();
            if (batchedIterator instanceof AsyncMultiDataSetIterator) {
                ((AsyncMultiDataSetIterator)batchedIterator).shutdown();
            }
        }
    }

    public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker<R> worker) {
        this.worker = worker;
    }
}

