package org.deeplearning4j.spark.api.stats;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.io.FilenameUtils;
import org.apache.spark.SparkContext;
import org.deeplearning4j.spark.stats.EventStats;
import org.deeplearning4j.spark.stats.StatsUtils;

/* loaded from: input_file:org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.class */
public class CommonSparkTrainingStats implements SparkTrainingStats {
    public static final String DEFAULT_DELIMITER = ",";
    public static final String FILENAME_TOTAL_TIME_STATS = "workerFlatMapTotalTimeMs.txt";
    public static final String FILENAME_GET_INITIAL_MODEL_STATS = "workerFlatMapGetInitialModelTimeMs.txt";
    public static final String FILENAME_DATASET_GET_TIME_STATS = "workerFlatMapDataSetGetTimesMs.txt";
    public static final String FILENAME_PROCESS_MINIBATCH_TIME_STATS = "workerFlatMapProcessMiniBatchTimesMs.txt";
    public static final String WORKER_FLAT_MAP_TOTAL_TIME_MS = "WorkerFlatMapTotalTimeMs";
    public static final String WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS = "WorkerFlatMapGetInitialModelTimeMs";
    public static final String WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS = "WorkerFlatMapDataSetGetTimesMs";
    public static final String WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS = "WorkerFlatMapProcessMiniBatchTimesMs";
    private static Set<String> columnNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(WORKER_FLAT_MAP_TOTAL_TIME_MS, WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS, WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS, WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS)));
    private SparkTrainingStats trainingWorkerSpecificStats;
    private List<EventStats> workerFlatMapTotalTimeMs;
    private List<EventStats> workerFlatMapGetInitialModelTimeMs;
    private List<EventStats> workerFlatMapDataSetGetTimesMs;
    private List<EventStats> workerFlatMapProcessMiniBatchTimesMs;

    /* loaded from: input_file:org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats$Builder.class */
    public static class Builder {
        private SparkTrainingStats trainingMasterSpecificStats;
        private List<EventStats> workerFlatMapTotalTimeMs;
        private List<EventStats> workerFlatMapGetInitialModelTimeMs;
        private List<EventStats> workerFlatMapDataSetGetTimesMs;
        private List<EventStats> workerFlatMapProcessMiniBatchTimesMs;

        public Builder trainingMasterSpecificStats(SparkTrainingStats sparkTrainingStats) {
            this.trainingMasterSpecificStats = sparkTrainingStats;
            return this;
        }

        public Builder workerFlatMapTotalTimeMs(List<EventStats> list) {
            this.workerFlatMapTotalTimeMs = list;
            return this;
        }

        public Builder workerFlatMapGetInitialModelTimeMs(List<EventStats> list) {
            this.workerFlatMapGetInitialModelTimeMs = list;
            return this;
        }

        public Builder workerFlatMapDataSetGetTimesMs(List<EventStats> list) {
            this.workerFlatMapDataSetGetTimesMs = list;
            return this;
        }

        public Builder workerFlatMapProcessMiniBatchTimesMs(List<EventStats> list) {
            this.workerFlatMapProcessMiniBatchTimesMs = list;
            return this;
        }

        public CommonSparkTrainingStats build() {
            return new CommonSparkTrainingStats(this);
        }
    }

    public CommonSparkTrainingStats() {
    }

    private CommonSparkTrainingStats(Builder builder) {
        this.trainingWorkerSpecificStats = builder.trainingMasterSpecificStats;
        this.workerFlatMapTotalTimeMs = builder.workerFlatMapTotalTimeMs;
        this.workerFlatMapGetInitialModelTimeMs = builder.workerFlatMapGetInitialModelTimeMs;
        this.workerFlatMapDataSetGetTimesMs = builder.workerFlatMapDataSetGetTimesMs;
        this.workerFlatMapProcessMiniBatchTimesMs = builder.workerFlatMapProcessMiniBatchTimesMs;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public Set<String> getKeySet() {
        LinkedHashSet linkedHashSet = new LinkedHashSet(columnNames);
        if (this.trainingWorkerSpecificStats != null) {
            linkedHashSet.addAll(this.trainingWorkerSpecificStats.getKeySet());
        }
        return linkedHashSet;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public List<EventStats> getValue(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1867284151:
                if (str.equals(WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
            case -902723053:
                if (str.equals(WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS)) {
                    z = 3;
                    break;
                }
                break;
            case -784742702:
                if (str.equals(WORKER_FLAT_MAP_TOTAL_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case -666352983:
                if (str.equals(WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.workerFlatMapTotalTimeMs;
            case true:
                return this.workerFlatMapGetInitialModelTimeMs;
            case true:
                return this.workerFlatMapDataSetGetTimesMs;
            case true:
                return this.workerFlatMapProcessMiniBatchTimesMs;
            default:
                if (this.trainingWorkerSpecificStats != null) {
                    return this.trainingWorkerSpecificStats.getValue(str);
                }
                throw new IllegalArgumentException("Unknown key: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public String getShortNameForKey(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1867284151:
                if (str.equals(WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
            case -902723053:
                if (str.equals(WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS)) {
                    z = 3;
                    break;
                }
                break;
            case -784742702:
                if (str.equals(WORKER_FLAT_MAP_TOTAL_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case -666352983:
                if (str.equals(WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "Total";
            case true:
                return "GetInitModel";
            case true:
                return "GetDataSet";
            case true:
                return "ProcessBatch";
            default:
                if (this.trainingWorkerSpecificStats != null) {
                    return this.trainingWorkerSpecificStats.getShortNameForKey(str);
                }
                throw new IllegalArgumentException("Unknown key: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public boolean defaultIncludeInPlots(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1867284151:
                if (str.equals(WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
            case -902723053:
                if (str.equals(WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
            case -784742702:
                if (str.equals(WORKER_FLAT_MAP_TOTAL_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case -666352983:
                if (str.equals(WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS)) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
            case true:
                return false;
            case true:
                return true;
            default:
                if (this.trainingWorkerSpecificStats != null) {
                    return this.trainingWorkerSpecificStats.defaultIncludeInPlots(str);
                }
                return false;
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public void addOtherTrainingStats(SparkTrainingStats sparkTrainingStats) {
        if (!(sparkTrainingStats instanceof CommonSparkTrainingStats)) {
            throw new IllegalArgumentException("Cannot add other training stats: not an instance of CommonSparkTrainingStats");
        }
        CommonSparkTrainingStats commonSparkTrainingStats = (CommonSparkTrainingStats) sparkTrainingStats;
        this.workerFlatMapTotalTimeMs.addAll(commonSparkTrainingStats.workerFlatMapTotalTimeMs);
        this.workerFlatMapGetInitialModelTimeMs.addAll(commonSparkTrainingStats.workerFlatMapGetInitialModelTimeMs);
        this.workerFlatMapDataSetGetTimesMs.addAll(commonSparkTrainingStats.workerFlatMapDataSetGetTimesMs);
        this.workerFlatMapProcessMiniBatchTimesMs.addAll(commonSparkTrainingStats.workerFlatMapProcessMiniBatchTimesMs);
        if (this.trainingWorkerSpecificStats != null) {
            this.trainingWorkerSpecificStats.addOtherTrainingStats(commonSparkTrainingStats.trainingWorkerSpecificStats);
        } else if (commonSparkTrainingStats.trainingWorkerSpecificStats != null) {
            throw new IllegalStateException("Cannot merge: training master specific stats is null in one, but not the other");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public SparkTrainingStats getNestedTrainingStats() {
        return this.trainingWorkerSpecificStats;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public String statsAsString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, WORKER_FLAT_MAP_TOTAL_TIME_MS));
        if (this.workerFlatMapTotalTimeMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.workerFlatMapTotalTimeMs, ",")).append("\n");
        }
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, WORKER_FLAT_MAP_GET_INITIAL_MODEL_TIME_MS));
        if (this.workerFlatMapGetInitialModelTimeMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.workerFlatMapGetInitialModelTimeMs, ",")).append("\n");
        }
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, WORKER_FLAT_MAP_DATA_SET_GET_TIMES_MS));
        if (this.workerFlatMapDataSetGetTimesMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.workerFlatMapDataSetGetTimesMs, ",")).append("\n");
        }
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, WORKER_FLAT_MAP_PROCESS_MINI_BATCH_TIMES_MS));
        if (this.workerFlatMapProcessMiniBatchTimesMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.workerFlatMapProcessMiniBatchTimesMs, ",")).append("\n");
        }
        if (this.trainingWorkerSpecificStats != null) {
            sb.append(this.trainingWorkerSpecificStats.statsAsString()).append("\n");
        }
        return sb.toString();
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public void exportStatFiles(String str, SparkContext sparkContext) throws IOException {
        StatsUtils.exportStats(this.workerFlatMapTotalTimeMs, FilenameUtils.concat(str, FILENAME_TOTAL_TIME_STATS), ",", sparkContext);
        StatsUtils.exportStats(this.workerFlatMapGetInitialModelTimeMs, FilenameUtils.concat(str, FILENAME_GET_INITIAL_MODEL_STATS), ",", sparkContext);
        StatsUtils.exportStats(this.workerFlatMapDataSetGetTimesMs, FilenameUtils.concat(str, FILENAME_DATASET_GET_TIME_STATS), ",", sparkContext);
        StatsUtils.exportStats(this.workerFlatMapProcessMiniBatchTimesMs, FilenameUtils.concat(str, FILENAME_PROCESS_MINIBATCH_TIME_STATS), ",", sparkContext);
        if (this.trainingWorkerSpecificStats != null) {
            this.trainingWorkerSpecificStats.exportStatFiles(str, sparkContext);
        }
    }

    public SparkTrainingStats getTrainingWorkerSpecificStats() {
        return this.trainingWorkerSpecificStats;
    }

    public List<EventStats> getWorkerFlatMapTotalTimeMs() {
        return this.workerFlatMapTotalTimeMs;
    }

    public List<EventStats> getWorkerFlatMapGetInitialModelTimeMs() {
        return this.workerFlatMapGetInitialModelTimeMs;
    }

    public List<EventStats> getWorkerFlatMapDataSetGetTimesMs() {
        return this.workerFlatMapDataSetGetTimesMs;
    }

    public List<EventStats> getWorkerFlatMapProcessMiniBatchTimesMs() {
        return this.workerFlatMapProcessMiniBatchTimesMs;
    }

    public void setTrainingWorkerSpecificStats(SparkTrainingStats sparkTrainingStats) {
        this.trainingWorkerSpecificStats = sparkTrainingStats;
    }

    public void setWorkerFlatMapTotalTimeMs(List<EventStats> list) {
        this.workerFlatMapTotalTimeMs = list;
    }

    public void setWorkerFlatMapGetInitialModelTimeMs(List<EventStats> list) {
        this.workerFlatMapGetInitialModelTimeMs = list;
    }

    public void setWorkerFlatMapDataSetGetTimesMs(List<EventStats> list) {
        this.workerFlatMapDataSetGetTimesMs = list;
    }

    public void setWorkerFlatMapProcessMiniBatchTimesMs(List<EventStats> list) {
        this.workerFlatMapProcessMiniBatchTimesMs = list;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CommonSparkTrainingStats)) {
            return false;
        }
        CommonSparkTrainingStats commonSparkTrainingStats = (CommonSparkTrainingStats) obj;
        if (!commonSparkTrainingStats.canEqual(this)) {
            return false;
        }
        SparkTrainingStats trainingWorkerSpecificStats = getTrainingWorkerSpecificStats();
        SparkTrainingStats trainingWorkerSpecificStats2 = commonSparkTrainingStats.getTrainingWorkerSpecificStats();
        if (trainingWorkerSpecificStats == null) {
            if (trainingWorkerSpecificStats2 != null) {
                return false;
            }
        } else if (!trainingWorkerSpecificStats.equals(trainingWorkerSpecificStats2)) {
            return false;
        }
        List<EventStats> workerFlatMapTotalTimeMs = getWorkerFlatMapTotalTimeMs();
        List<EventStats> workerFlatMapTotalTimeMs2 = commonSparkTrainingStats.getWorkerFlatMapTotalTimeMs();
        if (workerFlatMapTotalTimeMs == null) {
            if (workerFlatMapTotalTimeMs2 != null) {
                return false;
            }
        } else if (!workerFlatMapTotalTimeMs.equals(workerFlatMapTotalTimeMs2)) {
            return false;
        }
        List<EventStats> workerFlatMapGetInitialModelTimeMs = getWorkerFlatMapGetInitialModelTimeMs();
        List<EventStats> workerFlatMapGetInitialModelTimeMs2 = commonSparkTrainingStats.getWorkerFlatMapGetInitialModelTimeMs();
        if (workerFlatMapGetInitialModelTimeMs == null) {
            if (workerFlatMapGetInitialModelTimeMs2 != null) {
                return false;
            }
        } else if (!workerFlatMapGetInitialModelTimeMs.equals(workerFlatMapGetInitialModelTimeMs2)) {
            return false;
        }
        List<EventStats> workerFlatMapDataSetGetTimesMs = getWorkerFlatMapDataSetGetTimesMs();
        List<EventStats> workerFlatMapDataSetGetTimesMs2 = commonSparkTrainingStats.getWorkerFlatMapDataSetGetTimesMs();
        if (workerFlatMapDataSetGetTimesMs == null) {
            if (workerFlatMapDataSetGetTimesMs2 != null) {
                return false;
            }
        } else if (!workerFlatMapDataSetGetTimesMs.equals(workerFlatMapDataSetGetTimesMs2)) {
            return false;
        }
        List<EventStats> workerFlatMapProcessMiniBatchTimesMs = getWorkerFlatMapProcessMiniBatchTimesMs();
        List<EventStats> workerFlatMapProcessMiniBatchTimesMs2 = commonSparkTrainingStats.getWorkerFlatMapProcessMiniBatchTimesMs();
        return workerFlatMapProcessMiniBatchTimesMs == null ? workerFlatMapProcessMiniBatchTimesMs2 == null : workerFlatMapProcessMiniBatchTimesMs.equals(workerFlatMapProcessMiniBatchTimesMs2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CommonSparkTrainingStats;
    }

    public int hashCode() {
        SparkTrainingStats trainingWorkerSpecificStats = getTrainingWorkerSpecificStats();
        int hashCode = (1 * 59) + (trainingWorkerSpecificStats == null ? 43 : trainingWorkerSpecificStats.hashCode());
        List<EventStats> workerFlatMapTotalTimeMs = getWorkerFlatMapTotalTimeMs();
        int hashCode2 = (hashCode * 59) + (workerFlatMapTotalTimeMs == null ? 43 : workerFlatMapTotalTimeMs.hashCode());
        List<EventStats> workerFlatMapGetInitialModelTimeMs = getWorkerFlatMapGetInitialModelTimeMs();
        int hashCode3 = (hashCode2 * 59) + (workerFlatMapGetInitialModelTimeMs == null ? 43 : workerFlatMapGetInitialModelTimeMs.hashCode());
        List<EventStats> workerFlatMapDataSetGetTimesMs = getWorkerFlatMapDataSetGetTimesMs();
        int hashCode4 = (hashCode3 * 59) + (workerFlatMapDataSetGetTimesMs == null ? 43 : workerFlatMapDataSetGetTimesMs.hashCode());
        List<EventStats> workerFlatMapProcessMiniBatchTimesMs = getWorkerFlatMapProcessMiniBatchTimesMs();
        return (hashCode4 * 59) + (workerFlatMapProcessMiniBatchTimesMs == null ? 43 : workerFlatMapProcessMiniBatchTimesMs.hashCode());
    }

    public String toString() {
        return "CommonSparkTrainingStats(trainingWorkerSpecificStats=" + getTrainingWorkerSpecificStats() + ", workerFlatMapTotalTimeMs=" + getWorkerFlatMapTotalTimeMs() + ", workerFlatMapGetInitialModelTimeMs=" + getWorkerFlatMapGetInitialModelTimeMs() + ", workerFlatMapDataSetGetTimesMs=" + getWorkerFlatMapDataSetGetTimesMs() + ", workerFlatMapProcessMiniBatchTimesMs=" + getWorkerFlatMapProcessMiniBatchTimesMs() + ")";
    }
}
