/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec.iterator;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.spark.datavec.iterator.DataVecRecord;
import org.deeplearning4j.spark.datavec.iterator.DataVecRecords;
import org.deeplearning4j.spark.datavec.iterator.RRMDSIFunction;
import org.deeplearning4j.spark.datavec.iterator.SparkSourceDummyReader;
import org.deeplearning4j.spark.datavec.iterator.SparkSourceDummySeqReader;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import scala.Tuple2;

public class IteratorUtils {
    public static JavaRDD<MultiDataSet> mapRRMDSI(JavaRDD<List<Writable>> rdd, RecordReaderMultiDataSetIterator iterator) {
        IteratorUtils.checkIterator(iterator, 1, 0);
        return IteratorUtils.mapRRMDSIRecords((JavaRDD<DataVecRecords>)rdd.map((Function)new Function<List<Writable>, DataVecRecords>(){

            public DataVecRecords call(List<Writable> v1) throws Exception {
                return new DataVecRecords(Collections.singletonList(v1), null);
            }
        }), iterator);
    }

    public static JavaRDD<MultiDataSet> mapRRMDSISeq(JavaRDD<List<List<Writable>>> rdd, RecordReaderMultiDataSetIterator iterator) {
        IteratorUtils.checkIterator(iterator, 0, 1);
        return IteratorUtils.mapRRMDSIRecords((JavaRDD<DataVecRecords>)rdd.map((Function)new Function<List<List<Writable>>, DataVecRecords>(){

            public DataVecRecords call(List<List<Writable>> v1) throws Exception {
                return new DataVecRecords(null, Collections.singletonList(v1));
            }
        }), iterator);
    }

    public static JavaRDD<MultiDataSet> mapRRMDSI(List<JavaRDD<List<Writable>>> rdds, List<JavaRDD<List<List<Writable>>>> seqRdds, int[] rddsKeyColumns, int[] seqRddsKeyColumns, boolean filterMissing, RecordReaderMultiDataSetIterator iterator) {
        JavaPairRDD currPairs;
        Object rdd;
        int i;
        IteratorUtils.checkIterator(iterator, rdds == null ? 0 : rdds.size(), seqRdds == null ? 0 : seqRdds.size());
        IteratorUtils.assertNullOrSameLength(rdds, rddsKeyColumns, false);
        IteratorUtils.assertNullOrSameLength(seqRdds, seqRddsKeyColumns, true);
        if ((rdds == null || rdds.isEmpty()) && (seqRdds == null || seqRdds.isEmpty())) {
            throw new IllegalArgumentException();
        }
        JavaPairRDD allPairs = null;
        if (rdds != null) {
            for (i = 0; i < rdds.size(); ++i) {
                rdd = rdds.get(i);
                currPairs = rdd.mapToPair((PairFunction)new MapToPairFn(i, rddsKeyColumns[i]));
                allPairs = allPairs == null ? currPairs : allPairs.union(currPairs);
            }
        }
        if (seqRdds != null) {
            for (i = 0; i < seqRdds.size(); ++i) {
                rdd = seqRdds.get(i);
                currPairs = rdd.mapToPair((PairFunction)new MapToPairSeqFn(i, seqRddsKeyColumns[i]));
                allPairs = allPairs == null ? currPairs : allPairs.union(currPairs);
            }
        }
        int expNumRec = rddsKeyColumns == null ? 0 : rddsKeyColumns.length;
        int expNumSeqRec = seqRddsKeyColumns == null ? 0 : seqRddsKeyColumns.length;
        JavaPairRDD grouped = allPairs.groupByKey();
        if (filterMissing) {
            grouped = grouped.filter((Function)new FilterMissingFn(expNumRec, expNumSeqRec));
        }
        JavaRDD combined = grouped.map((Function)new CombineFunction(expNumRec, expNumSeqRec));
        return IteratorUtils.mapRRMDSIRecords((JavaRDD<DataVecRecords>)combined, iterator);
    }

    private static void assertNullOrSameLength(List<?> list, int[] arr, boolean isSeq) {
        if (list != null && arr == null) {
            throw new IllegalStateException();
        }
        if (list == null && arr != null && arr.length > 0) {
            throw new IllegalStateException();
        }
        if (list != null && list.size() != arr.length) {
            throw new IllegalStateException();
        }
    }

    public static JavaRDD<MultiDataSet> mapRRMDSIRecords(JavaRDD<DataVecRecords> rdd, RecordReaderMultiDataSetIterator iterator) {
        return rdd.map((Function)new RRMDSIFunction(iterator));
    }

    private static void checkIterator(RecordReaderMultiDataSetIterator iterator, int maxReaders, int maxSeqReaders) {
        Map rrs = iterator.getRecordReaders();
        Map seqRRs = iterator.getSequenceRecordReaders();
        if (rrs != null && rrs.size() > maxReaders) {
            throw new IllegalStateException("Invalid state: iterator has " + rrs.size() + " readers but " + maxReaders + " RDDs of List<Writable> were provided");
        }
        if (seqRRs != null && seqRRs.size() > maxSeqReaders) {
            throw new IllegalStateException("Invalid state: iterator has " + seqRRs.size() + " sequence readers but " + maxSeqReaders + " RDDs of sequences - List<List<Writable>> were provided");
        }
        if (rrs != null && rrs.size() > 0) {
            for (Map.Entry e : rrs.entrySet()) {
                if (e.getValue() instanceof SparkSourceDummyReader) continue;
                throw new IllegalStateException("Invalid state: expected SparkSourceDummyReader for reader with name \"" + (String)e.getKey() + "\", but got reader type: " + ((String)e.getKey()).getClass());
            }
        }
        if (seqRRs != null && seqRRs.size() > 0) {
            for (Map.Entry e : seqRRs.entrySet()) {
                if (e.getValue() instanceof SparkSourceDummySeqReader) continue;
                throw new IllegalStateException("Invalid state: expected SparkSourceDummySeqReader for sequence reader with name \"" + (String)e.getKey() + "\", but got reader type: " + ((String)e.getKey()).getClass());
            }
        }
    }

    private static class FilterMissingFn
    implements Function<Tuple2<Writable, Iterable<DataVecRecord>>, Boolean> {
        private final int expNumRec;
        private final int expNumSeqRec;
        private transient ThreadLocal<Set<Integer>> recIdxs;
        private transient ThreadLocal<Set<Integer>> seqRecIdxs;

        private FilterMissingFn(int expNumRec, int expNumSeqRec) {
            this.expNumRec = expNumRec;
            this.expNumSeqRec = expNumSeqRec;
        }

        public Boolean call(Tuple2<Writable, Iterable<DataVecRecord>> iter) throws Exception {
            Set<Integer> sri;
            Set<Integer> ri;
            if (this.recIdxs == null) {
                this.recIdxs = new ThreadLocal();
            }
            if (this.seqRecIdxs == null) {
                this.seqRecIdxs = new ThreadLocal();
            }
            if ((ri = this.recIdxs.get()) == null) {
                ri = new HashSet<Integer>();
                this.recIdxs.set(ri);
            }
            if ((sri = this.seqRecIdxs.get()) == null) {
                sri = new HashSet<Integer>();
                this.seqRecIdxs.set(sri);
            }
            for (DataVecRecord r : (Iterable)iter._2()) {
                if (r.getRecord() != null) {
                    ri.add(r.getReaderIdx());
                    continue;
                }
                if (r.getSeqRecord() == null) continue;
                sri.add(r.getReaderIdx());
            }
            int count = ri.size();
            int count2 = sri.size();
            ri.clear();
            sri.clear();
            return count == this.expNumRec && count2 == this.expNumSeqRec;
        }

        public FilterMissingFn(int expNumRec, int expNumSeqRec, ThreadLocal<Set<Integer>> recIdxs, ThreadLocal<Set<Integer>> seqRecIdxs) {
            this.expNumRec = expNumRec;
            this.expNumSeqRec = expNumSeqRec;
            this.recIdxs = recIdxs;
            this.seqRecIdxs = seqRecIdxs;
        }
    }

    private static class CombineFunction
    implements Function<Tuple2<Writable, Iterable<DataVecRecord>>, DataVecRecords> {
        private int expNumRecords;
        private int expNumSeqRecords;

        public DataVecRecords call(Tuple2<Writable, Iterable<DataVecRecord>> all) throws Exception {
            List[] allRecordsArr = null;
            if (this.expNumRecords > 0) {
                allRecordsArr = new List[this.expNumRecords];
            }
            List[] allRecordsSeqArr = null;
            if (this.expNumSeqRecords > 0) {
                allRecordsSeqArr = new List[this.expNumSeqRecords];
            }
            for (DataVecRecord rec : (Iterable)all._2()) {
                if (rec.getRecord() != null) {
                    allRecordsArr[rec.getReaderIdx()] = rec.getRecord();
                    continue;
                }
                allRecordsSeqArr[rec.getReaderIdx()] = rec.getSeqRecord();
            }
            if (allRecordsArr != null) {
                for (int i = 0; i < allRecordsArr.length; ++i) {
                    if (allRecordsArr[i] != null) continue;
                    throw new IllegalStateException("Encountered null records for input index " + i);
                }
            }
            if (allRecordsSeqArr != null) {
                for (int i = 0; i < allRecordsSeqArr.length; ++i) {
                    if (allRecordsSeqArr[i] != null) continue;
                    throw new IllegalStateException("Encountered null sequence records for input index " + i);
                }
            }
            List<List<Writable>> r = allRecordsArr == null ? null : Arrays.asList(allRecordsArr);
            List<List<List<Writable>>> sr = allRecordsSeqArr == null ? null : Arrays.asList(allRecordsSeqArr);
            return new DataVecRecords(r, sr);
        }

        public CombineFunction(int expNumRecords, int expNumSeqRecords) {
            this.expNumRecords = expNumRecords;
            this.expNumSeqRecords = expNumSeqRecords;
        }
    }

    private static class MapToPairSeqFn
    implements PairFunction<List<List<Writable>>, Writable, DataVecRecord> {
        private int readerIdx;
        private int keyIndex;

        public Tuple2<Writable, DataVecRecord> call(List<List<Writable>> seq) throws Exception {
            if (seq.isEmpty()) {
                throw new IllegalStateException("Sequence of length 0 encountered");
            }
            return new Tuple2((Object)seq.get(0).get(this.keyIndex), (Object)new DataVecRecord(this.readerIdx, null, seq));
        }

        public MapToPairSeqFn(int readerIdx, int keyIndex) {
            this.readerIdx = readerIdx;
            this.keyIndex = keyIndex;
        }
    }

    private static class MapToPairFn
    implements PairFunction<List<Writable>, Writable, DataVecRecord> {
        private int readerIdx;
        private int keyIndex;

        public Tuple2<Writable, DataVecRecord> call(List<Writable> writables) throws Exception {
            return new Tuple2((Object)writables.get(this.keyIndex), (Object)new DataVecRecord(this.readerIdx, writables, null));
        }

        public MapToPairFn(int readerIdx, int keyIndex) {
            this.readerIdx = readerIdx;
            this.keyIndex = keyIndex;
        }
    }
}

