package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.TestTableScanNodePartitioning;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.tree.QualifiedName;
import java.util.OptionalInt;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.options.WarmupMode;
import org.testng.Assert;
import org.testng.annotations.Test;

@Warmup(iterations = TestTableScanNodePartitioning.BUCKET_COUNT)
@State(Scope.Thread)
@Measurement(iterations = TestTableScanNodePartitioning.BUCKET_COUNT)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@BenchmarkMode({Mode.AverageTime})
/* loaded from: input_file:io/trino/operator/aggregation/BenchmarkDecimalAggregation.class */
public class BenchmarkDecimalAggregation {
    private static final Random RANDOM = new Random(633969769);
    private static final int ELEMENT_COUNT = 1000000;

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/aggregation/BenchmarkDecimalAggregation$BenchmarkData.class */
    public static class BenchmarkData {

        @Param({"SHORT", "LONG"})
        private String type = "SHORT";

        @Param({"avg", "sum"})
        private String function = "avg";

        @Param({"10", "1000"})
        private int groupCount = 10;

        @Param({"0.0", "0.05"})
        private float nullRate;
        private AggregatorFactory partialAggregatorFactory;
        private AggregatorFactory finalAggregatorFactory;
        private int[] groupIds;
        private Page values;
        private Page intermediateValues;

        @Setup
        public void setup() {
            TestingFunctionResolution testingFunctionResolution = new TestingFunctionResolution();
            String str = this.type;
            boolean z = -1;
            switch (str.hashCode()) {
                case 2342524:
                    if (str.equals("LONG")) {
                        z = true;
                        break;
                    }
                    break;
                case 78875740:
                    if (str.equals("SHORT")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.values = createValues(testingFunctionResolution, DecimalType.createDecimalType(14, 3));
                    break;
                case true:
                    this.values = createValues(testingFunctionResolution, DecimalType.createDecimalType(30, 10));
                    break;
            }
            int[] iArr = new int[BenchmarkDecimalAggregation.ELEMENT_COUNT];
            for (int i = 0; i < BenchmarkDecimalAggregation.ELEMENT_COUNT; i++) {
                iArr[i] = BenchmarkDecimalAggregation.RANDOM.nextInt(this.groupCount);
            }
            this.groupIds = iArr;
            this.intermediateValues = new Page(new Block[]{createIntermediateValues(this.partialAggregatorFactory.createGroupedAggregator(), this.groupIds, this.values)});
        }

        private Block createIntermediateValues(GroupedAggregator groupedAggregator, int[] iArr, Page page) {
            groupedAggregator.processPage(this.groupCount, iArr, page);
            BlockBuilder createBlockBuilder = groupedAggregator.getType().createBlockBuilder((BlockBuilderStatus) null, this.groupCount);
            for (int i = 0; i < this.groupCount; i++) {
                groupedAggregator.evaluate(i, createBlockBuilder);
            }
            return createBlockBuilder.build();
        }

        private Page createValues(TestingFunctionResolution testingFunctionResolution, Type type) {
            TestingAggregationFunction aggregateFunction = testingFunctionResolution.getAggregateFunction(QualifiedName.of(this.function), TypeSignatureProvider.fromTypes(new Type[]{type}));
            this.partialAggregatorFactory = aggregateFunction.createAggregatorFactory(AggregationNode.Step.PARTIAL, ImmutableList.of(0), OptionalInt.empty());
            this.finalAggregatorFactory = aggregateFunction.createAggregatorFactory(AggregationNode.Step.FINAL, ImmutableList.of(0), OptionalInt.empty());
            return new Page(new Block[]{BlockAssertions.createRandomBlockForType(type, BenchmarkDecimalAggregation.ELEMENT_COUNT, this.nullRate)});
        }

        public AggregatorFactory getPartialAggregatorFactory() {
            return this.partialAggregatorFactory;
        }

        public AggregatorFactory getFinalAggregatorFactory() {
            return this.finalAggregatorFactory;
        }

        public Page getValues() {
            return this.values;
        }

        public int[] getGroupIds() {
            return this.groupIds;
        }

        public int getGroupCount() {
            return this.groupCount;
        }

        public Page getIntermediateValues() {
            return this.intermediateValues;
        }
    }

    @Benchmark
    @OperationsPerInvocation(ELEMENT_COUNT)
    public GroupedAggregator benchmark(BenchmarkData benchmarkData) {
        GroupedAggregator createGroupedAggregator = benchmarkData.getPartialAggregatorFactory().createGroupedAggregator();
        createGroupedAggregator.processPage(benchmarkData.getGroupCount(), benchmarkData.getGroupIds(), benchmarkData.getValues());
        return createGroupedAggregator;
    }

    @Benchmark
    @OperationsPerInvocation(ELEMENT_COUNT)
    public Block benchmarkEvaluateIntermediate(BenchmarkData benchmarkData) {
        GroupedAggregator createGroupedAggregator = benchmarkData.getPartialAggregatorFactory().createGroupedAggregator();
        createGroupedAggregator.processPage(benchmarkData.getGroupCount(), benchmarkData.getGroupIds(), benchmarkData.getValues());
        BlockBuilder createBlockBuilder = createGroupedAggregator.getType().createBlockBuilder((BlockBuilderStatus) null, benchmarkData.getGroupCount());
        for (int i = 0; i < benchmarkData.getGroupCount(); i++) {
            createGroupedAggregator.evaluate(i, createBlockBuilder);
        }
        return createBlockBuilder.build();
    }

    @Benchmark
    public Block benchmarkEvaluateFinal(BenchmarkData benchmarkData) {
        GroupedAggregator createGroupedAggregator = benchmarkData.getFinalAggregatorFactory().createGroupedAggregator();
        createGroupedAggregator.processPage(benchmarkData.getGroupCount(), benchmarkData.getGroupIds(), benchmarkData.getIntermediateValues());
        createGroupedAggregator.processPage(benchmarkData.getGroupCount(), benchmarkData.getGroupIds(), benchmarkData.getIntermediateValues());
        BlockBuilder createBlockBuilder = createGroupedAggregator.getType().createBlockBuilder((BlockBuilderStatus) null, benchmarkData.getGroupCount());
        for (int i = 0; i < benchmarkData.getGroupCount(); i++) {
            createGroupedAggregator.evaluate(i, createBlockBuilder);
        }
        return createBlockBuilder.build();
    }

    @Test
    public void verify() {
        BenchmarkData benchmarkData = new BenchmarkData();
        benchmarkData.setup();
        Assert.assertEquals(benchmarkData.getGroupIds().length, benchmarkData.getValues().getPositionCount());
        new BenchmarkDecimalAggregation().benchmark(benchmarkData);
    }

    public static void main(String[] strArr) throws Exception {
        new BenchmarkDecimalAggregation().verify();
        Benchmarks.benchmark(BenchmarkDecimalAggregation.class, WarmupMode.BULK).run();
    }
}
