package test.org.apache.spark.sql;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import scala.collection.JavaConverters;

/* loaded from: input_file:test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.class */
public class JavaHigherOrderFunctionsSuite {
    private transient TestSparkSession spark;
    private Dataset<Row> arrDf;
    private Dataset<Row> mapDf;

    private void checkAnswer(Dataset<Row> dataset, List<Row> list) throws Exception {
        List collectAsList = dataset.collectAsList();
        Assert.assertEquals(list.size(), collectAsList.size());
        for (int i = 0; i < list.size(); i++) {
            Row row = list.get(i);
            Row row2 = (Row) collectAsList.get(i);
            Assert.assertEquals(row.size(), row2.size());
            for (int i2 = 0; i2 < row.size(); i2++) {
                Object obj = row.get(i2);
                Object obj2 = row2.get(i2);
                if (obj == null || !obj.getClass().isArray()) {
                    Assert.assertEquals(obj, obj2);
                } else {
                    Assert.assertArrayEquals((Object[]) obj, (Object[]) obj2.getClass().getMethod("array", new Class[0]).invoke(obj2, new Object[0]));
                }
            }
        }
    }

    @SafeVarargs
    private static <T> List<Row> toRows(T... tArr) {
        return (List) Arrays.stream(tArr).map(obj -> {
            return RowFactory.create(new Object[]{obj});
        }).collect(Collectors.toList());
    }

    @SafeVarargs
    private static <T> T[] makeArray(T... tArr) {
        return tArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void setUpArrDf() {
        this.arrDf = this.spark.createDataFrame(toRows((Integer[]) makeArray(1, 9, 8, 7), (Integer[]) makeArray(5, 8, 9, 7, 2), (Integer[]) makeArray(new Integer[0]), 0), new StructType().add("x", new ArrayType(DataTypes.IntegerType, true), true));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void setUpMapDf() {
        this.mapDf = this.spark.createDataFrame(toRows(new HashMap<Integer, Integer>() { // from class: test.org.apache.spark.sql.JavaHigherOrderFunctionsSuite.1
            {
                put(1, 1);
                put(2, 2);
            }
        }, 0), new StructType().add("x", new MapType(DataTypes.IntegerType, DataTypes.IntegerType, true)));
    }

    @Before
    public void setUp() {
        this.spark = new TestSparkSession();
        setUpArrDf();
        setUpMapDf();
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testTransform() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.transform(functions.col("x"), column -> {
            return column.plus(1);
        })}), toRows((Integer[]) makeArray(2, 10, 9, 8), (Integer[]) makeArray(6, 9, 10, 8, 3), (Integer[]) makeArray(new Integer[0]), 0));
        checkAnswer(this.arrDf.select(new Column[]{functions.transform(functions.col("x"), (column2, column3) -> {
            return column2.plus(column3);
        })}), toRows((Integer[]) makeArray(1, 10, 10, 10), (Integer[]) makeArray(5, 9, 11, 10, 6), (Integer[]) makeArray(new Integer[0]), 0));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testFilter() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.filter(functions.col("x"), column -> {
            return column.plus(1).equalTo(10);
        })}), toRows((Integer[]) makeArray(9), (Integer[]) makeArray(9), (Integer[]) makeArray(new Integer[0]), 0));
        checkAnswer(this.arrDf.select(new Column[]{functions.filter(functions.col("x"), (column2, column3) -> {
            return column2.plus(column3).equalTo(10);
        })}), toRows((Integer[]) makeArray(9, 8, 7), (Integer[]) makeArray(7), (Integer[]) makeArray(new Integer[0]), 0));
    }

    @Test
    public void testExists() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.exists(functions.col("x"), column -> {
            return column.plus(1).equalTo(10);
        })}), toRows(true, true, false, null));
    }

    @Test
    public void testForall() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.forall(functions.col("x"), column -> {
            return column.plus(1).equalTo(10);
        })}), toRows(false, false, true, null));
    }

    @Test
    public void testAggregate() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.aggregate(functions.col("x"), functions.lit(0), (column, column2) -> {
            return column.plus(column2);
        })}), toRows(25, 31, 0, null));
        checkAnswer(this.arrDf.select(new Column[]{functions.aggregate(functions.col("x"), functions.lit(0), (column3, column4) -> {
            return column3.plus(column4);
        }, column5 -> {
            return column5;
        })}), toRows(25, 31, 0, null));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testZipWith() throws Exception {
        checkAnswer(this.arrDf.select(new Column[]{functions.zip_with(functions.col("x"), functions.col("x"), (column, column2) -> {
            return functions.lit(42);
        })}), toRows((Integer[]) makeArray(42, 42, 42, 42), (Integer[]) makeArray(42, 42, 42, 42, 42), (Integer[]) makeArray(new Integer[0]), 0));
    }

    @Test
    public void testTransformKeys() throws Exception {
        checkAnswer(this.mapDf.select(new Column[]{functions.transform_keys(functions.col("x"), (column, column2) -> {
            return column.plus(column2);
        })}), toRows(JavaConverters.mapAsScalaMap(new HashMap<Integer, Integer>() { // from class: test.org.apache.spark.sql.JavaHigherOrderFunctionsSuite.2
            {
                put(2, 1);
                put(4, 2);
            }
        }), null));
    }

    @Test
    public void testTransformValues() throws Exception {
        checkAnswer(this.mapDf.select(new Column[]{functions.transform_values(functions.col("x"), (column, column2) -> {
            return column.plus(column2);
        })}), toRows(JavaConverters.mapAsScalaMap(new HashMap<Integer, Integer>() { // from class: test.org.apache.spark.sql.JavaHigherOrderFunctionsSuite.3
            {
                put(1, 2);
                put(2, 4);
            }
        }), null));
    }

    @Test
    public void testMapFilter() throws Exception {
        checkAnswer(this.mapDf.select(new Column[]{functions.map_filter(functions.col("x"), (column, column2) -> {
            return functions.lit(false);
        })}), toRows(JavaConverters.mapAsScalaMap(new HashMap()), null));
    }

    @Test
    public void testMapZipWith() throws Exception {
        checkAnswer(this.mapDf.select(new Column[]{functions.map_zip_with(functions.col("x"), functions.col("x"), (column, column2, column3) -> {
            return functions.lit(false);
        })}), toRows(JavaConverters.mapAsScalaMap(new HashMap<Integer, Boolean>() { // from class: test.org.apache.spark.sql.JavaHigherOrderFunctionsSuite.4
            {
                put(1, false);
                put(2, false);
            }
        }), null));
    }
}
