package io.trino.plugin.base.aggregation;

import com.google.common.collect.ImmutableSet;
import io.trino.matching.Match;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.ConnectorExpression;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:io/trino/plugin/base/aggregation/AggregateFunctionRewriter.class */
public final class AggregateFunctionRewriter<AggregationResult, ExpressionResult> {
    private final ConnectorExpressionRewriter<ExpressionResult> connectorExpressionRewriter;
    private final Set<AggregateFunctionRule<AggregationResult, ExpressionResult>> rules;

    public AggregateFunctionRewriter(ConnectorExpressionRewriter<ExpressionResult> connectorExpressionRewriter, Set<AggregateFunctionRule<AggregationResult, ExpressionResult>> set) {
        this.connectorExpressionRewriter = (ConnectorExpressionRewriter) Objects.requireNonNull(connectorExpressionRewriter, "connectorExpressionRewriter is null");
        this.rules = ImmutableSet.copyOf((Collection) Objects.requireNonNull(set, "rules is null"));
    }

    public Optional<AggregationResult> rewrite(final ConnectorSession connectorSession, AggregateFunction aggregateFunction, final Map<String, ColumnHandle> map) {
        Objects.requireNonNull(aggregateFunction, "aggregateFunction is null");
        Objects.requireNonNull(map, "assignments is null");
        AggregateFunctionRule.RewriteContext<ExpressionResult> rewriteContext = new AggregateFunctionRule.RewriteContext<ExpressionResult>() { // from class: io.trino.plugin.base.aggregation.AggregateFunctionRewriter.1
            @Override // io.trino.plugin.base.aggregation.AggregateFunctionRule.RewriteContext
            public Map<String, ColumnHandle> getAssignments() {
                return map;
            }

            @Override // io.trino.plugin.base.aggregation.AggregateFunctionRule.RewriteContext
            public ConnectorSession getSession() {
                return connectorSession;
            }

            @Override // io.trino.plugin.base.aggregation.AggregateFunctionRule.RewriteContext
            public Optional<ExpressionResult> rewriteExpression(ConnectorExpression connectorExpression) {
                return AggregateFunctionRewriter.this.connectorExpressionRewriter.rewrite(connectorSession, connectorExpression, map);
            }
        };
        for (AggregateFunctionRule<AggregationResult, ExpressionResult> aggregateFunctionRule : this.rules) {
            Iterator it = aggregateFunctionRule.getPattern().match(aggregateFunction, rewriteContext).iterator();
            while (it.hasNext()) {
                Optional<AggregationResult> rewrite = aggregateFunctionRule.rewrite(aggregateFunction, ((Match) it.next()).captures(), rewriteContext);
                if (rewrite.isPresent()) {
                    return rewrite;
                }
            }
        }
        return Optional.empty();
    }
}
