/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.jet.sql.impl.opt.physical;

import com.hazelcast.function.BiConsumerEx;
import com.hazelcast.function.FunctionEx;
import com.hazelcast.function.SupplierEx;
import com.hazelcast.internal.serialization.SerializationService;
import com.hazelcast.jet.aggregate.AggregateOperation;
import com.hazelcast.jet.impl.execution.init.Contexts;
import com.hazelcast.jet.sql.impl.aggregate.AvgSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.CountSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.MaxSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.MinSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SumSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.ValueSqlAggregation;
import com.hazelcast.jet.sql.impl.opt.OptUtils;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.sql.impl.QueryException;
import com.hazelcast.sql.impl.row.JetSqlRow;
import com.hazelcast.sql.impl.type.QueryDataType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

abstract class AggregateAbstractPhysicalRule
extends RelRule<RelRule.Config> {
    protected AggregateAbstractPhysicalRule(RelRule.Config config) {
        super(config);
    }

    protected static AggregateOperation<?, JetSqlRow> aggregateOperation(RelDataType inputType, ImmutableBitSet groupSet, List<AggregateCall> aggregateCalls) {
        List operandTypes = OptUtils.schema(inputType).getTypes();
        ArrayList<Object> aggregationProviders = new ArrayList<Object>();
        ArrayList<Object> valueProviders = new ArrayList<Object>();
        for (Integer groupIndex : groupSet.toList()) {
            aggregationProviders.add(ValueSqlAggregation::new);
            valueProviders.add((FunctionEx & Serializable)row -> row.getMaybeSerialized(groupIndex.intValue()));
        }
        block8: for (AggregateCall aggregateCall : aggregateCalls) {
            boolean distinct = aggregateCall.isDistinct();
            List<Integer> aggregateCallArguments = aggregateCall.getArgList();
            SqlKind kind = aggregateCall.getAggregation().getKind();
            switch (kind) {
                case COUNT: {
                    int countIndex;
                    if (distinct) {
                        countIndex = aggregateCallArguments.get(0);
                        aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(true, true));
                        valueProviders.add((FunctionEx & Serializable)row -> row.getMaybeSerialized(countIndex));
                        continue block8;
                    }
                    if (aggregateCallArguments.size() == 1) {
                        countIndex = aggregateCallArguments.get(0);
                        aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(true, false));
                        valueProviders.add((FunctionEx & Serializable)row -> row.getMaybeSerialized(countIndex));
                        continue block8;
                    }
                    aggregationProviders.add((SupplierEx & Serializable)() -> CountSqlAggregations.from(false, false));
                    valueProviders.add((FunctionEx & Serializable)row -> null);
                    continue block8;
                }
                case MIN: {
                    int minIndex = aggregateCallArguments.get(0);
                    aggregationProviders.add(MinSqlAggregation::new);
                    valueProviders.add((FunctionEx & Serializable)row -> row.get(minIndex));
                    continue block8;
                }
                case MAX: {
                    int maxIndex = aggregateCallArguments.get(0);
                    aggregationProviders.add(MaxSqlAggregation::new);
                    valueProviders.add((FunctionEx & Serializable)row -> row.get(maxIndex));
                    continue block8;
                }
                case SUM: {
                    int sumIndex = aggregateCallArguments.get(0);
                    QueryDataType sumOperandType = (QueryDataType)operandTypes.get(sumIndex);
                    aggregationProviders.add((SupplierEx & Serializable)() -> SumSqlAggregations.from(sumOperandType, distinct));
                    valueProviders.add((FunctionEx & Serializable)row -> row.get(sumIndex));
                    continue block8;
                }
                case AVG: {
                    int avgIndex = aggregateCallArguments.get(0);
                    QueryDataType avgOperandType = (QueryDataType)operandTypes.get(avgIndex);
                    aggregationProviders.add((SupplierEx & Serializable)() -> AvgSqlAggregations.from(avgOperandType, distinct));
                    valueProviders.add((FunctionEx & Serializable)row -> row.get(avgIndex));
                    continue block8;
                }
            }
            throw QueryException.error((String)("Unsupported aggregation function: " + (Object)((Object)kind)));
        }
        return AggregateOperation.withCreate((SupplierEx & Serializable)() -> {
            ArrayList<Object> aggregations = new ArrayList<Object>(aggregationProviders.size());
            for (SupplierEx aggregationProvider : aggregationProviders) {
                aggregations.add(aggregationProvider.get());
            }
            return aggregations;
        }).andAccumulate((BiConsumerEx & Serializable)(aggregations, row) -> {
            for (int i = 0; i < aggregations.size(); ++i) {
                ((SqlAggregation)aggregations.get(i)).accumulate(((FunctionEx)valueProviders.get(i)).apply(row));
            }
        }).andCombine((BiConsumerEx & Serializable)(lefts, rights) -> {
            assert (lefts.size() == rights.size());
            for (int i = 0; i < lefts.size(); ++i) {
                ((SqlAggregation)lefts.get(i)).combine((SqlAggregation)rights.get(i));
            }
        }).andExportFinish((FunctionEx & Serializable)aggregations -> {
            Object[] row = new Object[aggregations.size()];
            for (int i = 0; i < aggregations.size(); ++i) {
                row[i] = ((SqlAggregation)aggregations.get(i)).collect();
            }
            return new JetSqlRow((SerializationService)Contexts.getCastedThreadContext().serializationService(), row);
        });
    }
}

