package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.class */
public class RuntimeReorderJoinSides implements Rule<JoinNode> {
    private static final Logger log = Logger.get(RuntimeReorderJoinSides.class);
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private final Metadata metadata;
    private final SqlParser parser;

    public RuntimeReorderJoinSides(Metadata metadata, SqlParser sqlParser) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.parser = (SqlParser) Objects.requireNonNull(sqlParser, "parser is null");
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(planNode -> {
            return planNode.getSources().isEmpty() && !(planNode instanceof TableScanNode);
        }).matches()) {
            return Rule.Result.empty();
        }
        double d = Double.NaN;
        double d2 = Double.NaN;
        StatsProvider statsProvider = context.getStatsProvider();
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(planNode2 -> {
            return ((planNode2 instanceof TableScanNode) || (planNode2 instanceof ExchangeNode)) ? false : true;
        }).findAll().size() == 1) {
            d = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes();
            d2 = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes();
        }
        if (Double.isNaN(d) || Double.isNaN(d2)) {
            d = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft().getOutputVariables());
            d2 = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight().getOutputVariables());
        }
        if (Double.isNaN(d) || Double.isNaN(d2)) {
            return Rule.Result.empty();
        }
        if (d2 > d && isSwappedJoinValid(joinNode)) {
            JoinNode flipChildren = joinNode.flipChildren();
            PlanNode left = flipChildren.getLeft();
            Optional<VariableReferenceExpression> leftHashVariable = flipChildren.getLeftHashVariable();
            PlanNode resolve = context.getLookup().resolve(left);
            if ((resolve instanceof ExchangeNode) && resolve.getSources().size() == 1 && checkProbeSidePropertySatisfied((PlanNode) resolve.getSources().get(0), context)) {
                left = (PlanNode) resolve.getSources().get(0);
                if (flipChildren.getLeftHashVariable().isPresent()) {
                    leftHashVariable = Optional.of(((PlanNode) resolve.getSources().get(0)).getOutputVariables().get(resolve.getOutputVariables().indexOf(flipChildren.getLeftHashVariable().get())));
                    if (flipChildren.getOutputVariables().contains(flipChildren.getLeftHashVariable().get())) {
                        return Rule.Result.empty();
                    }
                }
            }
            List<VariableReferenceExpression> list = (List) flipChildren.getCriteria().stream().map((v0) -> {
                return v0.getRight();
            }).collect(ImmutableList.toImmutableList());
            PlanNode right = flipChildren.getRight();
            if (!checkBuildSidePropertySatisfied(flipChildren.getRight(), list, context)) {
                right = SystemSessionProperties.getTaskConcurrency(context.getSession()) > 1 ? ExchangeNode.systemPartitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, flipChildren.getRight(), list, flipChildren.getRightHashVariable()) : ExchangeNode.gatheringExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, flipChildren.getRight());
            }
            JoinNode joinNode2 = new JoinNode(flipChildren.getSourceLocation(), flipChildren.getId(), flipChildren.getType(), left, right, flipChildren.getCriteria(), flipChildren.getOutputVariables(), flipChildren.getFilter(), leftHashVariable, flipChildren.getRightHashVariable(), flipChildren.getDistributionType(), flipChildren.getDynamicFilters());
            log.debug(String.format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", Double.valueOf(d), Double.valueOf(d2), joinNode2.getId()));
            return Rule.Result.ofPlanNode(joinNode2);
        }
        return Rule.Result.empty();
    }

    private boolean isSwappedJoinValid(JoinNode joinNode) {
        return ((joinNode.getDistributionType().get() == JoinNode.DistributionType.REPLICATED && joinNode.getType() == JoinNode.Type.LEFT) || (joinNode.getDistributionType().get() == JoinNode.DistributionType.PARTITIONED && joinNode.getCriteria().isEmpty() && joinNode.getType() == JoinNode.Type.RIGHT)) ? false : true;
    }

    private boolean checkProbeSidePropertySatisfied(PlanNode planNode, Rule.Context context) {
        return ((SystemSessionProperties.isSpillEnabled(context.getSession()) && SystemSessionProperties.isJoinSpillingEnabled(context.getSession())) ? StreamPreferredProperties.fixedParallelism() : StreamPreferredProperties.defaultParallelism(context.getSession())).isSatisfiedBy(derivePropertiesRecursively(planNode, this.metadata, this.parser, context));
    }

    private boolean checkBuildSidePropertySatisfied(PlanNode planNode, List<VariableReferenceExpression> list, Rule.Context context) {
        return (SystemSessionProperties.getTaskConcurrency(context.getSession()) > 1 ? StreamPreferredProperties.exactlyPartitionedOn(list) : StreamPreferredProperties.singleStream()).isSatisfiedBy(derivePropertiesRecursively(planNode, this.metadata, this.parser, context));
    }

    private StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(PlanNode planNode, Metadata metadata, SqlParser sqlParser, Rule.Context context) {
        PlanNode resolve = context.getLookup().resolve(planNode);
        return StreamPropertyDerivations.deriveProperties(resolve, (List<StreamPropertyDerivations.StreamProperties>) resolve.getSources().stream().map(planNode2 -> {
            return derivePropertiesRecursively(planNode2, metadata, sqlParser, context);
        }).collect(ImmutableList.toImmutableList()), metadata, context.getSession(), context.getVariableAllocator().getTypes(), sqlParser);
    }
}
