package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WhenClause;
import io.trino.sql.util.AstUtils;
import jakarta.annotation.Nullable;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.class */
public class TransformCorrelatedInPredicateToJoin implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$Decorrelated.class */
    public static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> list, PlanNode planNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "correlatedPredicates is null"));
            this.decorrelatedNode = (PlanNode) Objects.requireNonNull(planNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin$DecorrelatingVisitor.class */
    public static class DecorrelatingVisitor extends PlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<Symbol> correlation;

        public DecorrelatingVisitor(Lookup lookup, Iterable<Symbol> iterable) {
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf((Iterable) Objects.requireNonNull(iterable, "correlation is null"));
        }

        public Optional<Decorrelated> decorrelate(PlanNode planNode) {
            return (Optional) this.lookup.resolve(planNode).accept(this, planNode);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitProject(ProjectNode projectNode, PlanNode planNode) {
            return isCorrelatedShallowly(projectNode) ? Optional.empty() : decorrelate(projectNode.getSource()).map(decorrelated -> {
                Assignments.Builder putAll = Assignments.builder().putAll(projectNode.getAssignments());
                Stream<R> flatMap = decorrelated.getCorrelatedPredicates().stream().flatMap((v0) -> {
                    return AstUtils.preOrder(v0);
                });
                Class<SymbolReference> cls = SymbolReference.class;
                Objects.requireNonNull(SymbolReference.class);
                Stream filter = flatMap.filter((v1) -> {
                    return r1.isInstance(v1);
                });
                Class<SymbolReference> cls2 = SymbolReference.class;
                Objects.requireNonNull(SymbolReference.class);
                filter.map((v1) -> {
                    return r1.cast(v1);
                }).filter(symbolReference -> {
                    return !this.correlation.contains(Symbol.from(symbolReference));
                }).forEach(symbolReference2 -> {
                    putAll.putIdentity(Symbol.from(symbolReference2));
                });
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), new ProjectNode(projectNode.getId(), decorrelated.getDecorrelatedNode(), putAll.build()));
            });
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitFilter(FilterNode filterNode, PlanNode planNode) {
            return decorrelate(filterNode.getSource()).map(decorrelated -> {
                return new Decorrelated(ImmutableList.builder().addAll(decorrelated.getCorrelatedPredicates()).add(filterNode.getPredicate()).build(), decorrelated.getDecorrelatedNode());
            });
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Optional<Decorrelated> visitPlan(PlanNode planNode, PlanNode planNode2) {
            return isCorrelatedRecursively(planNode) ? Optional.empty() : Optional.of(new Decorrelated(ImmutableList.of(), planNode2));
        }

        private boolean isCorrelatedRecursively(PlanNode planNode) {
            if (isCorrelatedShallowly(planNode)) {
                return true;
            }
            Stream<PlanNode> stream = planNode.getSources().stream();
            Lookup lookup = this.lookup;
            Objects.requireNonNull(lookup);
            return stream.map(lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode planNode) {
            Stream<Symbol> stream = SymbolsExtractor.extractUniqueNonRecursive(planNode).stream();
            Set<Symbol> set = this.correlation;
            Objects.requireNonNull(set);
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            });
        }
    }

    public TransformCorrelatedInPredicateToJoin(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = applyNode.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression expression = (Expression) Iterables.getOnlyElement(subqueryAssignments.getExpressions());
        if (!(expression instanceof InPredicate)) {
            return Rule.Result.empty();
        }
        return apply(applyNode, (InPredicate) expression, (Symbol) Iterables.getOnlyElement(subqueryAssignments.getSymbols()), context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator(), context.getSession());
    }

    private Rule.Result apply(ApplyNode applyNode, InPredicate inPredicate, Symbol symbol, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Session session) {
        Optional<Decorrelated> decorrelate = new DecorrelatingVisitor(lookup, applyNode.getCorrelation()).decorrelate(applyNode.getSubquery());
        return decorrelate.isEmpty() ? Rule.Result.empty() : Rule.Result.ofPlanNode(buildInPredicateEquivalent(applyNode, inPredicate, symbol, decorrelate.get(), planNodeIdAllocator, symbolAllocator, session));
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode applyNode, InPredicate inPredicate, Symbol symbol, Decorrelated decorrelated, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Session session) {
        Expression and = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedNode = decorrelated.getDecorrelatedNode();
        AssignUniqueId assignUniqueId = new AssignUniqueId(planNodeIdAllocator.getNextId(), applyNode.getInput(), symbolAllocator.newSymbol("unique", (Type) BigintType.BIGINT));
        Symbol newSymbol = symbolAllocator.newSymbol("buildSideKnownNonNull", (Type) BigintType.BIGINT);
        ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), decorrelatedNode, Assignments.builder().putIdentities(decorrelatedNode.getOutputSymbols()).put(newSymbol, bigint(0L)).build());
        Symbol from = Symbol.from(inPredicate.getValue());
        Symbol from2 = Symbol.from(inPredicate.getValueList());
        JoinNode leftOuterJoin = leftOuterJoin(planNodeIdAllocator, assignUniqueId, projectNode, ExpressionUtils.and(ExpressionUtils.or(new IsNullPredicate(from.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, from.toSymbolReference(), from2.toSymbolReference()), new IsNullPredicate(from2.toSymbolReference())), and));
        Symbol newSymbol2 = symbolAllocator.newSymbol("matchConditionSymbol", (Type) BooleanType.BOOLEAN);
        Expression and2 = ExpressionUtils.and(isNotNull(from), isNotNull(from2));
        Symbol newSymbol3 = symbolAllocator.newSymbol("nullMatchConditionSymbol", (Type) BooleanType.BOOLEAN);
        ProjectNode projectNode2 = new ProjectNode(planNodeIdAllocator.getNextId(), leftOuterJoin, Assignments.builder().putIdentities(leftOuterJoin.getOutputSymbols()).put(newSymbol2, and2).put(newSymbol3, ExpressionUtils.and(isNotNull(newSymbol), not(and2))).build());
        Symbol newSymbol4 = symbolAllocator.newSymbol("countMatches", (Type) BigintType.BIGINT);
        Symbol newSymbol5 = symbolAllocator.newSymbol("countNullMatches", (Type) BigintType.BIGINT);
        return new ProjectNode(planNodeIdAllocator.getNextId(), AggregationNode.singleAggregation(planNodeIdAllocator.getNextId(), projectNode2, ImmutableMap.builder().put(newSymbol4, countWithFilter(session, newSymbol2)).put(newSymbol5, countWithFilter(session, newSymbol3)).buildOrThrow(), AggregationNode.singleGroupingSet(assignUniqueId.getOutputSymbols())), Assignments.builder().putIdentities(applyNode.getInput().getOutputSymbols()).put(symbol, new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(newSymbol4, 0L), booleanConstant(true)), new WhenClause(isGreaterThan(newSymbol5, 0L), booleanConstant(null))), Optional.of(booleanConstant(false)))).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator planNodeIdAllocator, AssignUniqueId assignUniqueId, ProjectNode projectNode, Expression expression) {
        return new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.LEFT, assignUniqueId, projectNode, ImmutableList.of(), assignUniqueId.getOutputSymbols(), projectNode.getOutputSymbols(), false, Optional.of(expression), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    }

    private AggregationNode.Aggregation countWithFilter(Session session, Symbol symbol) {
        return new AggregationNode.Aggregation(this.metadata.resolveFunction(session, QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(), false, Optional.of(symbol), Optional.empty(), Optional.empty());
    }

    private static Expression isGreaterThan(Symbol symbol, long j) {
        return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, symbol.toSymbolReference(), bigint(j));
    }

    private static Expression not(Expression expression) {
        return new NotExpression(expression);
    }

    private static Expression isNotNull(Symbol symbol) {
        return new IsNotNullPredicate(symbol.toSymbolReference());
    }

    private static Expression bigint(long j) {
        return new Cast(new LongLiteral(String.valueOf(j)), TypeSignatureTranslator.toSqlType(BigintType.BIGINT));
    }

    private static Expression booleanConstant(@Nullable Boolean bool) {
        return bool == null ? new Cast(new NullLiteral(), TypeSignatureTranslator.toSqlType(BooleanType.BOOLEAN)) : new BooleanLiteral(bool.toString());
    }
}
