package com.facebook.presto.operator.aggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.UnscaledDecimal128Arithmetic;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SignatureBinder;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AggregationMetadata;
import com.facebook.presto.operator.aggregation.state.LongDecimalWithOverflowState;
import com.facebook.presto.operator.aggregation.state.LongDecimalWithOverflowStateFactory;
import com.facebook.presto.operator.aggregation.state.LongDecimalWithOverflowStateSerializer;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.util.Reflection;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slice;
import java.lang.invoke.MethodHandle;
import java.util.List;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/DecimalSumAggregation.class */
public class DecimalSumAggregation extends SqlAggregationFunction {
    private static final String NAME = "sum";
    private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(38, 0);
    private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(18, 0);
    public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation();
    private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "inputShortDecimal", LongDecimalWithOverflowState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "inputLongDecimal", LongDecimalWithOverflowState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "outputLongDecimal", LongDecimalWithOverflowState.class, BlockBuilder.class);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "combine", LongDecimalWithOverflowState.class, LongDecimalWithOverflowState.class);

    public DecimalSumAggregation() {
        super(NAME, ImmutableList.of(), ImmutableList.of(), TypeSignature.parseTypeSignature("decimal(38,s)", ImmutableSet.of("s")), ImmutableList.of(TypeSignature.parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s"))), FunctionKind.AGGREGATE);
    }

    public String getDescription() {
        return "Calculates the sum over the input values";
    }

    @Override // com.facebook.presto.metadata.SqlAggregationFunction
    public InternalAggregationFunction specialize(BoundVariables boundVariables, int i, FunctionAndTypeManager functionAndTypeManager) {
        return generateAggregation((Type) Iterables.getOnlyElement(SignatureBinder.applyBoundVariables(functionAndTypeManager, (List<TypeSignature>) getSignature().getArgumentTypes(), boundVariables)), SignatureBinder.applyBoundVariables(functionAndTypeManager, getSignature().getReturnType(), boundVariables));
    }

    private static InternalAggregationFunction generateAggregation(Type type, Type type2) {
        Preconditions.checkArgument(type instanceof DecimalType, "type must be Decimal");
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(DecimalSumAggregation.class.getClassLoader());
        ImmutableList of = ImmutableList.of(type);
        LongDecimalWithOverflowStateSerializer longDecimalWithOverflowStateSerializer = new LongDecimalWithOverflowStateSerializer();
        AggregationMetadata aggregationMetadata = new AggregationMetadata(AggregationUtils.generateAggregationName(NAME, type2.getTypeSignature(), (List) of.stream().map((v0) -> {
            return v0.getTypeSignature();
        }).collect(ImmutableList.toImmutableList())), createInputParameterMetadata(type), ((DecimalType) type).isShort() ? SHORT_DECIMAL_INPUT_FUNCTION : LONG_DECIMAL_INPUT_FUNCTION, COMBINE_FUNCTION, LONG_DECIMAL_OUTPUT_FUNCTION, ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(LongDecimalWithOverflowState.class, longDecimalWithOverflowStateSerializer, new LongDecimalWithOverflowStateFactory())), type2);
        return new InternalAggregationFunction(NAME, of, ImmutableList.of(longDecimalWithOverflowStateSerializer.getSerializedType()), type2, true, false, AccumulatorCompiler.generateAccumulatorFactoryBinder(aggregationMetadata, dynamicClassLoader));
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type type) {
        return ImmutableList.of(new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL, type), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX));
    }

    public static void inputShortDecimal(LongDecimalWithOverflowState longDecimalWithOverflowState, Block block, int i) {
        Slice longDecimal = longDecimalWithOverflowState.getLongDecimal();
        if (longDecimal == null) {
            longDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
            longDecimalWithOverflowState.setLongDecimal(longDecimal);
        }
        longDecimalWithOverflowState.addOverflow(UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, UnscaledDecimal128Arithmetic.unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, i)), longDecimal));
    }

    public static void inputLongDecimal(LongDecimalWithOverflowState longDecimalWithOverflowState, Block block, int i) {
        Slice longDecimal = longDecimalWithOverflowState.getLongDecimal();
        if (longDecimal == null) {
            longDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
            longDecimalWithOverflowState.setLongDecimal(longDecimal);
        }
        longDecimalWithOverflowState.addOverflow(UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, LONG_DECIMAL_TYPE.getSlice(block, i), longDecimal));
    }

    public static void combine(LongDecimalWithOverflowState longDecimalWithOverflowState, LongDecimalWithOverflowState longDecimalWithOverflowState2) {
        long overflow = longDecimalWithOverflowState2.getOverflow();
        Slice longDecimal = longDecimalWithOverflowState.getLongDecimal();
        Slice longDecimal2 = longDecimalWithOverflowState2.getLongDecimal();
        if (longDecimal == null) {
            longDecimalWithOverflowState.setLongDecimal(longDecimal2);
        } else {
            overflow += UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, longDecimal2, longDecimal);
        }
        longDecimalWithOverflowState.addOverflow(overflow);
    }

    public static void outputLongDecimal(LongDecimalWithOverflowState longDecimalWithOverflowState, BlockBuilder blockBuilder) {
        Slice longDecimal = longDecimalWithOverflowState.getLongDecimal();
        if (longDecimal == null) {
            blockBuilder.appendNull();
            return;
        }
        if (longDecimalWithOverflowState.getOverflow() != 0) {
            UnscaledDecimal128Arithmetic.throwOverflowException();
        }
        UnscaledDecimal128Arithmetic.throwIfOverflows(longDecimal);
        LONG_DECIMAL_TYPE.writeSlice(blockBuilder, longDecimal);
    }
}
