package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.ExceptNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.UnionNode;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.class */
public class PruneDistinctAggregation implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(PruneDistinctAggregation::isDistinctOperator);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PruneDistinctAggregation$DistinctAggregationRewriter.class */
    public static class DistinctAggregationRewriter extends PlanVisitor<PlanNode, Boolean> {
        private final Lookup lookup;
        private boolean rewritten = false;

        public DistinctAggregationRewriter(Lookup lookup) {
            this.lookup = lookup;
        }

        public boolean isRewritten() {
            return this.rewritten;
        }

        private PlanNode rewriteChildren(PlanNode planNode, Boolean bool) {
            Stream<PlanNode> stream = planNode.getSources().stream();
            Lookup lookup = this.lookup;
            Objects.requireNonNull(lookup);
            return ChildReplacer.replaceChildren(planNode, (List) stream.flatMap(lookup::resolveGroup).map(planNode2 -> {
                return (PlanNode) planNode2.accept(this, bool);
            }).collect(Collectors.toList()));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, Boolean bool) {
            return rewriteChildren(planNode, false);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitUnion(UnionNode unionNode, Boolean bool) {
            return rewriteChildren(unionNode, bool);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitIntersect(IntersectNode intersectNode, Boolean bool) {
            return intersectNode.isDistinct() ? rewriteChildren(intersectNode, bool) : visitPlan((PlanNode) intersectNode, bool);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExcept(ExceptNode exceptNode, Boolean bool) {
            return exceptNode.isDistinct() ? rewriteChildren(exceptNode, bool) : visitPlan((PlanNode) exceptNode, bool);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, Boolean bool) {
            boolean isDistinctOperator = PruneDistinctAggregation.isDistinctOperator(aggregationNode);
            PlanNode planNode = (PlanNode) Iterables.getOnlyElement((Iterable) this.lookup.resolveGroup(aggregationNode.getSource()).map(planNode2 -> {
                return (PlanNode) planNode2.accept(this, Boolean.valueOf(isDistinctOperator));
            }).collect(Collectors.toList()));
            if (!bool.booleanValue() || !isDistinctOperator) {
                return AggregationNode.builderFrom(aggregationNode).setSource(planNode).setPreGroupedSymbols(ImmutableList.of()).build();
            }
            this.rewritten = true;
            return planNode;
        }
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        DistinctAggregationRewriter distinctAggregationRewriter = new DistinctAggregationRewriter(lookup);
        Stream<PlanNode> stream = aggregationNode.getSources().stream();
        Objects.requireNonNull(lookup);
        return distinctAggregationRewriter.isRewritten() ? Rule.Result.ofPlanNode(ChildReplacer.replaceChildren(aggregationNode, (List) stream.flatMap(lookup::resolveGroup).map(planNode -> {
            return (PlanNode) planNode.accept(distinctAggregationRewriter, true);
        }).collect(Collectors.toList()))) : Rule.Result.empty();
    }

    private static boolean isDistinctOperator(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().isEmpty();
    }
}
