/*
 * Decompiled with CFR 0.152.
 */
package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.expressions.javalang.typed;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import scala.Tuple2;
import test.org.apache.spark.sql.JavaDatasetAggregatorSuiteBase;

public class JavaDatasetAggregatorSuite
extends JavaDatasetAggregatorSuiteBase {
    @Test
    public void testTypedAggregationAnonClass() {
        KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = this.generateGroupedDataset();
        Dataset aggregated = grouped.agg(new IntSumOf().toColumn());
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)3), new Tuple2((Object)"b", (Object)3)), (Object)aggregated.collectAsList());
        Dataset aggregated2 = grouped.agg(new IntSumOf().toColumn()).as(Encoders.tuple((Encoder)Encoders.STRING(), (Encoder)Encoders.INT()));
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)3), new Tuple2((Object)"b", (Object)3)), (Object)aggregated2.collectAsList());
    }

    @Test
    public void testTypedAggregationAverage() {
        KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = this.generateGroupedDataset();
        Dataset aggregated = grouped.agg(typed.avg((MapFunction & Serializable)value -> (double)((Integer)value._2()).intValue() * 2.0));
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)3.0), new Tuple2((Object)"b", (Object)6.0)), (Object)aggregated.collectAsList());
    }

    @Test
    public void testTypedAggregationCount() {
        KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = this.generateGroupedDataset();
        Dataset aggregated = grouped.agg(typed.count((MapFunction & Serializable)value -> value));
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)2L), new Tuple2((Object)"b", (Object)1L)), (Object)aggregated.collectAsList());
    }

    @Test
    public void testTypedAggregationSumDouble() {
        KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = this.generateGroupedDataset();
        Dataset aggregated = grouped.agg(typed.sum((MapFunction & Serializable)value -> (double)((Integer)value._2())));
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)3.0), new Tuple2((Object)"b", (Object)3.0)), (Object)aggregated.collectAsList());
    }

    @Test
    public void testTypedAggregationSumLong() {
        KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = this.generateGroupedDataset();
        Dataset aggregated = grouped.agg(typed.sumLong((MapFunction & Serializable)value -> (long)((Integer)value._2())));
        Assertions.assertEquals(Arrays.asList(new Tuple2((Object)"a", (Object)3L), new Tuple2((Object)"b", (Object)3L)), (Object)aggregated.collectAsList());
    }

    static class IntSumOf
    extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
        IntSumOf() {
        }

        public Integer zero() {
            return 0;
        }

        public Integer reduce(Integer l, Tuple2<String, Integer> t) {
            return l + (Integer)t._2();
        }

        public Integer merge(Integer b1, Integer b2) {
            return b1 + b2;
        }

        public Integer finish(Integer reduction) {
            return reduction;
        }

        public Encoder<Integer> bufferEncoder() {
            return Encoders.INT();
        }

        public Encoder<Integer> outputEncoder() {
            return Encoders.INT();
        }
    }
}

