/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import java.math.BigDecimal;
import java.math.MathContext;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.function.scalar.DataTypeConversionFunctions;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.core.query.request.context.ExpressionContext;

public class SumPrecisionAggregationFunction
extends BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
    MathContext _mathContext = new MathContext(0);
    Integer _scale = null;

    public SumPrecisionAggregationFunction(List<ExpressionContext> arguments) {
        super(arguments.get(0));
        int numArguments = arguments.size();
        if (numArguments == 3) {
            Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
            this._scale = Integer.parseInt(arguments.get(2).getLiteral());
            this._mathContext = new MathContext(precision);
        } else if (numArguments == 2) {
            Integer precision = Integer.parseInt(arguments.get(1).getLiteral());
            this._mathContext = new MathContext(precision);
        }
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.SUMPRECISION;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        byte[][] valueArray = blockValSetMap.get(this._expression).getBytesValuesSV();
        BigDecimal sumValue = this.getDefaultResult(aggregationResultHolder);
        for (int i = 0; i < length; ++i) {
            BigDecimal value = DataTypeConversionFunctions.bytesToBigDecimalObject((byte[])valueArray[i]);
            sumValue = sumValue.add(value);
        }
        aggregationResultHolder.setValue(sumValue);
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        byte[][] valueArray = blockValSetMap.get(this._expression).getBytesValuesSV();
        for (int i = 0; i < length; ++i) {
            int groupKey = groupKeyArray[i];
            BigDecimal groupByResultValue = this.getDefaultResult(groupByResultHolder, groupKey);
            BigDecimal value = DataTypeConversionFunctions.bytesToBigDecimalObject((byte[])valueArray[i]);
            groupByResultValue = groupByResultValue.add(value);
            groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        byte[][] valueArray = blockValSetMap.get(this._expression).getBytesValuesSV();
        for (int i = 0; i < length; ++i) {
            byte[] value = valueArray[i];
            for (int groupKey : groupKeysArray[i]) {
                BigDecimal groupByResultValue = this.getDefaultResult(groupByResultHolder, groupKey);
                BigDecimal valueBigDecimal = DataTypeConversionFunctions.bytesToBigDecimalObject((byte[])value);
                groupByResultValue = groupByResultValue.add(valueBigDecimal);
                groupByResultHolder.setValueForKey(groupKey, groupByResultValue);
            }
        }
    }

    @Override
    public BigDecimal extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        return this.getDefaultResult(aggregationResultHolder);
    }

    @Override
    public BigDecimal extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        return this.getDefaultResult(groupByResultHolder, groupKey);
    }

    @Override
    public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal intermediateResult2) {
        try {
            return intermediateResult1.add(intermediateResult2);
        }
        catch (Exception e) {
            throw new RuntimeException("Caught Exception while merging results in sum with precision function", e);
        }
    }

    @Override
    public boolean isIntermediateResultComparable() {
        return true;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override
    public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
        return this.setScale(new BigDecimal(intermediateResult.toString(), this._mathContext));
    }

    public BigDecimal getDefaultResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal result = (BigDecimal)aggregationResultHolder.getResult();
        if (result == null) {
            result = new BigDecimal(0);
            aggregationResultHolder.setValue(result);
        }
        return result;
    }

    public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        BigDecimal result = (BigDecimal)groupByResultHolder.getResult(groupKey);
        if (result == null) {
            result = new BigDecimal(0);
            groupByResultHolder.setValueForKey(groupKey, result);
        }
        return result;
    }

    private BigDecimal setScale(BigDecimal value) {
        if (this._scale != null) {
            value = value.setScale((int)this._scale, 6);
        }
        return value;
    }
}

