package io.trino.cost;

import io.trino.Session;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.RowNumberNode;
import java.util.Iterator;
import java.util.Optional;

/* loaded from: input_file:io/trino/cost/RowNumberStatsRule.class */
public class RowNumberStatsRule extends SimpleStatsRule<RowNumberNode> {
    private static final Pattern<RowNumberNode> PATTERN = Patterns.rowNumber();

    public RowNumberStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // io.trino.cost.ComposableStatsCalculator.Rule
    public Pattern<RowNumberNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(RowNumberNode rowNumberNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider, TableStatsProvider tableStatsProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(rowNumberNode.getSource());
        if (stats.isOutputRowCountUnknown()) {
            return Optional.empty();
        }
        double outputRowCount = stats.getOutputRowCount();
        double d = 1.0d;
        Iterator<Symbol> it = rowNumberNode.getPartitionBy().iterator();
        while (it.hasNext()) {
            SymbolStatsEstimate symbolStatistics = stats.getSymbolStatistics(it.next());
            d *= symbolStatistics.getDistinctValuesCount() + (symbolStatistics.getNullsFraction() == 0.0d ? 0 : 1);
        }
        double min = Math.min(outputRowCount, d);
        if (Double.isNaN(min)) {
            return Optional.empty();
        }
        double d2 = outputRowCount / min;
        if (rowNumberNode.getMaxRowCountPerPartition().isPresent()) {
            d2 = Math.min(d2, rowNumberNode.getMaxRowCountPerPartition().get().intValue());
        }
        double d3 = outputRowCount;
        if (rowNumberNode.getMaxRowCountPerPartition().isPresent()) {
            d3 = min * d2;
        }
        return Optional.of(PlanNodeStatsEstimate.buildFrom(stats).setOutputRowCount(d3).addSymbolStatistics(rowNumberNode.getRowNumberSymbol(), SymbolStatsEstimate.builder().setLowValue(1.0d).setDistinctValuesCount(d2).setNullsFraction(0.0d).setAverageRowSize(BigintType.BIGINT.getFixedSize()).build()).build());
    }
}
