/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class FeedForwardWithKeyFunction<K>
implements PairFlatMapFunction<Iterator<Tuple2<K, Tuple2<INDArray, INDArray>>>, K, INDArray> {
    protected static Logger log = LoggerFactory.getLogger(FeedForwardWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final int batchSize;

    public FeedForwardWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.batchSize = batchSize;
    }

    public Iterator<Tuple2<K, INDArray>> call(Iterator<Tuple2<K, Tuple2<INDArray, INDArray>>> iterator) throws Exception {
        int firstIdx;
        int nextIdx;
        if (!iterator.hasNext()) {
            return Collections.emptyIterator();
        }
        MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).dup();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
        }
        network.setParameters(val);
        ArrayList<INDArray> featuresList = new ArrayList<INDArray>(this.batchSize);
        ArrayList<INDArray> fMaskList = new ArrayList<INDArray>(this.batchSize);
        ArrayList<Object> keyList = new ArrayList<Object>(this.batchSize);
        ArrayList<Integer> origSizeList = new ArrayList<Integer>();
        long[] firstShape = null;
        boolean sizesDiffer = false;
        int tupleCount = 0;
        while (iterator.hasNext()) {
            Tuple2<K, Tuple2<INDArray, INDArray>> t2 = iterator.next();
            if (firstShape == null) {
                firstShape = ((INDArray)((Tuple2)t2._2())._1()).shape();
            } else if (!sizesDiffer) {
                for (int i = 1; i < firstShape.length; ++i) {
                    if (firstShape[i] == ((INDArray)featuresList.get(tupleCount - 1)).size(i)) continue;
                    sizesDiffer = true;
                    break;
                }
            }
            featuresList.add((INDArray)((Tuple2)t2._2())._1());
            fMaskList.add((INDArray)((Tuple2)t2._2())._2());
            keyList.add(t2._1());
            origSizeList.add((int)((INDArray)((Tuple2)t2._2())._1()).size(0));
            ++tupleCount;
        }
        if (tupleCount == 0) {
            return Collections.emptyIterator();
        }
        ArrayList<Tuple2> output = new ArrayList<Tuple2>(tupleCount);
        for (int currentArrayIndex = 0; currentArrayIndex < featuresList.size(); currentArrayIndex += nextIdx - firstIdx) {
            firstIdx = currentArrayIndex;
            int examplesInBatch = 0;
            ArrayList<INDArray> toMerge = new ArrayList<INDArray>();
            ArrayList<INDArray> toMergeMask = new ArrayList<INDArray>();
            firstShape = null;
            for (nextIdx = currentArrayIndex; nextIdx < featuresList.size() && examplesInBatch < this.batchSize; ++nextIdx) {
                if (firstShape == null) {
                    firstShape = ((INDArray)featuresList.get(nextIdx)).shape();
                } else if (sizesDiffer) {
                    boolean breakWhile = false;
                    for (int i = 1; i < firstShape.length; ++i) {
                        if (firstShape[i] == ((INDArray)featuresList.get(nextIdx)).size(i)) continue;
                        breakWhile = true;
                        break;
                    }
                    if (breakWhile) break;
                }
                INDArray f = (INDArray)featuresList.get(nextIdx);
                INDArray fm = (INDArray)fMaskList.get(nextIdx);
                toMerge.add(f);
                toMergeMask.add(fm);
                examplesInBatch = (int)((long)examplesInBatch + f.size(0));
            }
            Pair p = DataSetUtil.mergeFeatures((INDArray[])toMerge.toArray(new INDArray[toMerge.size()]), (INDArray[])toMergeMask.toArray(new INDArray[toMergeMask.size()]));
            INDArray out = network.output((INDArray)p.getFirst(), false, (INDArray)p.getSecond(), null);
            examplesInBatch = 0;
            for (int i = firstIdx; i < nextIdx; ++i) {
                int numExamples = (Integer)origSizeList.get(i);
                INDArray outputSubset = this.getSubset(examplesInBatch, examplesInBatch + numExamples, out);
                examplesInBatch += numExamples;
                output.add(new Tuple2(keyList.get(i), (Object)outputSubset));
            }
        }
        Nd4j.getExecutioner().commit();
        return output.iterator();
    }

    private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) {
        switch (from.rank()) {
            case 2: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all()});
            }
            case 3: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all()});
            }
            case 4: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            }
        }
        throw new RuntimeException("Invalid rank: " + from.rank());
    }
}

