package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.sparkfunction.SequenceToRows;
import org.datavec.spark.transform.sparkfunction.ToRecord;
import org.datavec.spark.transform.sparkfunction.ToRow;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceCreateCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeValue;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/datavec/spark/transform/DataFrames.class */
public class DataFrames {
    public static final String SEQUENCE_UUID_COLUMN = "__SEQ_UUID";
    public static final String SEQUENCE_INDEX_COLUMN = "__SEQ_IDX";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.datavec.spark.transform.DataFrames$3, reason: invalid class name */
    /* loaded from: input_file:org/datavec/spark/transform/DataFrames$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$datavec$api$transform$ColumnType = new int[ColumnType.values().length];

        static {
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Double.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Integer.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Long.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Float.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.String.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    private DataFrames() {
    }

    public static Column std(Dataset<Row> dataset, String str) {
        return functions.sqrt(var(dataset, str));
    }

    public static Column var(Dataset<Row> dataset, String str) {
        return dataset.groupBy(str, new String[0]).agg(functions.variance(str), new Column[0]).col(str);
    }

    public static Column min(Dataset<Row> dataset, String str) {
        return dataset.groupBy(str, new String[0]).agg(functions.min(str), new Column[0]).col(str);
    }

    public static Column max(Dataset<Row> dataset, String str) {
        return dataset.groupBy(str, new String[0]).agg(functions.max(str), new Column[0]).col(str);
    }

    public static Column mean(Dataset<Row> dataset, String str) {
        return dataset.groupBy(str, new String[0]).agg(functions.avg(str), new Column[0]).col(str);
    }

    public static StructType fromSchema(Schema schema) {
        StructField[] structFieldArr = new StructField[schema.numColumns()];
        for (int i = 0; i < structFieldArr.length; i++) {
            switch (AnonymousClass3.$SwitchMap$org$datavec$api$transform$ColumnType[((ColumnType) schema.getColumnTypes().get(i)).ordinal()]) {
                case 1:
                    structFieldArr[i] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    break;
                case 2:
                    structFieldArr[i] = new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    break;
                case 3:
                    structFieldArr[i] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    break;
                case 4:
                    structFieldArr[i] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    break;
                default:
                    throw new IllegalStateException("This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
            }
        }
        return new StructType(structFieldArr);
    }

    public static StructType fromSchemaSequence(Schema schema) {
        StructField[] structFieldArr = new StructField[schema.numColumns() + 2];
        structFieldArr[0] = new StructField(SEQUENCE_UUID_COLUMN, DataTypes.StringType, false, Metadata.empty());
        structFieldArr[1] = new StructField(SEQUENCE_INDEX_COLUMN, DataTypes.IntegerType, false, Metadata.empty());
        for (int i = 0; i < schema.numColumns(); i++) {
            switch (AnonymousClass3.$SwitchMap$org$datavec$api$transform$ColumnType[((ColumnType) schema.getColumnTypes().get(i)).ordinal()]) {
                case 1:
                    structFieldArr[i + 2] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    break;
                case 2:
                    structFieldArr[i + 2] = new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    break;
                case 3:
                    structFieldArr[i + 2] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    break;
                case 4:
                    structFieldArr[i + 2] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    break;
                default:
                    throw new IllegalStateException("This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
            }
        }
        return new StructType(structFieldArr);
    }

    /* JADX WARN: Removed duplicated region for block: B:26:0x00f4  */
    /* JADX WARN: Removed duplicated region for block: B:29:0x0100  */
    /* JADX WARN: Removed duplicated region for block: B:31:0x010c  */
    /* JADX WARN: Removed duplicated region for block: B:33:0x0118  */
    /* JADX WARN: Removed duplicated region for block: B:35:0x0124  */
    /* JADX WARN: Removed duplicated region for block: B:37:0x0130 A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static org.datavec.api.transform.schema.Schema fromStructType(org.apache.spark.sql.types.StructType r4) {
        /*
            Method dump skipped, instructions count: 330
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.datavec.spark.transform.DataFrames.fromStructType(org.apache.spark.sql.types.StructType):org.datavec.api.transform.schema.Schema");
    }

    public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(Dataset<Row> dataset) {
        Schema fromStructType = fromStructType(dataset.schema());
        return new Pair<>(fromStructType, dataset.javaRDD().map(new ToRecord(fromStructType)));
    }

    public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(Dataset<Row> dataset) {
        JavaPairRDD groupBy = dataset.javaRDD().groupBy(new Function<Row, String>() { // from class: org.datavec.spark.transform.DataFrames.1
            public String call(Row row) throws Exception {
                return row.getString(0);
            }
        });
        Schema fromStructType = fromStructType(dataset.schema());
        return new Pair<>(fromStructType, groupBy.combineByKey(new DataFrameToSequenceCreateCombiner(fromStructType), new DataFrameToSequenceMergeValue(fromStructType), new DataFrameToSequenceMergeCombiner()).values().map(new Function<List<List<Writable>>, List<List<Writable>>>() { // from class: org.datavec.spark.transform.DataFrames.2
            public List<List<Writable>> call(List<List<Writable>> list) throws Exception {
                ArrayList arrayList = new ArrayList(list.size());
                for (List<Writable> list2 : list) {
                    ArrayList arrayList2 = new ArrayList();
                    for (int i = 2; i < list2.size(); i++) {
                        arrayList2.add(list2.get(i));
                    }
                    arrayList.add(arrayList2);
                }
                return arrayList;
            }
        }));
    }

    public static Dataset<Row> toDataFrame(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return new SQLContext(new JavaSparkContext(javaRDD.context())).createDataFrame(javaRDD.map(new ToRow(schema)), fromSchema(schema));
    }

    public static Dataset<Row> toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return new SQLContext(new JavaSparkContext(javaRDD.context())).createDataFrame(javaRDD.flatMap(new SequenceToRows(schema)), fromSchemaSequence(schema));
    }

    public static List<Writable> rowToWritables(Schema schema, Row row) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < row.size(); i++) {
            switch (AnonymousClass3.$SwitchMap$org$datavec$api$transform$ColumnType[schema.getType(i).ordinal()]) {
                case 1:
                    arrayList.add(new DoubleWritable(row.getDouble(i)));
                    break;
                case 2:
                    arrayList.add(new IntWritable(row.getInt(i)));
                    break;
                case 3:
                    arrayList.add(new LongWritable(row.getLong(i)));
                    break;
                case 4:
                    arrayList.add(new FloatWritable(row.getFloat(i)));
                    break;
                case 5:
                    arrayList.add(new Text(row.getString(i)));
                    break;
                default:
                    throw new IllegalStateException("Illegal type");
            }
        }
        return arrayList;
    }

    public static List<String> toList(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        for (String str : strArr) {
            arrayList.add(str);
        }
        return arrayList;
    }

    public static String[] toArray(List<String> list) {
        String[] strArr = new String[list.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = list.get(i);
        }
        return strArr;
    }

    public static INDArray toMatrix(List<Row> list) {
        INDArray create = Nd4j.create(new int[]{list.size(), list.get(0).size()});
        for (int i = 0; i < create.rows(); i++) {
            for (int i2 = 0; i2 < create.columns(); i2++) {
                create.putScalar(i, i2, list.get(i).getDouble(i2));
            }
        }
        return create;
    }

    public static List<Column> toColumn(List<String> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(functions.col(it.next()));
        }
        return arrayList;
    }

    public static Column[] toColumns(String... strArr) {
        Column[] columnArr = new Column[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            columnArr[i] = functions.col(strArr[i]);
        }
        return columnArr;
    }
}
