package io.trino.cost;

import com.google.common.collect.MoreCollectors;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Pattern;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/cost/FilterProjectAggregationStatsRule.class */
public class FilterProjectAggregationStatsRule extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;

    public FilterProjectAggregationStatsRule(StatsNormalizer statsNormalizer, FilterStatsCalculator filterStatsCalculator) {
        super(statsNormalizer);
        this.filterStatsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null");
    }

    @Override // io.trino.cost.ComposableStatsCalculator.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.trino.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode filterNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider, TableStatsProvider tableStatsProvider) {
        AggregationNode aggregationNode;
        if (!SystemSessionProperties.isNonEstimatablePredicateApproximationEnabled(session)) {
            return Optional.empty();
        }
        PlanNode resolveGroup = resolveGroup(lookup, filterNode.getSource());
        if (resolveGroup instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolveGroup;
            if (!projectNode.isIdentity()) {
                return Optional.empty();
            }
            PlanNode resolveGroup2 = resolveGroup(lookup, projectNode.getSource());
            if (!(resolveGroup2 instanceof AggregationNode)) {
                return Optional.empty();
            }
            aggregationNode = (AggregationNode) resolveGroup2;
        } else {
            if (!(resolveGroup instanceof AggregationNode)) {
                return Optional.empty();
            }
            aggregationNode = (AggregationNode) resolveGroup;
        }
        return calculate(filterNode, aggregationNode, statsProvider, session, typeProvider);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, AggregationNode aggregationNode, StatsProvider statsProvider, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate filterStats = this.filterStatsCalculator.filterStats(statsProvider.getStats(filterNode.getSource()), filterNode.getPredicate(), session, typeProvider);
        if (!filterStats.isOutputRowCountUnknown()) {
            return Optional.of(filterStats);
        }
        PlanNodeStatsEstimate stats = statsProvider.getStats(aggregationNode);
        return stats.isOutputRowCountUnknown() ? Optional.of(filterStats) : Optional.of(stats.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * 0.9d);
        }));
    }

    private static PlanNode resolveGroup(Lookup lookup, PlanNode planNode) {
        return planNode instanceof GroupReference ? (PlanNode) lookup.resolveGroup(planNode).collect(MoreCollectors.onlyElement()) : planNode;
    }
}
