package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.collections.map.ListOrderedMap;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.reduce.MapToPairForReducerFunction;

/* loaded from: input_file:org/datavec/spark/transform/Normalization.class */
public class Normalization {
    public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> dataset) {
        return zeromeanUnitVariance(dataset, (List<String>) Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return zeromeanUnitVariance(schema, javaRDD, Collections.emptyList());
    }

    public static Dataset<Row> normalize(Dataset<Row> dataset, double d, double d2) {
        return normalize(dataset, d, d2, (List<String>) Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> javaRDD, double d, double d2) {
        return (JavaRDD) DataFrames.toRecords(normalize(DataFrames.toDataFrame(schema, javaRDD), d, d2, (List<String>) Collections.emptyList())).getSecond();
    }

    public static Dataset<Row> normalize(Dataset<Row> dataset) {
        return normalize(dataset, 0.0d, 1.0d, (List<String>) Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return normalize(schema, javaRDD, 0.0d, 1.0d, Collections.emptyList());
    }

    public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> dataset, List<String> list) {
        List<String> list2 = DataFrames.toList(dataset.columns());
        list2.removeAll(list);
        String[] array = DataFrames.toArray(list2);
        List<Row> stdDevMeanColumns = stdDevMeanColumns(dataset, array);
        for (int i = 0; i < array.length; i++) {
            String str = array[i];
            double doubleValue = ((Number) stdDevMeanColumns.get(0).get(i)).doubleValue();
            double doubleValue2 = ((Number) stdDevMeanColumns.get(1).get(i)).doubleValue();
            if (doubleValue == 0.0d) {
                doubleValue = 1.0d;
            }
            dataset = dataset.withColumn(str, dataset.col(str).minus(Double.valueOf(doubleValue2)).divide(Double.valueOf(doubleValue)));
        }
        return dataset;
    }

    public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> javaRDD, List<String> list) {
        return (JavaRDD) DataFrames.toRecords(zeromeanUnitVariance(DataFrames.toDataFrame(schema, javaRDD), list)).getSecond();
    }

    public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return zeroMeanUnitVarianceSequence(schema, javaRDD, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v16, types: [java.util.List] */
    public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD, List<String> list) {
        ArrayList arrayList;
        Dataset<Row> dataFrameSequence = DataFrames.toDataFrameSequence(schema, javaRDD);
        if (list == null) {
            arrayList = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
        } else {
            arrayList = new ArrayList(list);
            arrayList.add(DataFrames.SEQUENCE_UUID_COLUMN);
            arrayList.add(DataFrames.SEQUENCE_INDEX_COLUMN);
        }
        return (JavaRDD) DataFrames.toRecordsSequence(zeromeanUnitVariance(dataFrameSequence, arrayList)).getSecond();
    }

    public static List<Row> minMaxColumns(Dataset<Row> dataset, List<String> list) {
        String[] strArr = new String[list.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = list.get(i);
        }
        return minMaxColumns(dataset, strArr);
    }

    public static List<Row> minMaxColumns(Dataset<Row> dataset, String... strArr) {
        return aggregate(dataset, strArr, new String[]{"min", "max"});
    }

    public static List<Row> stdDevMeanColumns(Dataset<Row> dataset, List<String> list) {
        String[] strArr = new String[list.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = list.get(i);
        }
        return stdDevMeanColumns(dataset, strArr);
    }

    public static List<Row> stdDevMeanColumns(Dataset<Row> dataset, String... strArr) {
        return aggregate(dataset, strArr, new String[]{"stddev", "mean"});
    }

    public static List<Row> aggregate(Dataset<Row> dataset, String[] strArr, String[] strArr2) {
        String[] strArr3 = new String[strArr.length - 1];
        System.arraycopy(strArr, 1, strArr3, 0, strArr3.length);
        ArrayList arrayList = new ArrayList();
        for (String str : strArr2) {
            ListOrderedMap listOrderedMap = new ListOrderedMap();
            for (String str2 : strArr) {
                listOrderedMap.put(str2, str);
            }
            Dataset agg = dataset.agg(listOrderedMap);
            String[] columns = agg.columns();
            TreeMap treeMap = new TreeMap();
            for (String str3 : columns) {
                if (str3.contains("min(") || str3.contains("max(")) {
                    treeMap.put(str3, str3.replace(str, MapToPairForReducerFunction.GLOBAL_KEY).replaceAll("[()]", MapToPairForReducerFunction.GLOBAL_KEY));
                } else if (str3.contains("avg")) {
                    treeMap.put(str3, str3.replace("avg", MapToPairForReducerFunction.GLOBAL_KEY).replaceAll("[()]", MapToPairForReducerFunction.GLOBAL_KEY));
                } else {
                    treeMap.put(str3, str3.replace(str, MapToPairForReducerFunction.GLOBAL_KEY).replaceAll("[()]", MapToPairForReducerFunction.GLOBAL_KEY));
                }
            }
            Dataset dataset2 = null;
            for (Map.Entry entry : treeMap.entrySet()) {
                dataset2 = dataset2 == null ? agg.withColumnRenamed((String) entry.getKey(), (String) entry.getValue()) : dataset2.withColumnRenamed((String) entry.getKey(), (String) entry.getValue());
            }
            arrayList.addAll(dataset2.select(DataFrames.toColumns(strArr)).collectAsList());
        }
        return arrayList;
    }

    public static Dataset<Row> normalize(Dataset<Row> dataset, double d, double d2, List<String> list) {
        List<String> list2 = DataFrames.toList(dataset.columns());
        list2.removeAll(list);
        String[] array = DataFrames.toArray(list2);
        List<Row> minMaxColumns = minMaxColumns(dataset, array);
        for (int i = 0; i < array.length; i++) {
            String str = array[i];
            double doubleValue = ((Number) minMaxColumns.get(0).get(i)).doubleValue();
            double doubleValue2 = ((Number) minMaxColumns.get(1).get(i)).doubleValue() - doubleValue;
            if (doubleValue2 == 0.0d) {
                doubleValue2 = 1.0d;
            }
            dataset = dataset.withColumn(str, dataset.col(str).minus(Double.valueOf(doubleValue)).divide(Double.valueOf(doubleValue2)).multiply(Double.valueOf(d2 - d)).plus(Double.valueOf(d)));
        }
        return dataset;
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> javaRDD, double d, double d2, List<String> list) {
        return (JavaRDD) DataFrames.toRecords(normalize(DataFrames.toDataFrame(schema, javaRDD), d, d2, list)).getSecond();
    }

    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return normalizeSequence(schema, javaRDD, 0.0d, 1.0d);
    }

    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD, double d, double d2) {
        return normalizeSequence(schema, javaRDD, d, d2, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v15, types: [java.util.List] */
    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD, double d, double d2, List<String> list) {
        ArrayList arrayList;
        if (list == null) {
            arrayList = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
        } else {
            arrayList = new ArrayList(list);
            arrayList.add(DataFrames.SEQUENCE_UUID_COLUMN);
            arrayList.add(DataFrames.SEQUENCE_INDEX_COLUMN);
        }
        return (JavaRDD) DataFrames.toRecordsSequence(normalize(DataFrames.toDataFrameSequence(schema, javaRDD), d, d2, arrayList)).getSecond();
    }

    public static Dataset<Row> normalize(Dataset<Row> dataset, List<String> list) {
        return normalize(dataset, 0.0d, 1.0d, list);
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> javaRDD, List<String> list) {
        return normalize(schema, javaRDD, 0.0d, 1.0d, list);
    }
}
