package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.util.StructuralTestUtil;
import java.util.Map;
import java.util.OptionalInt;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.testng.Assert;

/* loaded from: input_file:io/trino/operator/aggregation/TestDoubleHistogramAggregation.class */
public class TestDoubleHistogramAggregation {
    private final TestingAggregationFunction function = new TestingFunctionResolution().getAggregateFunction("numeric_histogram", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, DoubleType.DOUBLE, DoubleType.DOUBLE}));
    private final Type intermediateType = this.function.getIntermediateType();
    private final Type finalType = this.function.getFinalType();
    private final Page input = makeInput(10);

    @Test
    public void test() {
        Aggregator aggregator = getAggregator(AggregationNode.Step.SINGLE);
        aggregator.processPage(this.input);
        Block finalBlock = AggregationTestUtils.getFinalBlock(this.finalType, aggregator);
        Aggregator aggregator2 = getAggregator(AggregationNode.Step.PARTIAL);
        aggregator2.processPage(this.input);
        Block intermediateBlock = AggregationTestUtils.getIntermediateBlock(this.intermediateType, aggregator2);
        Aggregator aggregator3 = getAggregator(AggregationNode.Step.FINAL);
        aggregator3.processPage(new Page(new Block[]{intermediateBlock}));
        Assert.assertEquals(extractSingleValue(AggregationTestUtils.getFinalBlock(this.finalType, aggregator3)), extractSingleValue(finalBlock));
    }

    @Test
    public void testMerge() {
        Aggregator aggregator = getAggregator(AggregationNode.Step.SINGLE);
        aggregator.processPage(this.input);
        Block finalBlock = AggregationTestUtils.getFinalBlock(this.finalType, aggregator);
        Aggregator aggregator2 = getAggregator(AggregationNode.Step.PARTIAL);
        aggregator2.processPage(this.input);
        Block intermediateBlock = AggregationTestUtils.getIntermediateBlock(this.intermediateType, aggregator2);
        Aggregator aggregator3 = getAggregator(AggregationNode.Step.FINAL);
        aggregator3.processPage(new Page(new Block[]{intermediateBlock}));
        aggregator3.processPage(new Page(new Block[]{intermediateBlock}));
        Assert.assertEquals(extractSingleValue(AggregationTestUtils.getFinalBlock(this.finalType, aggregator3)), Maps.transformValues(extractSingleValue(finalBlock), d -> {
            return Double.valueOf(d.doubleValue() * 2.0d);
        }));
    }

    @Test
    public void testNull() {
        Block finalBlock = AggregationTestUtils.getFinalBlock(this.finalType, getAggregator(AggregationNode.Step.SINGLE));
        Assert.assertTrue(finalBlock.getPositionCount() == 1);
        Assert.assertTrue(finalBlock.isNull(0));
    }

    @Test
    public void testBadNumberOfBuckets() {
        Aggregator aggregator = getAggregator(AggregationNode.Step.SINGLE);
        Assertions.assertThatThrownBy(() -> {
            aggregator.processPage(makeInput(0));
        }).isInstanceOf(TrinoException.class).hasMessage("numeric_histogram bucket count must be greater than one");
        AggregationTestUtils.getFinalBlock(this.finalType, aggregator);
    }

    private Aggregator getAggregator(AggregationNode.Step step) {
        return this.function.createAggregatorFactory(step, step.isInputRaw() ? ImmutableList.of(0, 1, 2) : ImmutableList.of(0), OptionalInt.empty()).createAggregator();
    }

    private static Map<Double, Double> extractSingleValue(Block block) {
        return (Map) StructuralTestUtil.mapType(DoubleType.DOUBLE, DoubleType.DOUBLE).getObjectValue((ConnectorSession) null, block, 0);
    }

    private static Page makeInput(int i) {
        PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE, DoubleType.DOUBLE));
        for (int i2 = 0; i2 < 100; i2++) {
            pageBuilder.declarePosition();
            BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(0), i);
            DoubleType.DOUBLE.writeDouble(pageBuilder.getBlockBuilder(1), i2);
            DoubleType.DOUBLE.writeDouble(pageBuilder.getBlockBuilder(2), 1.0d);
        }
        return pageBuilder.build();
    }
}
