/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.linkedin.coral.calcite.;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlPostfixOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;

public class AggregateCaseToFilterRule
extends RelOptRule {
    public static final AggregateCaseToFilterRule INSTANCE = new AggregateCaseToFilterRule(RelFactories.LOGICAL_BUILDER, null);

    protected AggregateCaseToFilterRule(RelBuilderFactory relBuilderFactory, String description) {
        super(AggregateCaseToFilterRule.operand(Aggregate.class, AggregateCaseToFilterRule.operand(Project.class, AggregateCaseToFilterRule.any()), new RelOptRuleOperand[0]), relBuilderFactory, description);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            int singleArg = AggregateCaseToFilterRule.soleArgument(aggregateCall);
            if (singleArg < 0 || !AggregateCaseToFilterRule.isThreeArgCase(project.getProjects().get(singleArg))) continue;
            return true;
        }
        return false;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>(aggregate.getAggCallList().size());
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>(project.getProjects());
        ArrayList<RexNode> newCasts = new ArrayList<RexNode>();
        Iterator<Object> iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int fieldNumber = iterator.next();
            newCasts.add(rexBuilder.makeInputRef(project.getProjects().get(fieldNumber).getType(), fieldNumber));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            AggregateCall newCall = this.transform(aggregateCall, project, newProjects);
            int i = newCasts.size();
            RelDataType oldType = aggregate.getRowType().getFieldList().get(i).getType();
            if (newCall == null) {
                newCalls.add(aggregateCall);
                newCasts.add(rexBuilder.makeInputRef(oldType, i));
                continue;
            }
            newCalls.add(newCall);
            newCasts.add(rexBuilder.makeCast(oldType, rexBuilder.makeInputRef(newCall.getType(), i)));
        }
        if (newCalls.equals(aggregate.getAggCallList())) {
            return;
        }
        RelBuilder relBuilder = call.builder().push(project.getInput()).project(newProjects);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, (List<AggregateCall>)newCalls).convert(aggregate.getRowType(), false);
        call.transformTo(relBuilder.build());
        call.getPlanner().setImportance(aggregate, 0.0);
    }

    @.Nullable
    private AggregateCall transform(AggregateCall aggregateCall, Project project, List<RexNode> newProjects) {
        int singleArg = AggregateCaseToFilterRule.soleArgument(aggregateCall);
        if (singleArg < 0) {
            return null;
        }
        RexNode rexNode = project.getProjects().get(singleArg);
        if (!AggregateCaseToFilterRule.isThreeArgCase(rexNode)) {
            return null;
        }
        RelOptCluster cluster = project.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RexCall caseCall = (RexCall)rexNode;
        boolean flip = RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(1)) && !RexLiteral.isNullLiteral((RexNode)caseCall.operands.get(2));
        RexNode arg1 = (RexNode)caseCall.operands.get(flip ? 2 : 1);
        RexNode arg2 = (RexNode)caseCall.operands.get(flip ? 1 : 2);
        SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE;
        RexNode filterFromCase = rexBuilder.makeCall((SqlOperator)op, (RexNode)caseCall.operands.get(0));
        RexNode filter = aggregateCall.filterArg >= 0 ? rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, project.getProjects().get(aggregateCall.filterArg), filterFromCase) : filterFromCase;
        SqlKind kind = aggregateCall.getAggregation().getKind();
        if (aggregateCall.isDistinct()) {
            if (kind == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
                newProjects.add(arg1);
                newProjects.add(filter);
                return AggregateCall.create(SqlStdOperatorTable.COUNT, true, false, false, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
            }
            return null;
        }
        if (kind == SqlKind.COUNT && arg1.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral(arg1) && RexLiteral.isNullLiteral(arg2)) {
            newProjects.add(filter);
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableList.of(), newProjects.size() - 1, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
        }
        if (kind == SqlKind.SUM && AggregateCaseToFilterRule.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 && AggregateCaseToFilterRule.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
            newProjects.add(filter);
            RelDataTypeFactory typeFactory = cluster.getTypeFactory();
            RelDataType dataType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableList.of(), newProjects.size() - 1, RelCollations.EMPTY, dataType, aggregateCall.getName());
        }
        if (RexLiteral.isNullLiteral(arg2) && aggregateCall.getAggregation().allowsFilter() || kind == SqlKind.SUM && AggregateCaseToFilterRule.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
            newProjects.add(arg1);
            newProjects.add(filter);
            return AggregateCall.create(aggregateCall.getAggregation(), false, false, false, ImmutableList.of(Integer.valueOf(newProjects.size() - 2)), newProjects.size() - 1, RelCollations.EMPTY, aggregateCall.getType(), aggregateCall.getName());
        }
        return null;
    }

    private static int soleArgument(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1 ? aggregateCall.getArgList().get(0) : -1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall)rexNode).operands.size() == 3;
    }

    private static boolean isIntLiteral(RexNode rexNode) {
        return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains((Object)rexNode.getType().getSqlTypeName());
    }
}

