package io.trino.execution.scheduler.faulttolerant;

import com.google.common.collect.Ordering;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.stats.TDigest;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.execution.scheduler.ErrorCodes;
import io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator;
import io.trino.memory.ClusterMemoryManager;
import io.trino.memory.MemoryInfo;
import io.trino.memory.MemoryManagerConfig;
import io.trino.spi.ErrorCode;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.PlanFragmentId;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.assertj.core.util.VisibleForTesting;

/* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/ExponentialGrowthPartitionMemoryEstimator.class */
public class ExponentialGrowthPartitionMemoryEstimator implements PartitionMemoryEstimator {
    private final DataSize defaultInitialMemoryLimit;
    private final boolean memoryRequirementIncreaseOnWorkerCrashEnabled;
    private final double growthFactor;
    private final double estimationQuantile;
    private final Supplier<Optional<DataSize>> maxNodePoolSizeSupplier;
    private final TDigest memoryUsageDistribution = new TDigest();

    /* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/ExponentialGrowthPartitionMemoryEstimator$Factory.class */
    public static class Factory implements PartitionMemoryEstimatorFactory {
        private static final Logger log = Logger.get(Factory.class);
        private final Supplier<Map<String, Optional<MemoryInfo>>> workerMemoryInfoSupplier;
        private final boolean memoryRequirementIncreaseOnWorkerCrashEnabled;
        private final ScheduledExecutorService executor;
        private final AtomicReference<Optional<DataSize>> maxNodePoolSize;

        /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
        @Inject
        public Factory(ClusterMemoryManager clusterMemoryManager, MemoryManagerConfig memoryManagerConfig) {
            this((Supplier<Map<String, Optional<MemoryInfo>>>) clusterMemoryManager::getWorkerMemoryInfo, memoryManagerConfig.isFaultTolerantExecutionMemoryRequirementIncreaseOnWorkerCrashEnabled());
            Objects.requireNonNull(clusterMemoryManager);
        }

        @VisibleForTesting
        Factory(Supplier<Map<String, Optional<MemoryInfo>>> supplier, boolean z) {
            this.executor = Executors.newSingleThreadScheduledExecutor();
            this.maxNodePoolSize = new AtomicReference<>(Optional.empty());
            this.workerMemoryInfoSupplier = (Supplier) Objects.requireNonNull(supplier, "workerMemoryInfoSupplier is null");
            this.memoryRequirementIncreaseOnWorkerCrashEnabled = z;
        }

        @PostConstruct
        public void start() {
            refreshNodePoolMemoryInfos();
            this.executor.scheduleWithFixedDelay(() -> {
                try {
                    refreshNodePoolMemoryInfos();
                } catch (Throwable th) {
                    log.error(th, "Unexpected error while refreshing node pool memory infos");
                }
            }, 1L, 1L, TimeUnit.SECONDS);
        }

        @PreDestroy
        public void stop() {
            this.executor.shutdownNow();
        }

        @VisibleForTesting
        void refreshNodePoolMemoryInfos() {
            long j = -1;
            for (Map.Entry<String, Optional<MemoryInfo>> entry : this.workerMemoryInfoSupplier.get().entrySet()) {
                if (!entry.getValue().isEmpty()) {
                    j = Math.max(entry.getValue().get().getPool().getMaxBytes(), j);
                }
            }
            this.maxNodePoolSize.set(j == -1 ? Optional.empty() : Optional.of(DataSize.ofBytes(j)));
        }

        @Override // io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimatorFactory
        public PartitionMemoryEstimator createPartitionMemoryEstimator(Session session, PlanFragment planFragment, Function<PlanFragmentId, PlanFragment> function) {
            DataSize faultTolerantExecutionDefaultCoordinatorTaskMemory = planFragment.getPartitioning().equals(SystemPartitioningHandle.COORDINATOR_DISTRIBUTION) ? SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory(session) : SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session);
            boolean z = this.memoryRequirementIncreaseOnWorkerCrashEnabled;
            double faultTolerantExecutionTaskMemoryGrowthFactor = SystemSessionProperties.getFaultTolerantExecutionTaskMemoryGrowthFactor(session);
            double faultTolerantExecutionTaskMemoryEstimationQuantile = SystemSessionProperties.getFaultTolerantExecutionTaskMemoryEstimationQuantile(session);
            AtomicReference<Optional<DataSize>> atomicReference = this.maxNodePoolSize;
            Objects.requireNonNull(atomicReference);
            return new ExponentialGrowthPartitionMemoryEstimator(faultTolerantExecutionDefaultCoordinatorTaskMemory, z, faultTolerantExecutionTaskMemoryGrowthFactor, faultTolerantExecutionTaskMemoryEstimationQuantile, atomicReference::get);
        }
    }

    private ExponentialGrowthPartitionMemoryEstimator(DataSize dataSize, boolean z, double d, double d2, Supplier<Optional<DataSize>> supplier) {
        this.defaultInitialMemoryLimit = (DataSize) Objects.requireNonNull(dataSize, "defaultInitialMemoryLimit is null");
        this.memoryRequirementIncreaseOnWorkerCrashEnabled = z;
        this.growthFactor = d;
        this.estimationQuantile = d2;
        this.maxNodePoolSizeSupplier = (Supplier) Objects.requireNonNull(supplier, "maxNodePoolSizeSupplier is null");
    }

    @Override // io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator
    public PartitionMemoryEstimator.MemoryRequirements getInitialMemoryRequirements() {
        return new PartitionMemoryEstimator.MemoryRequirements(capMemoryToMaxNodeSize((DataSize) Ordering.natural().max(this.defaultInitialMemoryLimit, getEstimatedMemoryUsage())));
    }

    @Override // io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator
    public PartitionMemoryEstimator.MemoryRequirements getNextRetryMemoryRequirements(PartitionMemoryEstimator.MemoryRequirements memoryRequirements, DataSize dataSize, ErrorCode errorCode) {
        DataSize dataSize2 = (DataSize) Ordering.natural().max(dataSize, memoryRequirements.getRequiredMemory());
        if (shouldIncreaseMemoryRequirement(errorCode)) {
            dataSize2 = DataSize.of((long) (dataSize2.toBytes() * this.growthFactor), DataSize.Unit.BYTE);
        }
        return new PartitionMemoryEstimator.MemoryRequirements(capMemoryToMaxNodeSize((DataSize) Ordering.natural().max(dataSize2, getEstimatedMemoryUsage())));
    }

    private DataSize capMemoryToMaxNodeSize(DataSize dataSize) {
        Optional<DataSize> optional = this.maxNodePoolSizeSupplier.get();
        return optional.isEmpty() ? dataSize : (DataSize) Ordering.natural().min(dataSize, optional.get());
    }

    @Override // io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator
    public synchronized void registerPartitionFinished(PartitionMemoryEstimator.MemoryRequirements memoryRequirements, DataSize dataSize, boolean z, Optional<ErrorCode> optional) {
        if (z) {
            this.memoryUsageDistribution.add(dataSize.toBytes());
        }
        if (!z && optional.isPresent() && shouldIncreaseMemoryRequirement(optional.get())) {
            this.memoryUsageDistribution.add(Math.max(memoryRequirements.getRequiredMemory().toBytes(), dataSize.toBytes()) * this.growthFactor);
        }
    }

    private synchronized DataSize getEstimatedMemoryUsage() {
        double valueAt = this.memoryUsageDistribution.valueAt(this.estimationQuantile);
        return Double.isNaN(valueAt) ? DataSize.ofBytes(0L) : DataSize.ofBytes((long) valueAt);
    }

    private String memoryUsageDistributionInfo() {
        double[] valuesAt;
        double[] dArr = {0.01d, 0.05d, 0.1d, 0.2d, 0.5d, 0.8d, 0.9d, 0.95d, 0.99d};
        synchronized (this) {
            valuesAt = this.memoryUsageDistribution.valuesAt(dArr);
        }
        return (String) IntStream.range(0, dArr.length).mapToObj(i -> {
            double d = dArr[i];
            double d2 = valuesAt[i];
            return d + "=" + d;
        }).collect(Collectors.joining(", ", "[", "]"));
    }

    public String toString() {
        return "memoryUsageDistribution=" + memoryUsageDistributionInfo();
    }

    private boolean shouldIncreaseMemoryRequirement(ErrorCode errorCode) {
        return ErrorCodes.isOutOfMemoryError(errorCode) || (this.memoryRequirementIncreaseOnWorkerCrashEnabled && ErrorCodes.isWorkerCrashAssociatedError(errorCode));
    }
}
