package io.trino.operator.aggregation;

import io.trino.operator.aggregation.state.LongDecimalWithOverflowState;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.Int128ArrayBlock;
import io.trino.spi.block.Int128ArrayBlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;

@AggregationFunction("sum")
@Description("Calculates the sum over the input values")
/* loaded from: input_file:io/trino/operator/aggregation/DecimalSumAggregation.class */
public final class DecimalSumAggregation {
    private DecimalSumAggregation() {
    }

    @InputFunction
    @LiteralParameters({"p", "s"})
    public static void inputShortDecimal(@AggregationState LongDecimalWithOverflowState longDecimalWithOverflowState, @SqlType("decimal(p,s)") long j) {
        longDecimalWithOverflowState.setNotNull();
        long[] decimalArray = longDecimalWithOverflowState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowState.getDecimalArrayOffset();
        longDecimalWithOverflowState.setOverflow(Math.addExact(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], j >> 63, j, decimalArray, decimalArrayOffset), longDecimalWithOverflowState.getOverflow()));
    }

    @InputFunction
    @LiteralParameters({"p", "s"})
    public static void inputLongDecimal(@AggregationState LongDecimalWithOverflowState longDecimalWithOverflowState, @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Int128ArrayBlock int128ArrayBlock, @BlockIndex int i) {
        longDecimalWithOverflowState.setNotNull();
        long[] decimalArray = longDecimalWithOverflowState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowState.getDecimalArrayOffset();
        longDecimalWithOverflowState.addOverflow(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], int128ArrayBlock.getLong(i, 0), int128ArrayBlock.getLong(i, 8), decimalArray, decimalArrayOffset));
    }

    @CombineFunction
    public static void combine(@AggregationState LongDecimalWithOverflowState longDecimalWithOverflowState, @AggregationState LongDecimalWithOverflowState longDecimalWithOverflowState2) {
        long[] decimalArray = longDecimalWithOverflowState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowState.getDecimalArrayOffset();
        long[] decimalArray2 = longDecimalWithOverflowState2.getDecimalArray();
        int decimalArrayOffset2 = longDecimalWithOverflowState2.getDecimalArrayOffset();
        if (longDecimalWithOverflowState.isNotNull()) {
            longDecimalWithOverflowState.addOverflow(Math.addExact(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], decimalArray2[decimalArrayOffset2], decimalArray2[decimalArrayOffset2 + 1], decimalArray, decimalArrayOffset), longDecimalWithOverflowState2.getOverflow()));
            return;
        }
        longDecimalWithOverflowState.setNotNull();
        decimalArray[decimalArrayOffset] = decimalArray2[decimalArrayOffset2];
        decimalArray[decimalArrayOffset + 1] = decimalArray2[decimalArrayOffset2 + 1];
        longDecimalWithOverflowState.setOverflow(longDecimalWithOverflowState2.getOverflow());
    }

    @OutputFunction("decimal(38,s)")
    public static void outputDecimal(@AggregationState LongDecimalWithOverflowState longDecimalWithOverflowState, BlockBuilder blockBuilder) {
        if (!longDecimalWithOverflowState.isNotNull()) {
            blockBuilder.appendNull();
            return;
        }
        if (longDecimalWithOverflowState.getOverflow() != 0) {
            throw new ArithmeticException("Decimal overflow");
        }
        long[] decimalArray = longDecimalWithOverflowState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowState.getDecimalArrayOffset();
        long j = decimalArray[decimalArrayOffset];
        long j2 = decimalArray[decimalArrayOffset + 1];
        Decimals.throwIfOverflows(j, j2);
        ((Int128ArrayBlockBuilder) blockBuilder).writeInt128(j, j2);
    }
}
