package io.trino.sql.planner.optimizations;

import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.ArrayList;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/optimizations/LimitPushDown.class */
public class LimitPushDown implements PlanOptimizer {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/optimizations/LimitPushDown$LimitContext.class */
    public static class LimitContext {
        private final long count;
        private final boolean partial;

        public LimitContext(long j, boolean z) {
            this.count = j;
            this.partial = z;
        }

        public long getCount() {
            return this.count;
        }

        public boolean isPartial() {
            return this.partial;
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("count", this.count).add("partial", this.partial).toString();
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/optimizations/LimitPushDown$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<LimitContext> {
        private final PlanNodeIdAllocator idAllocator;

        private Rewriter(PlanNodeIdAllocator planNodeIdAllocator) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        }

        @Override // io.trino.sql.planner.plan.SimplePlanRewriter, io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(planNode);
            LimitContext limitContext = rewriteContext.get();
            if (limitContext != null) {
                defaultRewrite = new LimitNode(this.idAllocator.getNextId(), defaultRewrite, limitContext.getCount(), limitContext.isPartial());
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitLimit(LimitNode limitNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            long count = limitNode.getCount();
            LimitContext limitContext = rewriteContext.get();
            if (limitContext != null) {
                count = Math.min(count, limitContext.getCount());
            }
            if (count == 0) {
                return new ValuesNode(this.idAllocator.getNextId(), limitNode.getOutputSymbols(), ImmutableList.of());
            }
            if (limitNode.requiresPreSortedInputs() || (limitNode.isWithTies() && (limitContext == null || limitNode.getCount() < limitContext.getCount()))) {
                return rewriteContext.defaultRewrite(limitNode, rewriteContext.get());
            }
            return rewriteContext.rewrite(limitNode.getSource(), new LimitContext(count, limitNode.isPartial() && (limitContext == null || limitContext.isPartial())));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        @Deprecated
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            LimitContext limitContext = rewriteContext.get();
            if (limitContext != null && aggregationNode.getAggregations().isEmpty() && !aggregationNode.getGroupingKeys().isEmpty() && aggregationNode.getOutputSymbols().size() == aggregationNode.getGroupingKeys().size() && aggregationNode.getOutputSymbols().containsAll(aggregationNode.getGroupingKeys())) {
                PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource());
                return new DistinctLimitNode(this.idAllocator.getNextId(), rewrite, limitContext.getCount(), false, rewrite.getOutputSymbols(), Optional.empty());
            }
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(aggregationNode);
            if (limitContext != null) {
                defaultRewrite = new LimitNode(this.idAllocator.getNextId(), defaultRewrite, limitContext.getCount(), limitContext.isPartial());
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            return rewriteContext.defaultRewrite(markDistinctNode, rewriteContext.get());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            return rewriteContext.defaultRewrite(projectNode, rewriteContext.get());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitTopN(TopNNode topNNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            LimitContext limitContext = rewriteContext.get();
            PlanNode rewrite = rewriteContext.rewrite(topNNode.getSource());
            if (rewrite == topNNode.getSource() && limitContext == null) {
                return topNNode;
            }
            long count = topNNode.getCount();
            if (limitContext != null) {
                count = Math.min(count, limitContext.getCount());
            }
            return new TopNNode(topNNode.getId(), rewrite, count, topNNode.getOrderingScheme(), topNNode.getStep());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        @Deprecated
        public PlanNode visitSort(SortNode sortNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            LimitContext limitContext = rewriteContext.get();
            PlanNode rewrite = rewriteContext.rewrite(sortNode.getSource());
            return limitContext != null ? new TopNNode(sortNode.getId(), rewrite, limitContext.getCount(), sortNode.getOrderingScheme(), TopNNode.Step.SINGLE) : rewrite != sortNode.getSource() ? new SortNode(sortNode.getId(), rewrite, sortNode.getOrderingScheme(), sortNode.isPartial()) : sortNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitUnion(UnionNode unionNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            LimitContext limitContext = rewriteContext.get();
            LimitContext limitContext2 = limitContext != null ? new LimitContext(limitContext.getCount(), true) : null;
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < unionNode.getSources().size(); i++) {
                arrayList.add(rewriteContext.rewrite(unionNode.getSources().get(i), limitContext2));
            }
            PlanNode unionNode2 = new UnionNode(unionNode.getId(), arrayList, unionNode.getSymbolMapping(), unionNode.getOutputSymbols());
            if (limitContext != null) {
                unionNode2 = new LimitNode(this.idAllocator.getNextId(), unionNode2, limitContext.getCount(), limitContext.isPartial());
            }
            return unionNode2;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<LimitContext> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(semiJoinNode.getSource(), rewriteContext.get());
            return rewrite != semiJoinNode.getSource() ? new SemiJoinNode(semiJoinNode.getId(), rewrite, semiJoinNode.getFilteringSource(), semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), semiJoinNode.getDynamicFilterId()) : semiJoinNode;
        }
    }

    @Override // io.trino.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, TableStatsProvider tableStatsProvider) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(planNodeIdAllocator), planNode, null);
    }
}
