/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.graph.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class GraphFeedForwardWithKeyFunction<K>
implements PairFlatMapFunction<Iterator<Tuple2<K, INDArray[]>>, K, INDArray[]> {
    private static final Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final int batchSize;

    public Iterator<Tuple2<K, INDArray[]>> call(Iterator<Tuple2<K, INDArray[]>> iterator) throws Exception {
        int firstIdx;
        int nextIdx;
        if (!iterator.hasNext()) {
            return Collections.emptyIterator();
        }
        ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).dup();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParams(val);
        ArrayList<INDArray[]> featuresList = new ArrayList<INDArray[]>(this.batchSize);
        ArrayList<Object> keyList = new ArrayList<Object>(this.batchSize);
        ArrayList<Long> origSizeList = new ArrayList<Long>();
        long[][] firstShapes = null;
        boolean sizesDiffer = false;
        int tupleCount = 0;
        while (iterator.hasNext()) {
            int i;
            Tuple2<K, INDArray[]> t2 = iterator.next();
            if (firstShapes == null) {
                firstShapes = new long[((INDArray[])t2._2()).length][0];
                for (i = 0; i < firstShapes.length; ++i) {
                    firstShapes[i] = ((INDArray[])t2._2())[i].shape();
                }
            } else if (!sizesDiffer) {
                block2: for (i = 0; i < firstShapes.length; ++i) {
                    for (int j = 1; j < firstShapes[i].length; ++j) {
                        if (firstShapes[i][j] == ((INDArray[])featuresList.get(tupleCount - 1))[i].size(j)) continue;
                        sizesDiffer = true;
                        continue block2;
                    }
                }
            }
            featuresList.add((INDArray[])t2._2());
            keyList.add(t2._1());
            origSizeList.add(((INDArray[])t2._2())[0].size(0));
            ++tupleCount;
        }
        if (tupleCount == 0) {
            return Collections.emptyIterator();
        }
        ArrayList<Tuple2> output = new ArrayList<Tuple2>(tupleCount);
        for (int currentArrayIndex = 0; currentArrayIndex < featuresList.size(); currentArrayIndex += nextIdx - firstIdx) {
            int j;
            int i;
            int i2;
            firstIdx = currentArrayIndex;
            int examplesInBatch = 0;
            ArrayList<INDArray[]> toMerge = new ArrayList<INDArray[]>();
            firstShapes = null;
            for (nextIdx = currentArrayIndex; nextIdx < featuresList.size() && examplesInBatch < this.batchSize; ++nextIdx) {
                INDArray[] f = (INDArray[])featuresList.get(nextIdx);
                if (firstShapes == null) {
                    firstShapes = new long[f.length][0];
                    for (i2 = 0; i2 < firstShapes.length; ++i2) {
                        firstShapes[i2] = f[i2].shape();
                    }
                } else if (sizesDiffer) {
                    boolean breakWhile = false;
                    block7: for (i = 0; i < firstShapes.length; ++i) {
                        for (j = 1; j < firstShapes[i].length; ++j) {
                            if (firstShapes[i][j] == ((INDArray[])featuresList.get(nextIdx))[i].size(j)) continue;
                            breakWhile = true;
                            continue block7;
                        }
                    }
                    if (breakWhile) break;
                }
                toMerge.add(f);
                examplesInBatch = (int)((long)examplesInBatch + f[0].size(0));
            }
            INDArray[] batchFeatures = new INDArray[((INDArray[])toMerge.get(0)).length];
            for (i2 = 0; i2 < batchFeatures.length; ++i2) {
                INDArray[] tempArr = new INDArray[toMerge.size()];
                for (j = 0; j < tempArr.length; ++j) {
                    tempArr[j] = ((INDArray[])toMerge.get(j))[i2];
                }
                batchFeatures[i2] = Nd4j.concat((int)0, (INDArray[])tempArr);
            }
            INDArray[] out = network.output(false, batchFeatures);
            examplesInBatch = 0;
            for (i = firstIdx; i < nextIdx; ++i) {
                long numExamples = (Long)origSizeList.get(i);
                INDArray[] outSubset = new INDArray[out.length];
                for (int j2 = 0; j2 < out.length; ++j2) {
                    outSubset[j2] = this.getSubset(examplesInBatch, (long)examplesInBatch + numExamples, out[j2]);
                }
                examplesInBatch = (int)((long)examplesInBatch + numExamples);
                output.add(new Tuple2(keyList.get(i), (Object)outSubset));
            }
        }
        Nd4j.getExecutioner().commit();
        return output.iterator();
    }

    private INDArray getSubset(long exampleStart, long exampleEnd, INDArray from) {
        switch (from.rank()) {
            case 2: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((long)exampleStart, (long)exampleEnd), NDArrayIndex.all()});
            }
            case 3: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((long)exampleStart, (long)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all()});
            }
            case 4: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((long)exampleStart, (long)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            }
        }
        throw new RuntimeException("Invalid rank: " + from.rank());
    }

    public GraphFeedForwardWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.batchSize = batchSize;
    }
}

