package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.class */
public class MultipleDistinctAggregationToMarkDistinct implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> {
            return aggregation.isDistinct() && (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent());
        });
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(aggregation -> {
            return aggregation.isDistinct();
        }).map((v0) -> {
            return v0.getArguments();
        }).map((v1) -> {
            return new HashSet(v1);
        }).distinct().count() > 1;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long count = aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).count();
        return count > 0 && count < ((long) aggregationNode.getAggregations().size());
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
            return Rule.Result.empty();
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PlanNode source = aggregationNode.getSource();
        for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) entry.getValue();
            if (!aggregation.isDistinct() || aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
                hashMap2.put(entry.getKey(), aggregation);
            } else {
                Stream map = aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression);
                PlanVariableAllocator variableAllocator = context.getVariableAllocator();
                variableAllocator.getClass();
                Set set = (Set) map.map(variableAllocator::toVariableReference).collect(Collectors.toSet());
                VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) hashMap.get(set);
                if (variableReferenceExpression == null) {
                    variableReferenceExpression = context.getVariableAllocator().newVariable(((VariableReferenceExpression) Iterables.getLast(set)).getName(), (Type) BooleanType.BOOLEAN, "distinct");
                    hashMap.put(set, variableReferenceExpression);
                    ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(aggregationNode.getGroupingKeys()).addAll(set);
                    Optional groupIdVariable = aggregationNode.getGroupIdVariable();
                    addAll.getClass();
                    groupIdVariable.ifPresent((v1) -> {
                        r1.add(v1);
                    });
                    source = new MarkDistinctNode(source.getSourceLocation(), context.getIdAllocator().getNextId(), source, variableReferenceExpression, ImmutableList.copyOf(addAll.build()), Optional.empty());
                }
                hashMap2.put(entry.getKey(), new AggregationNode.Aggregation(aggregation.getCall(), aggregation.getFilter(), aggregation.getOrderBy(), false, Optional.of(variableReferenceExpression)));
            }
        }
        return Rule.Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), aggregationNode.getId(), source, hashMap2, aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
    }
}
