package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.tree.BooleanLiteral;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.class */
public class DecorrelateLeftUnnestWithGlobalAggregation implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).matching(correlatedJoinNode -> {
        return correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER || correlatedJoinNode.getType() == CorrelatedJoinNode.Type.LEFT;
    });

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateLeftUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(planNode -> {
            return (planNode instanceof ProjectNode) || isGroupedAggregation(planNode);
        }).findFirst().isEmpty()) {
            return Rule.Result.empty();
        }
        Optional findFirst = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(planNode2 -> {
            return isSupportedUnnest(planNode2, correlatedJoinNode.getCorrelation(), context.getLookup());
        }).recurseOnlyWhen(planNode3 -> {
            return (planNode3 instanceof ProjectNode) || isGlobalAggregation(planNode3) || isGroupedAggregation(planNode3);
        }).findFirst();
        if (findFirst.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode) findFirst.get();
        PlanNode assignUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type) BigintType.BIGINT));
        PlanNode resolve = context.getLookup().resolve(unnestNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            assignUniqueId = new ProjectNode(projectNode.getId(), assignUniqueId, Assignments.builder().putIdentities(assignUniqueId.getOutputSymbols()).putAll(projectNode.getAssignments()).build());
        }
        PlanNode rewriteNodeSequence = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), assignUniqueId.getOutputSymbols(), new UnnestNode(context.getIdAllocator().getNextId(), assignUniqueId, assignUniqueId.getOutputSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), JoinNode.Type.LEFT, Optional.empty()), unnestNode.getId(), context.getLookup());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewriteNodeSequence, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewriteNodeSequence));
    }

    private static boolean isGlobalAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasSingleGlobalAggregation() && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isGroupedAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSupportedUnnest(PlanNode planNode, List<Symbol> list, Lookup lookup) {
        if (!(planNode instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode) planNode;
        List list2 = (List) unnestNode.getMappings().stream().map((v0) -> {
            return v0.getInput();
        }).collect(ImmutableList.toImmutableList());
        PlanNode resolve = lookup.resolve(unnestNode.getSource());
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && (ImmutableSet.copyOf(list).containsAll(list2) || ((resolve instanceof ProjectNode) && ImmutableSet.copyOf(list).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) resolve).getAssignments().getExpressions())))) && unnestNode.getJoinType() == JoinNode.Type.LEFT && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(BooleanLiteral.TRUE_LITERAL));
    }

    private static PlanNode rewriteNodeSequence(PlanNode planNode, List<Symbol> list, PlanNode planNode2, PlanNodeId planNodeId, Lookup lookup) {
        if (planNode.getId().equals(planNodeId)) {
            return planNode2;
        }
        PlanNode rewriteNodeSequence = rewriteNodeSequence(lookup.resolve((PlanNode) Iterables.getOnlyElement(planNode.getSources())), list, planNode2, planNodeId, lookup);
        if (planNode instanceof AggregationNode) {
            return withGrouping((AggregationNode) planNode, list, rewriteNodeSequence);
        }
        if (!(planNode instanceof ProjectNode)) {
            throw new IllegalStateException("unexpected node: " + String.valueOf(planNode));
        }
        ProjectNode projectNode = (ProjectNode) planNode;
        return new ProjectNode(projectNode.getId(), rewriteNodeSequence, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(list).build());
    }

    private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> list, PlanNode planNode) {
        return AggregationNode.singleAggregation(aggregationNode.getId(), planNode, aggregationNode.getAggregations(), AggregationNode.singleGroupingSet((List) Streams.concat(new Stream[]{list.stream(), aggregationNode.getGroupingKeys().stream()}).distinct().collect(ImmutableList.toImmutableList())));
    }
}
