package io.trino.cost;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

@ThreadSafe
/* loaded from: input_file:io/trino/cost/CostCalculatorUsingExchanges.class */
public class CostCalculatorUsingExchanges implements CostCalculator {
    private final TaskCountEstimator taskCountEstimator;

    /* loaded from: input_file:io/trino/cost/CostCalculatorUsingExchanges$CostEstimator.class */
    private static class CostEstimator extends PlanVisitor<PlanCostEstimate, Void> {
        private final StatsProvider stats;
        private final CostProvider sourcesCosts;
        private final TypeProvider types;
        private final TaskCountEstimator taskCountEstimator;
        private final Session session;

        CostEstimator(StatsProvider statsProvider, CostProvider costProvider, TypeProvider typeProvider, TaskCountEstimator taskCountEstimator, Session session) {
            this.stats = (StatsProvider) Objects.requireNonNull(statsProvider, "stats is null");
            this.sourcesCosts = (CostProvider) Objects.requireNonNull(costProvider, "sourcesCosts is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitPlan(PlanNode planNode, Void r4) {
            return PlanCostEstimate.unknown();
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitGroupReference(GroupReference groupReference, Void r5) {
            throw new UnsupportedOperationException();
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitAssignUniqueId(AssignUniqueId assignUniqueId, Void r6) {
            return costForStreaming(assignUniqueId, LocalCostEstimate.ofCpu(getStats(assignUniqueId).getOutputSizeInBytes(ImmutableList.of(assignUniqueId.getIdColumn()), this.types)));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitRowNumber(RowNumberNode rowNumberNode, Void r9) {
            ImmutableList outputSymbols = rowNumberNode.getOutputSymbols();
            if (rowNumberNode.getMaxRowCountPerPartition().isEmpty()) {
                outputSymbols = ImmutableList.builder().addAll(rowNumberNode.getPartitionBy()).add(rowNumberNode.getRowNumberSymbol()).build();
            }
            PlanNodeStatsEstimate stats = getStats(rowNumberNode);
            return costForStreaming(rowNumberNode, LocalCostEstimate.of(stats.getOutputSizeInBytes(outputSymbols, this.types), rowNumberNode.getPartitionBy().isEmpty() ? 0.0d : stats.getOutputSizeInBytes(rowNumberNode.getSource().getOutputSymbols(), this.types), 0.0d));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitOutput(OutputNode outputNode, Void r6) {
            return costForStreaming(outputNode, LocalCostEstimate.zero());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitTableScan(TableScanNode tableScanNode, Void r6) {
            return costForSource(tableScanNode, LocalCostEstimate.ofCpu(getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols(), this.types)));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitFilter(FilterNode filterNode, Void r6) {
            return costForStreaming(filterNode, LocalCostEstimate.ofCpu(getStats(filterNode.getSource()).getOutputSizeInBytes(filterNode.getOutputSymbols(), this.types)));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitProject(ProjectNode projectNode, Void r6) {
            return costForStreaming(projectNode, LocalCostEstimate.ofCpu(getStats(projectNode).getOutputSizeInBytes(projectNode.getOutputSymbols(), this.types)));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitAggregation(AggregationNode aggregationNode, Void r9) {
            if (aggregationNode.getStep() != AggregationNode.Step.FINAL && aggregationNode.getStep() != AggregationNode.Step.SINGLE) {
                return PlanCostEstimate.unknown();
            }
            return costForAccumulation(aggregationNode, LocalCostEstimate.of(getStats(aggregationNode.getSource()).getOutputSizeInBytes(aggregationNode.getSource().getOutputSymbols(), this.types), getStats(aggregationNode).getOutputSizeInBytes(aggregationNode.getOutputSymbols(), this.types), 0.0d));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitJoin(JoinNode joinNode, Void r9) {
            return costForLookupJoin(joinNode, calculateJoinCost(joinNode, joinNode.getLeft(), joinNode.getRight(), Objects.equals(joinNode.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED))));
        }

        private LocalCostEstimate calculateJoinCost(PlanNode planNode, PlanNode planNode2, PlanNode planNode3, boolean z) {
            int estimateSourceDistributedTaskCount = this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session);
            return LocalCostEstimate.addPartialComponents(CostCalculatorWithEstimatedExchanges.calculateJoinInputCost(planNode2, planNode3, this.stats, this.types, z, estimateSourceDistributedTaskCount), CostCalculatorWithEstimatedExchanges.adjustReplicatedJoinLocalExchangeCost(planNode3, this.stats, this.types, z, estimateSourceDistributedTaskCount), calculateJoinOutputCost(planNode));
        }

        private LocalCostEstimate calculateJoinOutputCost(PlanNode planNode) {
            return LocalCostEstimate.ofCpu(getStats(planNode).getOutputSizeInBytes(planNode.getOutputSymbols(), this.types));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitExchange(ExchangeNode exchangeNode, Void r7) {
            return costForStreaming(exchangeNode, calculateExchangeCost(exchangeNode));
        }

        private LocalCostEstimate calculateExchangeCost(ExchangeNode exchangeNode) {
            double outputSizeInBytes = getStats(exchangeNode).getOutputSizeInBytes(exchangeNode.getOutputSymbols(), this.types);
            switch (exchangeNode.getScope()) {
                case LOCAL:
                    switch (exchangeNode.getType()) {
                        case GATHER:
                            return LocalCostEstimate.zero();
                        case REPARTITION:
                            return CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(outputSizeInBytes);
                        case REPLICATE:
                            return LocalCostEstimate.zero();
                        default:
                            throw new IllegalArgumentException("Unexpected type: " + exchangeNode.getType());
                    }
                case REMOTE:
                    switch (exchangeNode.getType()) {
                        case GATHER:
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(outputSizeInBytes);
                        case REPARTITION:
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(outputSizeInBytes);
                        case REPLICATE:
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(outputSizeInBytes, this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session));
                        default:
                            throw new IllegalArgumentException("Unexpected type: " + exchangeNode.getType());
                    }
                default:
                    throw new IllegalArgumentException("Unexpected scope: " + exchangeNode.getScope());
            }
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitSemiJoin(SemiJoinNode semiJoinNode, Void r9) {
            return costForLookupJoin(semiJoinNode, calculateJoinCost(semiJoinNode, semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), semiJoinNode.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED) == SemiJoinNode.DistributionType.REPLICATED));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitSpatialJoin(SpatialJoinNode spatialJoinNode, Void r9) {
            return costForLookupJoin(spatialJoinNode, calculateJoinCost(spatialJoinNode, spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitValues(ValuesNode valuesNode, Void r6) {
            return costForSource(valuesNode, LocalCostEstimate.zero());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r6) {
            return costForAccumulation(enforceSingleRowNode, LocalCostEstimate.zero());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitLimit(LimitNode limitNode, Void r6) {
            return costForStreaming(limitNode, LocalCostEstimate.ofCpu(getStats(limitNode).getOutputSizeInBytes(limitNode.getOutputSymbols(), this.types)));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanCostEstimate visitUnion(UnionNode unionNode, Void r6) {
            return costForStreaming(unionNode, LocalCostEstimate.zero());
        }

        private PlanCostEstimate costForSource(PlanNode planNode, LocalCostEstimate localCostEstimate) {
            Verify.verify(planNode.getSources().isEmpty(), "Unexpected sources for %s: %s", planNode, planNode.getSources());
            return new PlanCostEstimate(localCostEstimate.getCpuCost(), localCostEstimate.getMaxMemory(), localCostEstimate.getMaxMemory(), localCostEstimate.getNetworkCost(), localCostEstimate);
        }

        private PlanCostEstimate costForAccumulation(PlanNode planNode, LocalCostEstimate localCostEstimate) {
            PlanCostEstimate reduce = getSourcesEstimations(planNode).reduce(PlanCostEstimate.zero(), CostCalculatorUsingExchanges::addParallelSiblingsCost);
            return new PlanCostEstimate(reduce.getCpuCost() + localCostEstimate.getCpuCost(), Math.max(reduce.getMaxMemory(), reduce.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory()), localCostEstimate.getMaxMemory(), reduce.getNetworkCost() + localCostEstimate.getNetworkCost(), localCostEstimate);
        }

        private PlanCostEstimate costForStreaming(PlanNode planNode, LocalCostEstimate localCostEstimate) {
            PlanCostEstimate reduce = getSourcesEstimations(planNode).reduce(PlanCostEstimate.zero(), CostCalculatorUsingExchanges::addParallelSiblingsCost);
            return new PlanCostEstimate(reduce.getCpuCost() + localCostEstimate.getCpuCost(), Math.max(reduce.getMaxMemory(), reduce.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory()), reduce.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory(), reduce.getNetworkCost() + localCostEstimate.getNetworkCost(), localCostEstimate);
        }

        private PlanCostEstimate costForLookupJoin(PlanNode planNode, LocalCostEstimate localCostEstimate) {
            Verify.verify(planNode.getSources().size() == 2, "Unexpected number of sources for %s: %s", planNode, planNode.getSources());
            List list = (List) getSourcesEstimations(planNode).collect(ImmutableList.toImmutableList());
            Verify.verify(list.size() == 2);
            PlanCostEstimate planCostEstimate = (PlanCostEstimate) list.get(0);
            PlanCostEstimate planCostEstimate2 = (PlanCostEstimate) list.get(1);
            return new PlanCostEstimate(planCostEstimate.getCpuCost() + planCostEstimate2.getCpuCost() + localCostEstimate.getCpuCost(), Math.max(planCostEstimate.getMaxMemory() + planCostEstimate2.getMaxMemory(), planCostEstimate.getMaxMemory() + planCostEstimate2.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory()), planCostEstimate.getMaxMemoryWhenOutputting() + localCostEstimate.getMaxMemory(), planCostEstimate.getNetworkCost() + planCostEstimate2.getNetworkCost() + localCostEstimate.getNetworkCost(), localCostEstimate);
        }

        private PlanNodeStatsEstimate getStats(PlanNode planNode) {
            return this.stats.getStats(planNode);
        }

        private Stream<PlanCostEstimate> getSourcesEstimations(PlanNode planNode) {
            Stream<PlanNode> stream = planNode.getSources().stream();
            CostProvider costProvider = this.sourcesCosts;
            Objects.requireNonNull(costProvider);
            return stream.map(costProvider::getCost);
        }
    }

    @Inject
    public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator) {
        this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override // io.trino.cost.CostCalculator
    public PlanCostEstimate calculateCost(PlanNode planNode, StatsProvider statsProvider, CostProvider costProvider, Session session, TypeProvider typeProvider) {
        return (PlanCostEstimate) planNode.accept(new CostEstimator(statsProvider, costProvider, typeProvider, this.taskCountEstimator, session), null);
    }

    private static PlanCostEstimate addParallelSiblingsCost(PlanCostEstimate planCostEstimate, PlanCostEstimate planCostEstimate2) {
        return new PlanCostEstimate(planCostEstimate.getCpuCost() + planCostEstimate2.getCpuCost(), planCostEstimate.getMaxMemory() + planCostEstimate2.getMaxMemory(), planCostEstimate.getMaxMemoryWhenOutputting() + planCostEstimate2.getMaxMemoryWhenOutputting(), planCostEstimate.getNetworkCost() + planCostEstimate2.getNetworkCost());
    }
}
