package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.LambdaExpression;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.class */
public class PushPartialAggregationThroughExchange implements Rule<AggregationNode> {
    private final PlannerContext plannerContext;
    private static final Capture<ExchangeNode> EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.exchange().matching(exchangeNode -> {
        return exchangeNode.getOrderingScheme().isEmpty();
    }).capturedAs(EXCHANGE_NODE)));

    public PushPartialAggregationThroughExchange(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode) captures.get(EXCHANGE_NODE);
        boolean isDecomposable = aggregationNode.isDecomposable(context.getSession(), this.plannerContext.getMetadata());
        if (aggregationNode.getStep() == AggregationNode.Step.SINGLE && aggregationNode.hasEmptyGroupingSet() && aggregationNode.hasNonEmptyGroupingSet() && exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            Preconditions.checkState(isDecomposable, "Distributed aggregation with empty grouping set requires partial but functions are not decomposable");
            return Rule.Result.ofPlanNode(split(aggregationNode, context));
        }
        if (!isDecomposable || !SystemSessionProperties.preferPartialAggregation(context.getSession())) {
            return Rule.Result.empty();
        }
        if ((exchangeNode.getType() != ExchangeNode.Type.GATHER && exchangeNode.getType() != ExchangeNode.Type.REPARTITION) || exchangeNode.getPartitioningScheme().isReplicateNullsAndAny()) {
            return Rule.Result.empty();
        }
        if (exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            if (!aggregationNode.getGroupingKeys().containsAll((List) exchangeNode.getPartitioningScheme().getPartitioning().getArguments().stream().filter((v0) -> {
                return v0.isVariable();
            }).map((v0) -> {
                return v0.getColumn();
            }).collect(Collectors.toList()))) {
                return Rule.Result.empty();
            }
        }
        if (aggregationNode.getHashSymbol().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
            return Rule.Result.empty();
        }
        switch (aggregationNode.getStep()) {
            case SINGLE:
                return Rule.Result.ofPlanNode(split(aggregationNode, context));
            case PARTIAL:
                return Rule.Result.ofPlanNode(pushPartial(aggregationNode, exchangeNode, context));
            default:
                return Rule.Result.empty();
        }
    }

    private PlanNode pushPartial(AggregationNode aggregationNode, ExchangeNode exchangeNode, Rule.Context context) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < exchangeNode.getSources().size(); i++) {
            PlanNode planNode = exchangeNode.getSources().get(i);
            SymbolMapper.Builder builder = SymbolMapper.builder();
            for (int i2 = 0; i2 < exchangeNode.getOutputSymbols().size(); i2++) {
                Symbol symbol = exchangeNode.getOutputSymbols().get(i2);
                Symbol symbol2 = exchangeNode.getInputs().get(i).get(i2);
                if (!symbol.equals(symbol2)) {
                    builder.put(symbol, symbol2);
                }
            }
            SymbolMapper build = builder.build();
            AggregationNode map = build.map(aggregationNode, planNode, context.getIdAllocator().getNextId());
            Assignments.Builder builder2 = Assignments.builder();
            for (Symbol symbol3 : aggregationNode.getOutputSymbols()) {
                builder2.put(symbol3, build.map(symbol3).toSymbolReference());
            }
            arrayList.add(new ProjectNode(context.getIdAllocator().getNextId(), map, builder2.build()));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Verify.verify(aggregationNode.getOutputSymbols().equals(((PlanNode) it.next()).getOutputSymbols()));
        }
        return new ExchangeNode(context.getIdAllocator().getNextId(), exchangeNode.getType(), exchangeNode.getScope(), new PartitioningScheme(exchangeNode.getPartitioningScheme().getPartitioning(), aggregationNode.getOutputSymbols(), exchangeNode.getPartitioningScheme().getHashColumn(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition(), exchangeNode.getPartitioningScheme().getPartitionCount()), arrayList, ImmutableList.copyOf(Collections.nCopies(arrayList.size(), aggregationNode.getOutputSymbols())), Optional.empty());
    }

    private PlanNode split(AggregationNode aggregationNode, Rule.Context context) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            ResolvedFunction resolvedFunction = value.getResolvedFunction();
            Stream stream = this.plannerContext.getMetadata().getAggregationFunctionMetadata(context.getSession(), resolvedFunction).getIntermediateTypes().stream();
            TypeManager typeManager = this.plannerContext.getTypeManager();
            Objects.requireNonNull(typeManager);
            List list = (List) stream.map(typeManager::getType).collect(ImmutableList.toImmutableList());
            Symbol newSymbol = context.getSymbolAllocator().newSymbol(resolvedFunction.getSignature().getName(), list.size() == 1 ? (Type) list.get(0) : RowType.anonymous(list));
            Preconditions.checkState(value.getOrderingScheme().isEmpty(), "Aggregate with ORDER BY does not support partial aggregation");
            hashMap.put(newSymbol, new AggregationNode.Aggregation(resolvedFunction, value.getArguments(), value.isDistinct(), value.getFilter(), value.getOrderingScheme(), value.getMask()));
            Symbol key = entry.getKey();
            ImmutableList.Builder add = ImmutableList.builder().add(newSymbol.toSymbolReference());
            Stream<Expression> stream2 = value.getArguments().stream();
            Class<LambdaExpression> cls = LambdaExpression.class;
            Objects.requireNonNull(LambdaExpression.class);
            hashMap2.put(key, new AggregationNode.Aggregation(resolvedFunction, add.addAll((Iterable) stream2.filter((v1) -> {
                return r7.isInstance(v1);
            }).collect(ImmutableList.toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return new AggregationNode(aggregationNode.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), hashMap, aggregationNode.getGroupingSets(), ImmutableList.of(), AggregationNode.Step.PARTIAL, aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()), hashMap2, aggregationNode.getGroupingSets(), ImmutableList.of(), AggregationNode.Step.FINAL, aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
    }
}
