package io.trino.sql.planner.iterative.rule;

import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.CostCalculatorWithEstimatedExchanges;
import io.trino.cost.CostComparator;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import java.util.ArrayList;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.class */
public class DetermineSemiJoinDistributionType implements Rule<SemiJoinNode> {
    private final TaskCountEstimator taskCountEstimator;
    private final CostComparator costComparator;
    private static final Pattern<SemiJoinNode> PATTERN = Patterns.semiJoin().matching(semiJoinNode -> {
        return semiJoinNode.getDistributionType().isEmpty();
    });

    public DetermineSemiJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator) {
        this.costComparator = (CostComparator) Objects.requireNonNull(costComparator, "costComparator is null");
        this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<SemiJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(SemiJoinNode semiJoinNode, Captures captures, Rule.Context context) {
        switch (SystemSessionProperties.getJoinDistributionType(context.getSession())) {
            case AUTOMATIC:
                return Rule.Result.ofPlanNode(getCostBasedDistributionType(semiJoinNode, context));
            case PARTITIONED:
                return Rule.Result.ofPlanNode(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED));
            case BROADCAST:
                return Rule.Result.ofPlanNode(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.REPLICATED));
            default:
                throw new IncompatibleClassChangeError();
        }
    }

    private PlanNode getCostBasedDistributionType(SemiJoinNode semiJoinNode, Rule.Context context) {
        if (!canReplicate(semiJoinNode, context)) {
            return semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(getSemiJoinNodeWithCost(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.REPLICATED), context));
        arrayList.add(getSemiJoinNodeWithCost(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED), context));
        return arrayList.stream().anyMatch(planNodeWithCost -> {
            return planNodeWithCost.getCost().hasUnknownComponents();
        }) ? getSizeBaseDistributionType(semiJoinNode, context) : ((PlanNodeWithCost) this.costComparator.forSession(context.getSession()).onResultOf((v0) -> {
            return v0.getCost();
        }).min(arrayList)).getPlanNode();
    }

    private PlanNode getSizeBaseDistributionType(SemiJoinNode semiJoinNode, Rule.Context context) {
        return DetermineJoinDistributionType.getSourceTablesSizeInBytes(semiJoinNode.getFilteringSource(), context) <= ((double) SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession()).toBytes()) ? semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.REPLICATED) : semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED);
    }

    private boolean canReplicate(SemiJoinNode semiJoinNode, Rule.Context context) {
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        PlanNode filteringSource = semiJoinNode.getFilteringSource();
        return context.getStatsProvider().getStats(filteringSource).getOutputSizeInBytes(filteringSource.getOutputSymbols(), context.getSymbolAllocator().getTypes()) <= ((double) joinMaxBroadcastTableSize.toBytes()) || DetermineJoinDistributionType.getSourceTablesSizeInBytes(filteringSource, context) <= ((double) joinMaxBroadcastTableSize.toBytes());
    }

    private PlanNodeWithCost getSemiJoinNodeWithCost(SemiJoinNode semiJoinNode, Rule.Context context) {
        return new PlanNodeWithCost(CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput(semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), context.getStatsProvider(), context.getSymbolAllocator().getTypes(), semiJoinNode.getDistributionType().get() == SemiJoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession())).toPlanCost(), semiJoinNode);
    }
}
