package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.DistinctOutputQueryUtil;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.Row;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.class */
public class PushAggregationThroughOuterJoin implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN)));

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin$MappedAggregationInfo.class */
    public static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<Symbol, Symbol> symbolMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<Symbol, Symbol> map) {
            this.aggregationNode = aggregationNode;
            this.symbolMapping = map;
        }

        public Map<Symbol, Symbol> getSymbolMapping() {
            return this.symbolMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushAggregationThroughOuterJoin(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        Preconditions.checkArgument(aggregationNode.getHashSymbol().isEmpty(), "unexpected hash symbol");
        JoinNode joinNode = (JoinNode) captures.get(JOIN);
        if (!joinNode.getFilter().isPresent() && ((joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT) && groupsOnAllColumns(aggregationNode, getOuterTable(joinNode).getOutputSymbols()))) {
            PlanNode resolve = context.getLookup().resolve(getOuterTable(joinNode));
            Lookup lookup = context.getLookup();
            Objects.requireNonNull(lookup);
            if (DistinctOutputQueryUtil.isDistinct(resolve, lookup::resolve) && isAggregationOnSymbols(aggregationNode, getInnerTable(joinNode))) {
                AggregationNode build = AggregationNode.builderFrom(aggregationNode).setSource(getInnerTable(joinNode)).setGroupingSets(AggregationNode.singleGroupingSet((List) joinNode.getCriteria().stream().map(joinNode.getType() == JoinNode.Type.RIGHT ? (v0) -> {
                    return v0.getLeft();
                } : (v0) -> {
                    return v0.getRight();
                }).collect(ImmutableList.toImmutableList()))).setPreGroupedSymbols(ImmutableList.of()).build();
                Optional<PlanNode> coalesceWithNullAggregation = coalesceWithNullAggregation(build, joinNode.getType() == JoinNode.Type.LEFT ? new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), build, joinNode.getCriteria(), joinNode.getLeft().getOutputSymbols(), ImmutableList.copyOf(build.getAggregations().keySet()), false, joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()) : new JoinNode(joinNode.getId(), joinNode.getType(), build, joinNode.getRight(), joinNode.getCriteria(), ImmutableList.copyOf(build.getAggregations().keySet()), joinNode.getRight().getOutputSymbols(), false, joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()), context.getSymbolAllocator(), context.getIdAllocator());
                return coalesceWithNullAggregation.isEmpty() ? Rule.Result.empty() : Rule.Result.ofPlanNode(coalesceWithNullAggregation.get());
            }
        }
        return Rule.Result.empty();
    }

    private static PlanNode getInnerTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType() == JoinNode.Type.LEFT ? joinNode.getRight() : joinNode.getLeft();
    }

    private static PlanNode getOuterTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType() == JoinNode.Type.LEFT ? joinNode.getLeft() : joinNode.getRight();
    }

    private static boolean groupsOnAllColumns(AggregationNode aggregationNode, List<Symbol> list) {
        return aggregationNode.getGroupingSetCount() == 1 && new HashSet(aggregationNode.getGroupingKeys()).equals(new HashSet(list));
    }

    private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode planNode, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        MappedAggregationInfo createAggregationOverNull = createAggregationOverNull(aggregationNode, symbolAllocator, planNodeIdAllocator);
        AggregationNode aggregation = createAggregationOverNull.getAggregation();
        Map<Symbol, Symbol> symbolMapping = createAggregationOverNull.getSymbolMapping();
        JoinNode joinNode = new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.INNER, planNode, aggregation, ImmutableList.of(), planNode.getOutputSymbols(), aggregation.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
        Assignments.Builder builder = Assignments.builder();
        for (Symbol symbol : planNode.getOutputSymbols()) {
            if (aggregationNode.getAggregations().containsKey(symbol)) {
                builder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), symbolMapping.get(symbol).toSymbolReference(), new Expression[0]));
            } else {
                builder.putIdentity(symbol);
            }
        }
        return Optional.of(new ProjectNode(planNodeIdAllocator.getNextId(), joinNode, builder.build()));
    }

    private MappedAggregationInfo createAggregationOverNull(AggregationNode aggregationNode, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        ImmutableMap.Builder builder3 = ImmutableMap.builder();
        for (Symbol symbol : aggregationNode.getSource().getOutputSymbols()) {
            Type type = symbolAllocator.getTypes().get(symbol);
            builder2.add(new Cast(new NullLiteral(), TypeSignatureTranslator.toSqlType(type)));
            Symbol newSymbol = symbolAllocator.newSymbol("null", type);
            builder.add(newSymbol);
            builder3.put(symbol, newSymbol);
        }
        ValuesNode valuesNode = new ValuesNode(planNodeIdAllocator.getNextId(), builder.build(), ImmutableList.of(new Row(builder2.build())));
        ImmutableMap.Builder builder4 = ImmutableMap.builder();
        ImmutableMap.Builder builder5 = ImmutableMap.builder();
        SymbolMapper symbolMapper = SymbolMapper.symbolMapper(builder3.buildOrThrow());
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            Symbol key = entry.getKey();
            AggregationNode.Aggregation map = symbolMapper.map(entry.getValue());
            Symbol newSymbol2 = symbolAllocator.newSymbol(map.getResolvedFunction().getSignature().getName(), symbolAllocator.getTypes().get(key));
            builder5.put(newSymbol2, map);
            builder4.put(key, newSymbol2);
        }
        return new MappedAggregationInfo(AggregationNode.singleAggregation(planNodeIdAllocator.getNextId(), valuesNode, builder5.buildOrThrow(), AggregationNode.globalAggregation()), builder4.buildOrThrow());
    }

    private static boolean isAggregationOnSymbols(AggregationNode aggregationNode, PlanNode planNode) {
        ImmutableSet copyOf = ImmutableSet.copyOf(planNode.getOutputSymbols());
        return aggregationNode.getAggregations().values().stream().allMatch(aggregation -> {
            return copyOf.containsAll(SymbolsExtractor.extractUnique(aggregation));
        });
    }
}
