/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.org.apache.calcite.linq4j.Ord;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.core.RelFactories;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin;
import com.hazelcast.org.apache.calcite.rel.metadata.RelMetadataQuery;
import com.hazelcast.org.apache.calcite.rel.rules.ImmutableAggregateJoinTransposeRule;
import com.hazelcast.org.apache.calcite.rel.rules.TransformationRule;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.SqlSplittableAggFunction;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.apache.calcite.util.mapping.Mapping;
import com.hazelcast.org.apache.calcite.util.mapping.Mappings;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateJoinTransposeRule
extends RelRule<Config>
implements TransformationRule {
    protected AggregateJoinTransposeRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory, boolean allowFunctions) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(aggregateClass, joinClass, allowFunctions));
    }

    @Deprecated
    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), false);
    }

    @Deprecated
    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, boolean allowFunctions) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
    }

    @Deprecated
    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
    }

    @Deprecated
    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), allowFunctions);
    }

    private static boolean isAggregateSupported(Aggregate aggregate, boolean allowFunctions) {
        if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
            return false;
        }
        if (aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
            return false;
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                return false;
            }
            if (aggregateCall.filterArg < 0 && !aggregateCall.isDistinct()) continue;
            return false;
        }
        return true;
    }

    private static boolean isJoinSupported(Join join, Aggregate aggregate) {
        return join.getJoinType() == JoinRelType.INNER || aggregate.getAggCallList().isEmpty();
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Join join = (Join)call.rel(1);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelBuilder relBuilder = call.builder();
        if (!AggregateJoinTransposeRule.isJoinSupported(join, aggregate)) {
            return;
        }
        ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        RelMetadataQuery mq = call.getMetadataQuery();
        ImmutableBitSet keyColumns = AggregateJoinTransposeRule.keyColumns(aggregateColumns, mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
        ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
        boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        ArrayList<Integer> leftKeys = new ArrayList<Integer>();
        ArrayList<Integer> rightKeys = new ArrayList<Integer>();
        ArrayList<Boolean> filterNulls = new ArrayList<Boolean>();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        HashMap map = new HashMap();
        ArrayList<Side> sides = new ArrayList<Side>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s2 = 0; s2 < 2; ++s2) {
            boolean unique;
            Side side = new Side();
            RelNode joinInput = join.getInput(s2);
            int fieldCount = joinInput.getRowType().getFieldCount();
            ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
            ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
                map.put(c.e, belowOffset + c.i);
            }
            Mappings.IdentityMapping mapping = s2 == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            if (!((Config)this.config).isAllowFunctions()) {
                assert (aggregate.getAggCallList().isEmpty());
                Util.discard(false);
                unique = true;
            } else {
                Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
                boolean bl = unique = unique0 != null && unique0 != false;
            }
            if (unique) {
                ++uniqueCount;
                side.aggregate = false;
                relBuilder.push(joinInput);
                ArrayList<RexNode> projects = new ArrayList<RexNode>();
                Iterator<Object> iterator = belowAggregateKey.iterator();
                while (iterator.hasNext()) {
                    Integer i = iterator.next();
                    projects.add(relBuilder.field(i));
                }
                for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) {
                    SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
                    SqlSplittableAggFunction splitter2 = aggregation.unwrapOrThrow(SqlSplittableAggFunction.class);
                    if (((AggregateCall)aggCall.e).getArgList().isEmpty() || !fieldSet.contains(ImmutableBitSet.of(((AggregateCall)aggCall.e).getArgList()))) continue;
                    RexNode singleton = splitter2.singleton(rexBuilder, joinInput.getRowType(), ((AggregateCall)aggCall.e).transform(mapping));
                    if (singleton instanceof RexInputRef) {
                        int index = ((RexInputRef)singleton).getIndex();
                        if (!belowAggregateKey.get(index)) {
                            projects.add(singleton);
                            side.split.put(aggCall.i, projects.size() - 1);
                            continue;
                        }
                        side.split.put(aggCall.i, index);
                        continue;
                    }
                    projects.add(singleton);
                    side.split.put(aggCall.i, projects.size() - 1);
                }
                relBuilder.project(projects);
                side.newInput = relBuilder.build();
            } else {
                side.aggregate = true;
                ArrayList<AggregateCall> belowAggCalls = new ArrayList<AggregateCall>();
                SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = AggregateJoinTransposeRule.registry(belowAggCalls);
                int oldGroupKeyCount = aggregate.getGroupCount();
                int newGroupKeyCount = belowAggregateKey.cardinality();
                for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                    AggregateCall call1;
                    SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
                    SqlSplittableAggFunction splitter3 = aggregation.unwrapOrThrow(SqlSplittableAggFunction.class);
                    if (fieldSet.contains(ImmutableBitSet.of(((AggregateCall)aggCall.e).getArgList()))) {
                        AggregateCall splitCall = splitter3.split((AggregateCall)aggCall.e, mapping);
                        call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
                    } else {
                        call1 = splitter3.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
                    }
                    if (call1 == null) continue;
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
                side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey), (List<AggregateCall>)belowAggCalls).build();
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2) {
            return;
        }
        Mapping mapping = (Mapping)Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
        RexNode newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, join.getCondition());
        RelNode side0 = Objects.requireNonNull(((Side)sides.get((int)0)).newInput, "sides.get(0).newInput");
        relBuilder.push(side0).push(Objects.requireNonNull(((Side)sides.get((int)1)).newInput, "sides.get(1).newInput")).join(join.getJoinType(), newCondition);
        ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
        int groupCount = aggregate.getGroupCount();
        int newLeftWidth = side0.getRowType().getFieldCount();
        ArrayList<RexNode> projects = new ArrayList<RexNode>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) {
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter4 = aggregation.unwrapOrThrow(SqlSplittableAggFunction.class);
            Integer leftSubTotal = ((Side)sides.get((int)0)).split.get(aggCall.i);
            Integer rightSubTotal = ((Side)sides.get((int)1)).split.get(aggCall.i);
            newAggCalls.add(splitter4.topSplit(rexBuilder, AggregateJoinTransposeRule.registry(projects), groupCount, relBuilder.peek().getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
        }
        relBuilder.project(projects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate && join.getJoinType() != JoinRelType.FULL) {
            ArrayList<RexInputRef> projects2 = new ArrayList<RexInputRef>();
            for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
                projects2.add(relBuilder.field(key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                newAggCall.getAggregation().maybeUnwrap(SqlSplittableAggFunction.class).ifPresent(splitter -> {
                    RelDataType rowType = relBuilder.peek().getRowType();
                    projects2.add((RexInputRef)splitter.singleton(rexBuilder, rowType, newAggCall));
                });
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), (Iterable<? extends ImmutableBitSet>)Mappings.apply2(mapping, aggregate.getGroupSets())), (List<AggregateCall>)newAggCalls);
        }
        call.transformTo(relBuilder.build());
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode predicate : predicates) {
            AggregateJoinTransposeRule.populateEquivalences(equivalence, predicate);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        switch (predicate.getKind()) {
            case EQUALS: {
                RexCall call = (RexCall)predicate;
                List<RexNode> operands = call.getOperands();
                if (!(operands.get(0) instanceof RexInputRef)) break;
                RexInputRef ref0 = (RexInputRef)operands.get(0);
                if (!(operands.get(1) instanceof RexInputRef)) break;
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                AggregateJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                AggregateJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
                break;
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.get(i0);
        if (bitSet == null) {
            bitSet = new BitSet();
            equivalence.put(i0, bitSet);
        }
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(List<E> list) {
        return e -> {
            int i = list.indexOf(e);
            if (i < 0) {
                i = list.size();
                list.add(e);
            }
            return i;
        };
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateJoinTransposeRule.Config.of().withOperandFor(LogicalAggregate.class, LogicalJoin.class, false);
        public static final Config EXTENDED = ImmutableAggregateJoinTransposeRule.Config.of().withOperandFor(LogicalAggregate.class, LogicalJoin.class, true);

        @Override
        default public AggregateJoinTransposeRule toRule() {
            return new AggregateJoinTransposeRule(this);
        }

        @Value.Default
        default public boolean isAllowFunctions() {
            return false;
        }

        public Config withAllowFunctions(boolean var1);

        default public Config withOperandFor(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, boolean allowFunctions) {
            return this.withAllowFunctions(allowFunctions).withOperandSupplier(b0 -> b0.operand(aggregateClass).predicate(agg -> AggregateJoinTransposeRule.isAggregateSupported(agg, allowFunctions)).oneInput(b1 -> b1.operand(joinClass).anyInputs())).as(Config.class);
        }
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        @Nullable RelNode newInput;
        boolean aggregate;

        private Side() {
        }
    }
}

