package org.apache.druid.server.coordinator.balancer;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListeningExecutorService;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.util.FastMath;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.server.coordinator.SegmentCountsPerInterval;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.loading.SegmentAction;
import org.apache.druid.server.coordinator.stats.CoordinatorRunStats;
import org.apache.druid.server.coordinator.stats.Dimension;
import org.apache.druid.server.coordinator.stats.RowKey;
import org.apache.druid.server.coordinator.stats.Stats;
import org.apache.druid.timeline.DataSegment;
import org.joda.time.Interval;

/* loaded from: input_file:org/apache/druid/server/coordinator/balancer/CostBalancerStrategy.class */
public class CostBalancerStrategy implements BalancerStrategy {
    private static final EmittingLogger log = new EmittingLogger(CostBalancerStrategy.class);
    private static final double HALF_LIFE = 24.0d;
    static final double LAMBDA = Math.log(2.0d) / HALF_LIFE;
    static final double INV_LAMBDA_SQUARE = 1.0d / (LAMBDA * LAMBDA);
    private static final double MILLIS_IN_HOUR = 3600000.0d;
    private static final double MILLIS_FACTOR = MILLIS_IN_HOUR / LAMBDA;
    private static final Comparator<Pair<Double, ServerHolder>> CHEAPEST_SERVERS_FIRST = Comparator.comparing(pair -> {
        return (Double) pair.lhs;
    }).thenComparing(pair2 -> {
        return (ServerHolder) pair2.rhs;
    });
    private final CoordinatorRunStats stats = new CoordinatorRunStats();
    private final AtomicLong computeTimeNanos = new AtomicLong(0);
    private final ListeningExecutorService exec;

    public static double computeJointSegmentsCost(DataSegment dataSegment, Iterable<DataSegment> iterable) {
        Interval costComputeInterval = getCostComputeInterval(dataSegment);
        double d = 0.0d;
        for (DataSegment dataSegment2 : iterable) {
            if (costComputeInterval.overlaps(dataSegment2.getInterval())) {
                d += computeJointSegmentsCost(dataSegment, dataSegment2);
            }
        }
        return d;
    }

    public static double computeJointSegmentsCost(DataSegment dataSegment, DataSegment dataSegment2) {
        return intervalCost(dataSegment.getInterval(), dataSegment2.getInterval()) * (dataSegment.getDataSource().equals(dataSegment2.getDataSource()) ? 2.0d : 1.0d);
    }

    public static double intervalCost(Interval interval, Interval interval2) {
        double startMillis = interval.getStartMillis();
        return INV_LAMBDA_SQUARE * intervalCost((interval.getEndMillis() - startMillis) / MILLIS_FACTOR, (interval2.getStartMillis() - startMillis) / MILLIS_FACTOR, (interval2.getEndMillis() - startMillis) / MILLIS_FACTOR);
    }

    public static double intervalCost(double d, double d2, double d3) {
        double d4;
        double d5;
        if (d == CMAESOptimizer.DEFAULT_STOPFITNESS || d3 == d2) {
            return CMAESOptimizer.DEFAULT_STOPFITNESS;
        }
        if (d2 < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            d = d3 - d2;
            d3 = d - d2;
            d2 = -d2;
        }
        if (d2 >= d) {
            double exp = FastMath.exp(d - d2);
            double exp2 = FastMath.exp(d - d3);
            return (FastMath.exp(CMAESOptimizer.DEFAULT_STOPFITNESS - d3) - FastMath.exp(CMAESOptimizer.DEFAULT_STOPFITNESS - d2)) - (exp2 - exp);
        }
        if (d3 <= d) {
            d4 = d3 - d2;
            d5 = d - d2;
        } else {
            d4 = d - d2;
            d5 = d3 - d2;
        }
        return intervalCost(d2, d2, d3) + intervalCost(d4, d4, d5) + (2.0d * ((d4 + FastMath.exp(-d4)) - 1.0d));
    }

    public CostBalancerStrategy(ListeningExecutorService listeningExecutorService) {
        this.exec = listeningExecutorService;
    }

    @Override // org.apache.druid.server.coordinator.balancer.BalancerStrategy
    public Iterator<ServerHolder> findServersToLoadSegment(DataSegment dataSegment, List<ServerHolder> list) {
        return orderServersByPlacementCost(dataSegment, list, SegmentAction.LOAD).stream().filter(serverHolder -> {
            return serverHolder.canLoadSegment(dataSegment);
        }).iterator();
    }

    @Override // org.apache.druid.server.coordinator.balancer.BalancerStrategy
    public ServerHolder findDestinationServerToMoveSegment(DataSegment dataSegment, ServerHolder serverHolder, List<ServerHolder> list) {
        List<ServerHolder> orderServersByPlacementCost = orderServersByPlacementCost(dataSegment, list, SegmentAction.MOVE_TO);
        if (orderServersByPlacementCost.isEmpty()) {
            return null;
        }
        ServerHolder serverHolder2 = orderServersByPlacementCost.get(0);
        if (serverHolder2.equals(serverHolder)) {
            return null;
        }
        return serverHolder2;
    }

    @Override // org.apache.druid.server.coordinator.balancer.BalancerStrategy
    public Iterator<ServerHolder> findServersToDropSegment(DataSegment dataSegment, List<ServerHolder> list) {
        return Lists.reverse(orderServersByPlacementCost(dataSegment, list, SegmentAction.DROP)).iterator();
    }

    @Override // org.apache.druid.server.coordinator.balancer.BalancerStrategy
    public CoordinatorRunStats getStats() {
        this.stats.add(Stats.Balancer.COMPUTATION_TIME, TimeUnit.NANOSECONDS.toMillis(this.computeTimeNanos.getAndSet(0L)));
        return this.stats;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double computePlacementCost(DataSegment dataSegment, ServerHolder serverHolder) {
        Interval costComputeInterval = getCostComputeInterval(dataSegment);
        Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
        SegmentCountsPerInterval projectedSegments = serverHolder.getProjectedSegments();
        projectedSegments.getIntervalToTotalSegmentCount().object2IntEntrySet().forEach(entry -> {
            Interval interval = (Interval) entry.getKey();
            if (costComputeInterval.overlaps(interval)) {
                object2IntOpenHashMap.addTo(interval, entry.getIntValue());
            }
        });
        projectedSegments.getIntervalToSegmentCount(dataSegment.getDataSource()).object2IntEntrySet().forEach(entry2 -> {
            Interval interval = (Interval) entry2.getKey();
            if (costComputeInterval.overlaps(interval)) {
                object2IntOpenHashMap.addTo(interval, entry2.getIntValue());
            }
        });
        Interval interval = dataSegment.getInterval();
        double sum = CMAESOptimizer.DEFAULT_STOPFITNESS + object2IntOpenHashMap.object2IntEntrySet().stream().mapToDouble(entry3 -> {
            return intervalCost(interval, (Interval) entry3.getKey()) * entry3.getIntValue();
        }).sum();
        if (serverHolder.isProjectedSegment(dataSegment)) {
            sum -= intervalCost(interval, interval) * 2.0d;
        }
        return sum;
    }

    private List<ServerHolder> orderServersByPlacementCost(DataSegment dataSegment, List<ServerHolder> list, SegmentAction segmentAction) {
        Stopwatch createStarted = Stopwatch.createStarted();
        ArrayList arrayList = new ArrayList();
        for (ServerHolder serverHolder : list) {
            arrayList.add(this.exec.submit(() -> {
                return Pair.of(Double.valueOf(computePlacementCost(dataSegment, serverHolder)), serverHolder);
            }));
        }
        RowKey and = RowKey.with(Dimension.TIER, list.isEmpty() ? null : list.get(0).getServer().getTier()).with(Dimension.DATASOURCE, dataSegment.getDataSource()).and(Dimension.DESCRIPTION, segmentAction.name());
        PriorityQueue priorityQueue = new PriorityQueue(CHEAPEST_SERVERS_FIRST);
        try {
            priorityQueue.addAll((Collection) Futures.allAsList(arrayList).get(1L, TimeUnit.MINUTES));
        } catch (Exception e) {
            this.stats.add(Stats.Balancer.COMPUTATION_ERRORS, and, 1L);
            handleFailure(e, dataSegment, segmentAction);
        }
        createStarted.stop();
        this.stats.add(Stats.Balancer.COMPUTATION_COUNT, 1L);
        this.computeTimeNanos.addAndGet(createStarted.elapsed(TimeUnit.NANOSECONDS));
        return (List) priorityQueue.stream().map(pair -> {
            return (ServerHolder) pair.rhs;
        }).collect(Collectors.toList());
    }

    private void handleFailure(Exception exc, DataSegment dataSegment, SegmentAction segmentAction) {
        String message;
        Object obj = "";
        if (this.exec.isShutdown()) {
            message = "Executor shutdown";
        } else if (exc instanceof TimeoutException) {
            message = "Timed out";
            obj = " Try setting a higher value for 'balancerComputeThreads'.";
        } else {
            message = exc.getMessage();
        }
        log.noStackTrace().warn(exc, "Cost strategy computations failed for action[%s] on segment[%s] due to reason[%s].[%s]", segmentAction, dataSegment.getId(), message, obj);
    }

    private static Interval getCostComputeInterval(DataSegment dataSegment) {
        Interval interval = dataSegment.getInterval();
        if (Intervals.isEternity(interval)) {
            return interval;
        }
        long millis = TimeUnit.DAYS.toMillis(45L);
        return Intervals.utc(interval.getStartMillis() - millis, interval.getEndMillis() + millis);
    }
}
