package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.Iterables;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.Cardinality;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Row;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.class */
public class ReplaceJoinOverConstantWithProject implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(ReplaceJoinOverConstantWithProject::isUnconditional);
    private final Metadata metadata;

    public ReplaceJoinOverConstantWithProject(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        Cardinality extractCardinality = QueryCardinalityUtil.extractCardinality(joinNode.getLeft(), context.getLookup());
        if (extractCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        Cardinality extractCardinality2 = QueryCardinalityUtil.extractCardinality(joinNode.getRight(), context.getLookup());
        if (extractCardinality2.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode resolve = context.getLookup().resolve(joinNode.getLeft());
        PlanNode resolve2 = context.getLookup().resolve(joinNode.getRight());
        boolean canInlineJoinSource = canInlineJoinSource(resolve);
        boolean canInlineJoinSource2 = canInlineJoinSource(resolve2);
        switch (joinNode.getType()) {
            case INNER:
                return canInlineJoinSource ? Rule.Result.ofPlanNode(appendProjection(resolve2, joinNode.getRightOutputSymbols(), resolve, joinNode.getLeftOutputSymbols(), context.getIdAllocator())) : canInlineJoinSource2 ? Rule.Result.ofPlanNode(appendProjection(resolve, joinNode.getLeftOutputSymbols(), resolve2, joinNode.getRightOutputSymbols(), context.getIdAllocator())) : Rule.Result.empty();
            case LEFT:
                return (canInlineJoinSource && extractCardinality2.isAtLeastScalar()) ? Rule.Result.ofPlanNode(appendProjection(resolve2, joinNode.getRightOutputSymbols(), resolve, joinNode.getLeftOutputSymbols(), context.getIdAllocator())) : canInlineJoinSource2 ? Rule.Result.ofPlanNode(appendProjection(resolve, joinNode.getLeftOutputSymbols(), resolve2, joinNode.getRightOutputSymbols(), context.getIdAllocator())) : Rule.Result.empty();
            case RIGHT:
                return canInlineJoinSource ? Rule.Result.ofPlanNode(appendProjection(resolve2, joinNode.getRightOutputSymbols(), resolve, joinNode.getLeftOutputSymbols(), context.getIdAllocator())) : (canInlineJoinSource2 && extractCardinality.isAtLeastScalar()) ? Rule.Result.ofPlanNode(appendProjection(resolve, joinNode.getLeftOutputSymbols(), resolve2, joinNode.getRightOutputSymbols(), context.getIdAllocator())) : Rule.Result.empty();
            case FULL:
                return (canInlineJoinSource && extractCardinality2.isAtLeastScalar()) ? Rule.Result.ofPlanNode(appendProjection(resolve2, joinNode.getRightOutputSymbols(), resolve, joinNode.getLeftOutputSymbols(), context.getIdAllocator())) : (canInlineJoinSource2 && extractCardinality.isAtLeastScalar()) ? Rule.Result.ofPlanNode(appendProjection(resolve, joinNode.getLeftOutputSymbols(), resolve2, joinNode.getRightOutputSymbols(), context.getIdAllocator())) : Rule.Result.empty();
            default:
                throw new IncompatibleClassChangeError();
        }
    }

    private static boolean isUnconditional(JoinNode joinNode) {
        return joinNode.getCriteria().isEmpty() && (joinNode.getFilter().isEmpty() || joinNode.getFilter().get().equals(BooleanLiteral.TRUE_LITERAL));
    }

    private boolean canInlineJoinSource(PlanNode planNode) {
        return isSingleConstantRow(planNode) && !planNode.getOutputSymbols().isEmpty();
    }

    private boolean isSingleConstantRow(PlanNode planNode) {
        if (!(planNode instanceof ValuesNode)) {
            return false;
        }
        ValuesNode valuesNode = (ValuesNode) planNode;
        if (valuesNode.getRowCount() != 1) {
            return false;
        }
        if (valuesNode.getRows().isEmpty()) {
            return true;
        }
        Expression expression = (Expression) Iterables.getOnlyElement(valuesNode.getRows().get());
        if (DeterminismEvaluator.isDeterministic(expression, this.metadata)) {
            return expression instanceof Row;
        }
        return false;
    }

    private ProjectNode appendProjection(PlanNode planNode, List<Symbol> list, PlanNode planNode2, List<Symbol> list2, PlanNodeIdAllocator planNodeIdAllocator) {
        ValuesNode valuesNode = (ValuesNode) planNode2;
        Row row = (Row) Iterables.getOnlyElement(valuesNode.getRows().get());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
            hashMap.put(valuesNode.getOutputSymbols().get(i), (Expression) row.getItems().get(i));
        }
        Assignments.Builder putIdentities = Assignments.builder().putIdentities(list);
        list2.forEach(symbol -> {
            putIdentities.put(symbol, (Expression) hashMap.get(symbol));
        });
        return new ProjectNode(planNodeIdAllocator.getNextId(), planNode, putIdentities.build());
    }
}
