/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.com.google.common.base.Preconditions;
import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Iterables;
import com.hazelcast.com.google.common.collect.Lists;
import com.hazelcast.org.apache.calcite.plan.Contexts;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelCollations;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.core.RelFactories;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rel.rules.ImmutableAggregateExpandDistinctAggregatesRule;
import com.hazelcast.org.apache.calcite.rel.rules.TransformationRule;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeField;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBeans;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Optionality;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Enclosing
public final class AggregateExpandDistinctAggregatesRule
extends RelRule<Config>
implements TransformationRule {
    AggregateExpandDistinctAggregatesRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends Aggregate> clazz, boolean useGroupingSets, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b -> b.operand(clazz).anyInputs()).as(Config.class).withUsingGroupingSets(useGroupingSets));
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, boolean useGroupingSets, RelFactories.JoinFactory joinFactory) {
        this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, RelFactories.JoinFactory joinFactory) {
        this(clazz, false, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        int i2;
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (!aggregate.containsDistinctCall()) {
            return;
        }
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        List distinctAggCalls = aggCalls.stream().filter(AggregateCall::isDistinct).collect(Collectors.toList());
        List nonDistinctAggCalls = aggCalls.stream().filter(aggCall -> !aggCall.isDistinct()).collect(Collectors.toList());
        long filterCount = aggCalls.stream().filter(aggCall -> aggCall.filterArg >= 0).count();
        long unsupportedNonDistinctAggCallCount = nonDistinctAggCalls.stream().filter(aggCall -> {
            SqlKind aggCallKind = aggCall.getAggregation().getKind();
            switch (aggCallKind) {
                case COUNT: 
                case SUM: 
                case SUM0: 
                case MIN: 
                case MAX: {
                    return false;
                }
            }
            return true;
        }).count();
        Set distinctCallArgLists = distinctAggCalls.stream().map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg)).collect(Collectors.toCollection(LinkedHashSet::new));
        Preconditions.checkState(distinctCallArgLists.size() > 0, "containsDistinctCall lied");
        List nonDistinctAggCallsOfIgnoredOptionality = nonDistinctAggCalls.stream().filter(aggCall -> aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED).collect(Collectors.toList());
        Set distinctCallArgLists2 = Stream.of(distinctAggCalls, nonDistinctAggCallsOfIgnoredOptionality).flatMap(Collection::stream).map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg)).collect(Collectors.toCollection(LinkedHashSet::new));
        if (nonDistinctAggCalls.size() - nonDistinctAggCallsOfIgnoredOptionality.size() == 0 && distinctCallArgLists2.size() == 1 && aggregate.getGroupType() == Aggregate.Group.SIMPLE) {
            Pair pair = (Pair)Iterables.getOnlyElement(distinctCallArgLists2);
            RelBuilder relBuilder = call.builder();
            AggregateExpandDistinctAggregatesRule.convertMonopole(relBuilder, aggregate, (List)pair.left, (Integer)pair.right);
            call.transformTo(relBuilder.build());
            return;
        }
        if (((Config)this.config).isUsingGroupingSets()) {
            AggregateExpandDistinctAggregatesRule.rewriteUsingGroupingSets(call, aggregate);
            return;
        }
        if (distinctAggCalls.size() == 1 && filterCount == 0L && unsupportedNonDistinctAggCallCount == 0L && nonDistinctAggCalls.size() > 0) {
            RelBuilder relBuilder = call.builder();
            AggregateExpandDistinctAggregatesRule.convertSingletonDistinct(relBuilder, aggregate, distinctCallArgLists);
            call.transformTo(relBuilder.build());
            return;
        }
        List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
        ArrayList<@Nullable RexInputRef> refs = new ArrayList<RexInputRef>();
        List<String> fieldNames = aggregate.getRowType().getFieldNames();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        int groupCount = aggregate.getGroupCount();
        for (int i2 : Util.range(groupCount)) {
            refs.add(RexInputRef.of(i2, aggFields));
        }
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        i2 = -1;
        for (AggregateCall aggCall2 : aggregate.getAggCallList()) {
            ++i2;
            if (aggCall2.isDistinct()) {
                refs.add(null);
                continue;
            }
            refs.add(new RexInputRef(groupCount + newAggCallList.size(), aggFields.get(groupCount + i2).getType()));
            newAggCallList.add(aggCall2);
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        int n = 0;
        if (!newAggCallList.isEmpty()) {
            RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, (Iterable<? extends ImmutableBitSet>)aggregate.getGroupSets());
            relBuilder.aggregate(groupKey, (List<AggregateCall>)newAggCallList);
            ++n;
        }
        for (Pair argList : distinctCallArgLists) {
            AggregateExpandDistinctAggregatesRule.doRewrite(relBuilder, aggregate, n++, (List)argList.left, (Integer)argList.right, refs);
        }
        ArrayList<RexInputRef> nonNullRefs = refs;
        relBuilder.project(nonNullRefs, fieldNames);
        call.transformTo(relBuilder.build());
    }

    private static RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
        Preconditions.checkArgument(argLists.size() == 1);
        relBuilder.push(aggregate.getInput());
        List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
        ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
        TreeSet<Integer> bottomGroups = new TreeSet<Integer>(aggregate.getGroupSet().asList());
        for (AggregateCall aggCall : originalAggCalls) {
            if (!aggCall.isDistinct()) continue;
            bottomGroups.addAll(aggCall.getArgList());
            break;
        }
        ImmutableBitSet bottomGroupSet = ImmutableBitSet.of(bottomGroups);
        ArrayList<AggregateCall> bottomAggregateCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggCall : originalAggCalls) {
            if (aggCall.isDistinct()) continue;
            AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.getArgList(), -1, aggCall.distinctKeys, aggCall.collation, bottomGroupSet.cardinality(), relBuilder.peek(), null, aggCall.name);
            bottomAggregateCalls.add(newCall);
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), bottomGroupSet, null, bottomAggregateCalls));
        ArrayList<AggregateCall> topAggregateCalls = new ArrayList<AggregateCall>();
        int nonDistinctAggCallProcessedSoFar = 0;
        for (AggregateCall aggCall : originalAggCalls) {
            AggregateCall newCall;
            if (aggCall.isDistinct()) {
                ArrayList<Integer> newArgList = new ArrayList<Integer>();
                for (int arg : aggCall.getArgList()) {
                    newArgList.add(bottomGroups.headSet(arg, false).size());
                }
                newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgList, -1, aggCall.distinctKeys, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            } else {
                int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
                ImmutableList<Integer> newArgs = ImmutableList.of(Integer.valueOf(arg));
                if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
                    RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
                    newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.distinctKeys, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), typeFactory.getTypeSystem().deriveSumType(typeFactory, aggCall.getType()), aggCall.getName());
                } else {
                    newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.distinctKeys, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
                }
                ++nonDistinctAggCallProcessedSoFar;
            }
            topAggregateCalls.add(newCall);
        }
        HashSet<Integer> topGroupSet = new HashSet<Integer>();
        int groupSetToAdd = 0;
        Iterator iterator = bottomGroups.iterator();
        while (iterator.hasNext()) {
            int bottomGroup = (Integer)iterator.next();
            if (originalGroupSet.get(bottomGroup)) {
                topGroupSet.add(groupSetToAdd);
            }
            ++groupSetToAdd;
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
        relBuilder.convert(aggregate.getRowType(), true);
        return relBuilder;
    }

    private static void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
        TreeSet<ImmutableBitSet> groupSetTreeSet = new TreeSet<ImmutableBitSet>(ImmutableBitSet.ORDERING);
        HashMap<ImmutableBitSet, Set> distinctFilterArgMap = new HashMap<ImmutableBitSet, Set>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            ImmutableBitSet groupSet;
            int filterArg;
            if (!aggCall.isDistinct()) {
                filterArg = -1;
                ImmutableBitSet groupSet2 = aggregate.getGroupSet();
                groupSetTreeSet.add(aggregate.getGroupSet());
            } else {
                filterArg = aggCall.filterArg;
                groupSet = ImmutableBitSet.of(aggCall.getArgList()).setIf(filterArg, filterArg >= 0).union(aggregate.getGroupSet());
                groupSetTreeSet.add(groupSet);
            }
            Set filterList = distinctFilterArgMap.computeIfAbsent(groupSet, g2 -> new HashSet());
            filterList.add(filterArg);
        }
        ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
        ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
        ArrayList<AggregateCall> distinctAggCalls = new ArrayList<AggregateCall>();
        for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
            if (((AggregateCall)aggCall.left).isDistinct()) continue;
            AggregateCall newAggCall = ((AggregateCall)aggCall.left).adaptTo(aggregate.getInput(), ((AggregateCall)aggCall.left).getArgList(), ((AggregateCall)aggCall.left).filterArg, aggregate.getGroupCount(), fullGroupSet.cardinality());
            distinctAggCalls.add(newAggCall.withName((String)aggCall.right));
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        int groupCount = fullGroupSet.cardinality();
        LinkedHashMap<Pair<ImmutableBitSet, Integer>, Integer> filters = new LinkedHashMap<Pair<ImmutableBitSet, Integer>, Integer>();
        int z = groupCount + distinctAggCalls.size();
        for (ImmutableBitSet groupSet3 : groupSets) {
            Set filterArgList = (Set)distinctFilterArgMap.get(groupSet3);
            for (Integer filterArg : Objects.requireNonNull(filterArgList, "filterArgList")) {
                filters.put(Pair.of(groupSet3, filterArg), z);
                ++z;
            }
        }
        distinctAggCalls.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false, ImmutableIntList.copyOf(fullGroupSet), -1, null, RelCollations.EMPTY, groupSets.size(), relBuilder.peek(), null, "$g"));
        relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, (Iterable<? extends ImmutableBitSet>)groupSets), (List<AggregateCall>)distinctAggCalls);
        if (!filters.isEmpty()) {
            ArrayList<RexNode> nodes = new ArrayList<RexNode>(relBuilder.fields());
            RexNode nodeZ = (RexNode)nodes.remove(nodes.size() - 1);
            for (Map.Entry entry : filters.entrySet()) {
                long v = AggregateExpandDistinctAggregatesRule.groupValue(fullGroupSet.asList(), (ImmutableBitSet)((Pair)entry.getKey()).left);
                int distinctFilterArg = AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, (Integer)((Pair)entry.getKey()).right);
                RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
                if (distinctFilterArg > -1) {
                    expr = relBuilder.and(expr, relBuilder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, relBuilder.field(distinctFilterArg)));
                }
                nodes.add(relBuilder.alias(expr, "$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + distinctFilterArg)));
            }
            relBuilder.project(nodes);
        }
        int x = groupCount;
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            int newFilterArg;
            ImmutableIntList newArgList;
            SqlAggFunction aggregation;
            if (!aggCall.isDistinct()) {
                aggregation = SqlStdOperatorTable.MIN;
                newArgList = ImmutableIntList.of(x++);
                newFilterArg = (Integer)Objects.requireNonNull(filters.get(Pair.of(groupSet, -1)), "filters.get(Pair.of(groupSet, -1))");
            } else {
                aggregation = aggCall.getAggregation();
                newArgList = AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggCall.getArgList());
                ImmutableBitSet newGroupSet = ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(groupSet);
                newFilterArg = (Integer)Objects.requireNonNull(filters.get(Pair.of(newGroupSet, aggCall.filterArg)), "filters.get(of(newGroupSet, aggCall.filterArg))");
            }
            AggregateCall newCall = AggregateCall.create(aggregation, false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgList, newFilterArg, aggCall.distinctKeys, aggCall.collation, aggregate.getGroupCount(), relBuilder.peek(), null, aggCall.name);
            newCalls.add(newCall);
        }
        relBuilder.aggregate(relBuilder.groupKey(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, groupSet), (Iterable<? extends ImmutableBitSet>)AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSets())), (List<AggregateCall>)newCalls);
        relBuilder.convert(aggregate.getRowType(), true);
        call.transformTo(relBuilder.build());
    }

    static long groupValue(Collection<Integer> fullGroupSet, ImmutableBitSet groupSet) {
        long v = 0L;
        long x = 1L << fullGroupSet.size() - 1;
        assert (ImmutableBitSet.of(fullGroupSet).contains(groupSet));
        for (int i : fullGroupSet) {
            if (!groupSet.get(i)) {
                v |= x;
            }
            x >>= 1;
        }
        return v;
    }

    static ImmutableBitSet remap(ImmutableBitSet groupSet, ImmutableBitSet bitSet) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (Integer bit : bitSet) {
            builder.set(AggregateExpandDistinctAggregatesRule.remap(groupSet, bit));
        }
        return builder.build();
    }

    static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet, Iterable<ImmutableBitSet> bitSets) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (ImmutableBitSet bitSet : bitSets) {
            builder.add(AggregateExpandDistinctAggregatesRule.remap(groupSet, bitSet));
        }
        return builder.build();
    }

    private static List<Integer> remap(ImmutableBitSet groupSet, List<Integer> argList) {
        ImmutableIntList list = ImmutableIntList.of();
        for (int arg : argList) {
            list = list.append(AggregateExpandDistinctAggregatesRule.remap(groupSet, arg));
        }
        return list;
    }

    private static int remap(ImmutableBitSet groupSet, int arg) {
        return arg < 0 ? -1 : groupSet.indexOf(arg);
    }

    private static RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        AggregateExpandDistinctAggregatesRule.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList<AggregateCall> newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
        AggregateExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(cardinality), null, newAggCalls));
        return relBuilder;
    }

    private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<@Nullable RexInputRef> refs) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<RelDataTypeField> leftFields = n == 0 ? null : relBuilder.peek().getRowType().getFieldList();
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        AggregateExpandDistinctAggregatesRule.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>();
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int i = groupCount - 1;
        for (AggregateCall aggregateCall : aggCalls) {
            ++i;
            if (!aggregateCall.isDistinct() || !aggregateCall.getArgList().equals(argList)) continue;
            int argCount = aggregateCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (Integer arg : aggregateCall.getArgList()) {
                newArgs.add((Integer)Objects.requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
            }
            int newFilterArg = aggregateCall.filterArg < 0 ? -1 : (Integer)Objects.requireNonNull(sourceOf.get(aggregateCall.filterArg), () -> "sourceOf.get(" + aggCall.filterArg + ")");
            AggregateCall newAggCall = AggregateCall.create(aggregateCall.getAggregation(), false, aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), newArgs, newFilterArg, aggregateCall.distinctKeys, aggregateCall.collation, aggregateCall.getType(), aggregateCall.getName());
            assert (refs.get(i) == null);
            if (leftFields == null) {
                refs.set(i, new RexInputRef(groupCount + aggCallList.size(), newAggCall.getType()));
            } else {
                refs.set(i, new RexInputRef(leftFields.size() + groupCount + aggCallList.size(), newAggCall.getType()));
            }
            aggCallList.add(newAggCall);
        }
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (Integer key : aggregate.getGroupSet()) {
            map.put(key, map.size());
        }
        ImmutableBitSet immutableBitSet = aggregate.getGroupSet().permute(map);
        assert (immutableBitSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality())));
        List<ImmutableBitSet> newGroupingSets = null;
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), immutableBitSet, newGroupingSets, aggCallList));
        if (leftFields == null) {
            return;
        }
        List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
        ArrayList<RexNode> conditions = new ArrayList<RexNode>();
        for (i = 0; i < groupCount; ++i) {
            conditions.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
        }
        relBuilder.join(JoinRelType.INNER, conditions);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() && aggCall.getAggregation().getDistinctOptionality() != Optionality.IGNORED || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add(Objects.requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
            }
            AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.distinctKeys, aggCall.collation, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private static RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg, Map<Integer, Integer> sourceOf) {
        relBuilder.push(aggregate.getInput());
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
        for (int i : aggregate.getGroupSet()) {
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2(i, childFields));
        }
        for (Integer arg : argList) {
            if (filterArg >= 0) {
                RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
                Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
                RexNode condition = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, filterRef, (RexNode)argRef.left, rexBuilder.makeNullLiteral(((RexNode)argRef.left).getType()));
                sourceOf.put(arg, projects.size());
                projects.add(Pair.of(condition, "i$" + (String)argRef.right));
                continue;
            }
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2(arg, childFields));
        }
        relBuilder.project(Pair.left(projects), Pair.right(projects));
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(projects.size()), null, ImmutableList.of()));
        return relBuilder;
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateExpandDistinctAggregatesRule.Config.of().withOperandSupplier(b -> b.operand(LogicalAggregate.class).anyInputs());
        public static final Config JOIN = DEFAULT.withUsingGroupingSets(false);

        @Override
        default public AggregateExpandDistinctAggregatesRule toRule() {
            return new AggregateExpandDistinctAggregatesRule(this);
        }

        @ImmutableBeans.Property
        @ImmutableBeans.BooleanDefault(value=true)
        @Value.Default
        default public boolean isUsingGroupingSets() {
            return true;
        }

        public Config withUsingGroupingSets(boolean var1);
    }
}

