/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.graph.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScoreExamplesFunction
implements DoubleFlatMapFunction<Iterator<org.nd4j.linalg.dataset.api.MultiDataSet>> {
    private static final Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final boolean addRegularization;
    private final int batchSize;

    public ScoreExamplesFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean addRegularizationTerms, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.addRegularization = addRegularizationTerms;
        this.batchSize = batchSize;
    }

    public Iterator<Double> call(Iterator<org.nd4j.linalg.dataset.api.MultiDataSet> iterator) throws Exception {
        if (!iterator.hasNext()) {
            return Collections.emptyIterator();
        }
        ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).dup();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParams(val);
        ArrayList<Double> ret = new ArrayList<Double>();
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> collect = new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>(this.batchSize);
        int totalCount = 0;
        while (iterator.hasNext()) {
            double[] doubleScores;
            collect.clear();
            int nExamples = 0;
            while (iterator.hasNext() && nExamples < this.batchSize) {
                org.nd4j.linalg.dataset.api.MultiDataSet ds = iterator.next();
                long n = ds.getFeatures(0).size(0);
                collect.add(ds);
                nExamples = (int)((long)nExamples + n);
            }
            totalCount += nExamples;
            MultiDataSet data = MultiDataSet.merge(collect);
            INDArray scores = network.scoreExamples((org.nd4j.linalg.dataset.api.MultiDataSet)data, this.addRegularization);
            for (double doubleScore : doubleScores = scores.data().asDouble()) {
                ret.add(doubleScore);
            }
        }
        Nd4j.getExecutioner().commit();
        if (log.isDebugEnabled()) {
            log.debug("Scored {} examples ", (Object)totalCount);
        }
        return ret.iterator();
    }
}

