/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.arrow.recordreader;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.listener.RecordListener;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataIndex;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.datavec.arrow.recordreader.ArrowRecord;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ArrowRecordReader
implements RecordReader {
    private static final Logger log = LoggerFactory.getLogger(ArrowRecordReader.class);
    private InputSplit split;
    private Configuration configuration;
    private Iterator<String> pathsIter;
    private int currIdx;
    private String currentPath;
    private Schema schema;
    private List<Writable> recordAllocation = new ArrayList<Writable>();
    private ArrowWritableRecordBatch currentBatch;
    private List<RecordListener> recordListeners;

    public void initialize(InputSplit split) {
        this.split = split;
        this.pathsIter = split.locationsPathIterator();
    }

    public void initialize(Configuration conf, InputSplit split) {
        this.split = split;
        this.pathsIter = split.locationsPathIterator();
    }

    public boolean batchesSupported() {
        return true;
    }

    public List<List<Writable>> next(int num) {
        if (this.currentBatch == null || this.currIdx >= this.currentBatch.size()) {
            this.loadNextBatch();
        }
        if (num == this.currentBatch.getArrowRecordBatch().getLength()) {
            this.currIdx += num;
            return this.currentBatch;
        }
        ArrayList<List<Writable>> ret = new ArrayList<List<Writable>>(num);
        int numBatches = 0;
        while (this.hasNext() && numBatches < num) {
            ret.add(this.next());
        }
        return ret;
    }

    public List<Writable> next() {
        if (this.currentBatch == null || this.currIdx >= this.currentBatch.size()) {
            this.loadNextBatch();
        } else {
            this.recordAllocation = this.currentBatch.get(this.currIdx++);
        }
        return this.recordAllocation;
    }

    private void loadNextBatch() {
        String url = this.pathsIter.next();
        try (InputStream inputStream = this.split.openInputStreamFor(url);){
            this.currIdx = 0;
            byte[] arr = IOUtils.toByteArray((InputStream)inputStream);
            Pair<Schema, ArrowWritableRecordBatch> read = ArrowConverter.readFromBytes(arr);
            if (this.schema == null) {
                this.schema = (Schema)read.getFirst();
            }
            this.currentBatch = (ArrowWritableRecordBatch)read.getRight();
            this.recordAllocation = this.currentBatch.get(0);
            ++this.currIdx;
            this.currentPath = url;
        }
        catch (Exception e) {
            log.error("", (Throwable)e);
        }
    }

    public boolean hasNext() {
        return this.pathsIter.hasNext() || this.currIdx < this.currentBatch.size();
    }

    public List<String> getLabels() {
        throw new UnsupportedOperationException();
    }

    public void reset() {
        if (this.split != null) {
            this.split.reset();
        }
    }

    public boolean resetSupported() {
        return true;
    }

    public List<Writable> record(URI uri, DataInputStream dataInputStream) {
        throw new UnsupportedOperationException();
    }

    public Record nextRecord() {
        this.next();
        ArrowRecord ret = new ArrowRecord(this.currentBatch, this.currIdx - 1, URI.create(this.currentPath));
        return ret;
    }

    public Record loadFromMetaData(RecordMetaData recordMetaData) {
        if (!(recordMetaData instanceof RecordMetaDataIndex)) {
            throw new IllegalArgumentException("Unable to load from meta data. No index specified for record");
        }
        RecordMetaDataIndex index = (RecordMetaDataIndex)recordMetaData;
        FileSplit fileSplit = new FileSplit(new File(index.getURI()));
        this.initialize((InputSplit)fileSplit);
        this.currIdx = (int)index.getIndex();
        return this.nextRecord();
    }

    public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) {
        HashMap<String, ArrayList<RecordMetaData>> metaDataByUri = new HashMap<String, ArrayList<RecordMetaData>>();
        for (RecordMetaData recordMetaData : recordMetaDatas) {
            if (!(recordMetaData instanceof RecordMetaDataIndex)) {
                throw new IllegalArgumentException("Unable to load from meta data. No index specified for record");
            }
            ArrayList<RecordMetaData> recordMetaData1 = (ArrayList<RecordMetaData>)metaDataByUri.get(recordMetaData.getURI().toString());
            if (recordMetaData1 == null) {
                recordMetaData1 = new ArrayList<RecordMetaData>();
                metaDataByUri.put(recordMetaData.getURI().toString(), recordMetaData1);
            }
            recordMetaData1.add(recordMetaData);
        }
        ArrayList<Record> ret = new ArrayList<Record>();
        for (String uri : metaDataByUri.keySet()) {
            List metaData = (List)metaDataByUri.get(uri);
            FileSplit fileSplit = new FileSplit(new File(URI.create(uri)));
            this.initialize((InputSplit)fileSplit);
            for (RecordMetaData index : metaData) {
                RecordMetaDataIndex index2 = (RecordMetaDataIndex)index;
                this.currIdx = (int)index2.getIndex();
                ret.add(this.nextRecord());
            }
        }
        return ret;
    }

    public List<RecordListener> getListeners() {
        return this.recordListeners;
    }

    public void setListeners(RecordListener ... listeners) {
        this.recordListeners = new ArrayList<RecordListener>(Arrays.asList(listeners));
    }

    public void setListeners(Collection<RecordListener> listeners) {
        this.recordListeners = new ArrayList<RecordListener>(listeners);
    }

    public void close() {
        if (this.currentBatch != null) {
            try {
                this.currentBatch.close();
            }
            catch (IOException e) {
                log.error("", (Throwable)e);
            }
        }
    }

    public void setConf(Configuration conf) {
        this.configuration = conf;
    }

    public Configuration getConf() {
        return this.configuration;
    }

    public ArrowWritableRecordBatch getCurrentBatch() {
        return this.currentBatch;
    }
}

