package org.deeplearning4j.spark.parameterserver.functions;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.Collections;
import java.util.Iterator;
import org.apache.commons.io.LineIterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.datavec.spark.util.SerializableHadoopConfig;
import org.deeplearning4j.core.loader.DataSetLoader;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.iterator.PathSparkDataSetIterator;
import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.class */
public class SharedFlatMapPaths<R extends TrainingResult> implements FlatMapFunction<Iterator<String>, R> {
    public static Configuration defaultConfig;
    protected final SharedTrainingWorker worker;
    protected final DataSetLoader loader;
    protected final Broadcast<SerializableHadoopConfig> hadoopConfig;

    public static File toTempFile(Iterator<String> it) throws IOException {
        File file = Files.createTempFile("SharedFlatMapPaths", ".txt", new FileAttribute[0]).toFile();
        file.deleteOnExit();
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        while (it.hasNext()) {
            try {
                bufferedWriter.write(it.next());
                bufferedWriter.write("\n");
            } catch (Throwable th) {
                try {
                    bufferedWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        bufferedWriter.close();
        return file;
    }

    public SharedFlatMapPaths(TrainingWorker<R> trainingWorker, DataSetLoader dataSetLoader, Broadcast<SerializableHadoopConfig> broadcast) {
        this.worker = (SharedTrainingWorker) trainingWorker;
        this.loader = dataSetLoader;
        this.hadoopConfig = broadcast;
    }

    public Iterator<R> call(Iterator<String> it) throws Exception {
        if (!it.hasNext()) {
            return Collections.emptyIterator();
        }
        File tempFile = toTempFile(it);
        LineIterator lineIterator = new LineIterator(new FileReader(tempFile));
        try {
            SharedTrainingWrapper.getInstance(this.worker.getInstanceId()).attachDS(new PathSparkDataSetIterator(lineIterator, this.loader, this.hadoopConfig));
            Iterator<R> it2 = Collections.singletonList(SharedTrainingWrapper.getInstance(this.worker.getInstanceId()).run(this.worker)).iterator();
            lineIterator.close();
            tempFile.delete();
            return it2;
        } catch (Throwable th) {
            lineIterator.close();
            tempFile.delete();
            throw th;
        }
    }
}
