package io.trino.execution.scheduler.faulttolerant;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.math.Quantiles;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.metadata.Split;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

/* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage.class */
public class TaskDescriptorStorage {
    private static final Logger log = Logger.get(TaskDescriptorStorage.class);
    private final long maxMemoryInBytes;
    private final JsonCodec<Split> splitJsonCodec;
    private final StorageStats storageStats;

    @GuardedBy("this")
    private final Map<QueryId, TaskDescriptors> storages;

    @GuardedBy("this")
    private long reservedBytes;

    /* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStats.class */
    public static class StorageStats {
        private final Supplier<StorageStatsValue> statsSupplier;

        StorageStats(Supplier<StorageStatsValue> supplier) {
            this.statsSupplier = (Supplier) Objects.requireNonNull(supplier, "statsSupplier is null");
        }

        @Managed
        public long getQueriesCount() {
            return this.statsSupplier.get().queriesCount();
        }

        @Managed
        public long getStagesCount() {
            return this.statsSupplier.get().stagesCount();
        }

        @Managed
        public long getReservedBytes() {
            return this.statsSupplier.get().reservedBytes();
        }

        @Managed
        public long getQueryReservedBytesAvg() {
            return this.statsSupplier.get().queryReservedBytesAvg();
        }

        @Managed
        public long getQueryReservedBytesP50() {
            return this.statsSupplier.get().queryReservedBytesP50();
        }

        @Managed
        public long getQueryReservedBytesP90() {
            return this.statsSupplier.get().queryReservedBytesP90();
        }

        @Managed
        public long getQueryReservedBytesP95() {
            return this.statsSupplier.get().queryReservedBytesP95();
        }

        @Managed
        public long getStageReservedBytesAvg() {
            return this.statsSupplier.get().stageReservedBytesP50();
        }

        @Managed
        public long getStageReservedBytesP50() {
            return this.statsSupplier.get().stageReservedBytesP50();
        }

        @Managed
        public long getStageReservedBytesP90() {
            return this.statsSupplier.get().stageReservedBytesP90();
        }

        @Managed
        public long getStageReservedBytesP95() {
            return this.statsSupplier.get().stageReservedBytesP95();
        }
    }

    /* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue.class */
    private static final class StorageStatsValue extends Record {
        private final long queriesCount;
        private final long stagesCount;
        private final long reservedBytes;
        private final long queryReservedBytesAvg;
        private final long queryReservedBytesP50;
        private final long queryReservedBytesP90;
        private final long queryReservedBytesP95;
        private final long stageReservedBytesAvg;
        private final long stageReservedBytesP50;
        private final long stageReservedBytesP90;
        private final long stageReservedBytesP95;

        private StorageStatsValue(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8, long j9, long j10, long j11) {
            this.queriesCount = j;
            this.stagesCount = j2;
            this.reservedBytes = j3;
            this.queryReservedBytesAvg = j4;
            this.queryReservedBytesP50 = j5;
            this.queryReservedBytesP90 = j6;
            this.queryReservedBytesP95 = j7;
            this.stageReservedBytesAvg = j8;
            this.stageReservedBytesP50 = j9;
            this.stageReservedBytesP90 = j10;
            this.stageReservedBytesP95 = j11;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, StorageStatsValue.class), StorageStatsValue.class, "queriesCount;stagesCount;reservedBytes;queryReservedBytesAvg;queryReservedBytesP50;queryReservedBytesP90;queryReservedBytesP95;stageReservedBytesAvg;stageReservedBytesP50;stageReservedBytesP90;stageReservedBytesP95", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queriesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stagesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->reservedBytes:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP95:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP95:J").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, StorageStatsValue.class), StorageStatsValue.class, "queriesCount;stagesCount;reservedBytes;queryReservedBytesAvg;queryReservedBytesP50;queryReservedBytesP90;queryReservedBytesP95;stageReservedBytesAvg;stageReservedBytesP50;stageReservedBytesP90;stageReservedBytesP95", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queriesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stagesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->reservedBytes:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP95:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP95:J").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, StorageStatsValue.class, Object.class), StorageStatsValue.class, "queriesCount;stagesCount;reservedBytes;queryReservedBytesAvg;queryReservedBytesP50;queryReservedBytesP90;queryReservedBytesP95;stageReservedBytesAvg;stageReservedBytesP50;stageReservedBytesP90;stageReservedBytesP95", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queriesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stagesCount:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->reservedBytes:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->queryReservedBytesP95:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesAvg:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP50:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP90:J", "FIELD:Lio/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$StorageStatsValue;->stageReservedBytesP95:J").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public long queriesCount() {
            return this.queriesCount;
        }

        public long stagesCount() {
            return this.stagesCount;
        }

        public long reservedBytes() {
            return this.reservedBytes;
        }

        public long queryReservedBytesAvg() {
            return this.queryReservedBytesAvg;
        }

        public long queryReservedBytesP50() {
            return this.queryReservedBytesP50;
        }

        public long queryReservedBytesP90() {
            return this.queryReservedBytesP90;
        }

        public long queryReservedBytesP95() {
            return this.queryReservedBytesP95;
        }

        public long stageReservedBytesAvg() {
            return this.stageReservedBytesAvg;
        }

        public long stageReservedBytesP50() {
            return this.stageReservedBytesP50;
        }

        public long stageReservedBytesP90() {
            return this.stageReservedBytesP90;
        }

        public long stageReservedBytesP95() {
            return this.stageReservedBytesP95;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @NotThreadSafe
    /* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage$TaskDescriptors.class */
    public class TaskDescriptors {
        private long reservedBytes;
        private RuntimeException failure;
        private final Table<StageId, Integer, TaskDescriptor> descriptors = HashBasedTable.create();
        private final Map<StageId, AtomicLong> stagesReservedBytes = new HashMap();

        private TaskDescriptors() {
        }

        public void put(StageId stageId, int i, TaskDescriptor taskDescriptor) {
            throwIfFailed();
            Preconditions.checkState(!this.descriptors.contains(stageId, Integer.valueOf(i)), "task descriptor is already present for key %s/%s ", stageId, i);
            this.descriptors.put(stageId, Integer.valueOf(i), taskDescriptor);
            long retainedSizeInBytes = taskDescriptor.getRetainedSizeInBytes();
            this.reservedBytes += retainedSizeInBytes;
            this.stagesReservedBytes.computeIfAbsent(stageId, stageId2 -> {
                return new AtomicLong();
            }).addAndGet(retainedSizeInBytes);
        }

        public TaskDescriptor get(StageId stageId, int i) {
            throwIfFailed();
            TaskDescriptor taskDescriptor = (TaskDescriptor) this.descriptors.get(stageId, Integer.valueOf(i));
            if (taskDescriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId, Integer.valueOf(i)));
            }
            return taskDescriptor;
        }

        public void remove(StageId stageId, int i) {
            throwIfFailed();
            TaskDescriptor taskDescriptor = (TaskDescriptor) this.descriptors.remove(stageId, Integer.valueOf(i));
            if (taskDescriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId, Integer.valueOf(i)));
            }
            long retainedSizeInBytes = taskDescriptor.getRetainedSizeInBytes();
            this.reservedBytes -= retainedSizeInBytes;
            ((AtomicLong) Objects.requireNonNull(this.stagesReservedBytes.get(stageId), (Supplier<String>) () -> {
                return String.format("no entry for stage %s", stageId);
            })).addAndGet(-retainedSizeInBytes);
        }

        public long getReservedBytes() {
            return this.reservedBytes;
        }

        private String getDebugInfo() {
            Multimap multimap = (Multimap) this.descriptors.cellSet().stream().collect(ImmutableSetMultimap.toImmutableSetMultimap((v0) -> {
                return v0.getRowKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            return "stagesInfo=%s; biggestSplits=%s".formatted((Map) multimap.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return getDebugInfo((Collection) entry.getValue());
            })), multimap.entries().stream().flatMap(entry2 -> {
                return ((TaskDescriptor) entry2.getValue()).getSplits().getSplitsFlat().entries().stream().map(entry2 -> {
                    return Map.entry("%s/%s".formatted(entry2.getKey(), entry2.getKey()), (Split) entry2.getValue());
                });
            }).sorted(Comparator.comparingLong(entry3 -> {
                return ((Split) entry3.getValue()).getRetainedSizeInBytes();
            }).reversed()).limit(3L).map(entry4 -> {
                return "{nodeId=%s, size=%s, split=%s}".formatted(entry4.getKey(), Long.valueOf(((Split) entry4.getValue()).getRetainedSizeInBytes()), TaskDescriptorStorage.this.splitJsonCodec.toJson((Split) entry4.getValue()));
            }).toList());
        }

        private String getDebugInfo(Collection<TaskDescriptor> collection) {
            int size = collection.size();
            Stats of = Stats.of(collection.stream().mapToLong((v0) -> {
                return v0.getRetainedSizeInBytes();
            }));
            Set<PlanNodeId> set = (Set) collection.stream().flatMap(taskDescriptor -> {
                return taskDescriptor.getSplits().getSplitsFlat().keySet().stream();
            }).collect(ImmutableSet.toImmutableSet());
            HashMap hashMap = new HashMap();
            for (PlanNodeId planNodeId : set) {
                Stats of2 = Stats.of(collection.stream().mapToLong(taskDescriptor2 -> {
                    return ((Collection) taskDescriptor2.getSplits().getSplitsFlat().asMap().get(planNodeId)).size();
                }));
                Stats of3 = Stats.of(collection.stream().flatMap(taskDescriptor3 -> {
                    return taskDescriptor3.getSplits().getSplitsFlat().get(planNodeId).stream();
                }).mapToLong((v0) -> {
                    return v0.getRetainedSizeInBytes();
                }));
                hashMap.put(planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted(Double.valueOf(of2.mean()), Double.valueOf(of2.populationStandardDeviation()), Double.valueOf(of3.mean()), Double.valueOf(of3.populationStandardDeviation())));
            }
            return "[taskDescriptorsCount=%s, taskDescriptorsRetainedSizeMean=%s, taskDescriptorsRetainedSizeStdDev=%s, splits=%s]".formatted(Integer.valueOf(size), Double.valueOf(of.mean()), Double.valueOf(of.populationStandardDeviation()), hashMap);
        }

        private void fail(RuntimeException runtimeException) {
            if (this.failure == null) {
                this.descriptors.clear();
                this.reservedBytes = 0L;
                this.failure = runtimeException;
            }
        }

        private void throwIfFailed() {
            if (this.failure != null) {
                throw this.failure;
            }
        }

        public int getStagesCount() {
            return this.descriptors.rowMap().size();
        }

        public Stream<Long> getStagesReservedBytes() {
            return this.stagesReservedBytes.values().stream().map((v0) -> {
                return v0.get();
            });
        }
    }

    @Inject
    public TaskDescriptorStorage(QueryManagerConfig queryManagerConfig, JsonCodec<Split> jsonCodec) {
        this(queryManagerConfig.getFaultTolerantExecutionTaskDescriptorStorageMaxMemory(), jsonCodec);
    }

    public TaskDescriptorStorage(DataSize dataSize, JsonCodec<Split> jsonCodec) {
        this.storages = new HashMap();
        this.maxMemoryInBytes = dataSize.toBytes();
        this.splitJsonCodec = (JsonCodec) Objects.requireNonNull(jsonCodec, "splitJsonCodec is null");
        this.storageStats = new StorageStats(Suppliers.memoizeWithExpiration(this::computeStats, 1L, TimeUnit.SECONDS));
    }

    public synchronized void initialize(QueryId queryId) {
        TaskDescriptors taskDescriptors = new TaskDescriptors();
        Verify.verify(this.storages.putIfAbsent(queryId, taskDescriptors) == null, "storage is already initialized for query: %s", queryId);
        updateMemoryReservation(taskDescriptors.getReservedBytes());
    }

    public synchronized void put(StageId stageId, TaskDescriptor taskDescriptor) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        if (taskDescriptors == null) {
            return;
        }
        long reservedBytes = taskDescriptors.getReservedBytes();
        taskDescriptors.put(stageId, taskDescriptor.getPartitionId(), taskDescriptor);
        updateMemoryReservation(taskDescriptors.getReservedBytes() - reservedBytes);
    }

    public synchronized Optional<TaskDescriptor> get(StageId stageId, int i) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        return taskDescriptors == null ? Optional.empty() : Optional.of(taskDescriptors.get(stageId, i));
    }

    public synchronized void remove(StageId stageId, int i) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        if (taskDescriptors == null) {
            return;
        }
        long reservedBytes = taskDescriptors.getReservedBytes();
        taskDescriptors.remove(stageId, i);
        updateMemoryReservation(taskDescriptors.getReservedBytes() - reservedBytes);
    }

    public synchronized void destroy(QueryId queryId) {
        TaskDescriptors remove = this.storages.remove(queryId);
        if (remove != null) {
            updateMemoryReservation(-remove.getReservedBytes());
        }
    }

    private synchronized void updateMemoryReservation(long j) {
        this.reservedBytes += j;
        if (j <= 0) {
            return;
        }
        while (this.reservedBytes > this.maxMemoryInBytes) {
            QueryId queryId = (QueryId) this.storages.entrySet().stream().max(Comparator.comparingLong(entry -> {
                return ((TaskDescriptors) entry.getValue()).getReservedBytes();
            })).map((v0) -> {
                return v0.getKey();
            }).orElseThrow(() -> {
                return new VerifyException(String.format("storage is empty but reservedBytes (%s) is still greater than maxMemoryInBytes (%s)", Long.valueOf(this.reservedBytes), Long.valueOf(this.maxMemoryInBytes)));
            });
            TaskDescriptors taskDescriptors = this.storages.get(queryId);
            long reservedBytes = taskDescriptors.getReservedBytes();
            if (log.isInfoEnabled()) {
                log.info("Failing query %s; reclaiming %s of %s task descriptor memory from %s queries; extraStorageInfo=%s", new Object[]{queryId, Long.valueOf(taskDescriptors.getReservedBytes()), DataSize.succinctBytes(this.reservedBytes), Integer.valueOf(this.storages.size()), taskDescriptors.getDebugInfo()});
            }
            taskDescriptors.fail(new TrinoException(StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, String.format("Task descriptor storage capacity has been exceeded: %s > %s", DataSize.succinctBytes(this.maxMemoryInBytes), DataSize.succinctBytes(this.reservedBytes))));
            this.reservedBytes += taskDescriptors.getReservedBytes() - reservedBytes;
        }
    }

    @VisibleForTesting
    synchronized long getReservedBytes() {
        return this.reservedBytes;
    }

    @Managed
    @Nested
    public StorageStats getStats() {
        return this.storageStats;
    }

    private synchronized StorageStatsValue computeStats() {
        int size = this.storages.size();
        long sum = this.storages.values().stream().mapToLong((v0) -> {
            return v0.getStagesCount();
        }).sum();
        Quantiles.ScaleAndIndexes indexes = Quantiles.percentiles().indexes(new int[]{50, 90, 95});
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = 0;
        long j7 = 0;
        long j8 = 0;
        if (size > 0) {
            Map compute = indexes.compute((Collection) this.storages.values().stream().map((v0) -> {
                return v0.getReservedBytes();
            }).collect(ImmutableList.toImmutableList()));
            j = ((Double) compute.get(50)).longValue();
            j2 = ((Double) compute.get(90)).longValue();
            j3 = ((Double) compute.get(95)).longValue();
            j4 = this.reservedBytes / size;
            List list = (List) this.storages.values().stream().flatMap((v0) -> {
                return v0.getStagesReservedBytes();
            }).collect(ImmutableList.toImmutableList());
            if (!list.isEmpty()) {
                Map compute2 = indexes.compute(list);
                j5 = ((Double) compute2.get(50)).longValue();
                j6 = ((Double) compute2.get(90)).longValue();
                j7 = ((Double) compute2.get(95)).longValue();
                j8 = this.reservedBytes / sum;
            }
        }
        return new StorageStatsValue(size, sum, this.reservedBytes, j4, j, j2, j3, j8, j5, j6, j7);
    }
}
