package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.SetOperationNodeTranslator;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.FunctionCall;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ImplementIntersectAll.class */
public class ImplementIntersectAll implements Rule<IntersectNode> {
    private static final Pattern<IntersectNode> PATTERN = Patterns.intersect().with(Patterns.Intersect.distinct().equalTo(false));
    private final Metadata metadata;

    public ImplementIntersectAll(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<IntersectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(IntersectNode intersectNode, Captures captures, Rule.Context context) {
        SetOperationNodeTranslator.TranslationResult makeSetContainmentPlanForAll = new SetOperationNodeTranslator(context.getSession(), this.metadata, context.getSymbolAllocator(), context.getIdAllocator()).makeSetContainmentPlanForAll(intersectNode);
        Preconditions.checkState(makeSetContainmentPlanForAll.getCountSymbols().size() > 0, "IntersectNode translation result has no count symbols");
        ResolvedFunction resolveBuiltinFunction = this.metadata.resolveBuiltinFunction("least", TypeSignatureProvider.fromTypes(BigintType.BIGINT, BigintType.BIGINT));
        FunctionCall symbolReference = makeSetContainmentPlanForAll.getCountSymbols().get(0).toSymbolReference();
        for (int i = 1; i < makeSetContainmentPlanForAll.getCountSymbols().size(); i++) {
            symbolReference = new FunctionCall(resolveBuiltinFunction.toQualifiedName(), ImmutableList.of(symbolReference, makeSetContainmentPlanForAll.getCountSymbols().get(i).toSymbolReference()));
        }
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), makeSetContainmentPlanForAll.getPlanNode(), new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, makeSetContainmentPlanForAll.getRowNumberSymbol().toSymbolReference(), symbolReference)), Assignments.identity(intersectNode.getOutputSymbols())));
    }
}
