package io.trino.gateway.ha.router;

import com.google.common.base.Strings;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/trino/gateway/ha/router/TrinoQueueLengthRoutingTable.class */
public class TrinoQueueLengthRoutingTable extends HaRoutingManager {
    private static final Logger log = LoggerFactory.getLogger(TrinoQueueLengthRoutingTable.class);
    private static final Random RANDOM = new Random();
    private static final int MIN_WT = 1;
    private static final int MAX_WT = 100;
    private final Object lockObject;
    private ConcurrentHashMap<String, Integer> routingGroupWeightSum;
    private ConcurrentHashMap<String, ConcurrentHashMap<String, Integer>> clusterQueueLengthMap;
    private ConcurrentHashMap<String, ConcurrentHashMap<String, Integer>> userClusterQueueLengthMap;
    private Map<String, TreeMap<Integer, String>> weightedDistributionRouting;

    public TrinoQueueLengthRoutingTable(GatewayBackendManager gatewayBackendManager, QueryHistoryManager queryHistoryManager) {
        super(gatewayBackendManager, queryHistoryManager);
        this.lockObject = new Object();
        this.routingGroupWeightSum = new ConcurrentHashMap<>();
        this.clusterQueueLengthMap = new ConcurrentHashMap<>();
        this.weightedDistributionRouting = new HashMap();
        this.userClusterQueueLengthMap = new ConcurrentHashMap<>();
    }

    private int getWeightForMaxQueueCluster(LinkedHashMap<String, Integer> linkedHashMap) {
        int i;
        int sum = linkedHashMap.values().stream().mapToInt((v0) -> {
            return v0.intValue();
        }).sum() / linkedHashMap.size();
        Object[] array = linkedHashMap.values().toArray();
        int intValue = ((Integer) array[0]).intValue();
        int intValue2 = ((Integer) array[array.length - 1]).intValue();
        int i2 = intValue;
        if (array.length > 2) {
            i2 = ((Integer) array[array.length - 2]).intValue();
        }
        if (intValue2 == 0) {
            i = MAX_WT;
        } else if (i2 == 0 || i2 == intValue2) {
            i = 1;
        } else {
            int ceil = (int) Math.ceil(100.0d - ((i2 * MAX_WT) / intValue2));
            double d = intValue / intValue2;
            i = (int) Math.ceil(d * ceil);
            if (i2 < sum || (i2 > sum && intValue <= sum)) {
                i = intValue == 0 ? 1 : (int) Math.ceil(d * d * ceil);
            }
        }
        return i;
    }

    private void computeWeightsBasedOnQueueLength(final ConcurrentHashMap<String, ConcurrentHashMap<String, Integer>> concurrentHashMap) {
        synchronized (this.lockObject) {
            this.weightedDistributionRouting.clear();
            this.routingGroupWeightSum.clear();
            log.debug("Computing Weights for Queue Map :[{}] ", concurrentHashMap.toString());
            Iterator it = concurrentHashMap.keySet().iterator();
            while (it.hasNext()) {
                final String str = (String) it.next();
                int i = 0;
                TreeMap<Integer, String> treeMap = new TreeMap<>();
                if (concurrentHashMap.get(str).size() == 0) {
                    log.warn("No active clusters in routingGroup : [{}]. Continue to process rest of routing table ", str);
                } else if (concurrentHashMap.get(str).size() == 1) {
                    log.debug("Routing Group: [{}] has only 1 active backend.", str);
                    this.weightedDistributionRouting.put(str, new TreeMap<Integer, String>() { // from class: io.trino.gateway.ha.router.TrinoQueueLengthRoutingTable.1
                        {
                            put(Integer.valueOf(TrinoQueueLengthRoutingTable.MAX_WT), (String) ((ConcurrentHashMap) concurrentHashMap.get(str)).keys().nextElement());
                        }
                    });
                    this.routingGroupWeightSum.put(str, Integer.valueOf(MAX_WT));
                } else {
                    LinkedHashMap<String, Integer> linkedHashMap = (LinkedHashMap) concurrentHashMap.get(str).entrySet().stream().sorted(Comparator.comparing((v0) -> {
                        return v0.getValue();
                    })).collect(Collectors.toMap((v0) -> {
                        return v0.getKey();
                    }, (v0) -> {
                        return v0.getValue();
                    }, (num, num2) -> {
                        return num;
                    }, LinkedHashMap::new));
                    int size = linkedHashMap.size();
                    linkedHashMap.values().stream().mapToInt((v0) -> {
                        return v0.intValue();
                    }).sum();
                    Object[] array = linkedHashMap.values().toArray();
                    Object[] array2 = linkedHashMap.keySet().toArray();
                    int intValue = ((Integer) array[array.length - 1]).intValue();
                    int weightForMaxQueueCluster = getWeightForMaxQueueCluster(linkedHashMap);
                    for (int i2 = 0; i2 < size - 1; i2++) {
                        i += intValue == ((Integer) array[i2]).intValue() ? weightForMaxQueueCluster : (int) Math.ceil(100.0d - ((((Integer) array[i2]).intValue() * MAX_WT) / intValue));
                        treeMap.put(Integer.valueOf(i), (String) array2[i2]);
                    }
                    int i3 = i + weightForMaxQueueCluster;
                    treeMap.put(Integer.valueOf(i3), (String) array2[size - 1]);
                    this.weightedDistributionRouting.put(str, treeMap);
                    this.routingGroupWeightSum.put(str, Integer.valueOf(i3));
                }
            }
            if (log.isDebugEnabled()) {
                for (String str2 : this.weightedDistributionRouting.keySet()) {
                    log.debug("Routing Table for : [{}] is [{}]", str2, this.weightedDistributionRouting.get(str2).toString());
                }
            }
        }
    }

    public void updateRoutingTable(String str, Set<String> set) {
        synchronized (this.lockObject) {
            if (this.clusterQueueLengthMap.containsKey(str)) {
                log.debug("Update routing table for routing group : [{}] with active backends : [{}]", str, set.toString());
                HashSet hashSet = new HashSet();
                hashSet.addAll(this.clusterQueueLengthMap.get(str).keySet());
                if (set.containsAll(hashSet)) {
                    return;
                }
                if (hashSet.removeAll(set)) {
                    Iterator<?> it = hashSet.iterator();
                    while (it.hasNext()) {
                        this.clusterQueueLengthMap.get(str).remove((String) it.next());
                    }
                }
            }
            computeWeightsBasedOnQueueLength(this.clusterQueueLengthMap);
        }
    }

    public void updateRoutingTable(Map<String, Map<String, Integer>> map, Map<String, Map<String, Integer>> map2, Map<String, Map<String, Integer>> map3) {
        synchronized (this.lockObject) {
            log.debug("Update Routing table with new cluster queue lengths : [{}]", map.toString());
            this.clusterQueueLengthMap.clear();
            this.userClusterQueueLengthMap.clear();
            if (map3 != null) {
                for (String str : map3.keySet()) {
                    this.userClusterQueueLengthMap.put(str, new ConcurrentHashMap<>(map3.get(str)));
                }
            }
            for (String str2 : map.keySet()) {
                if (str2 != null) {
                    ConcurrentHashMap<String, Integer> concurrentHashMap = new ConcurrentHashMap<>();
                    int intValue = ((Integer) Collections.max(map.get(str2).values())).intValue();
                    if (((Integer) Collections.min(map.get(str2).values())).intValue() == intValue && map.get(str2).size() > 1 && map2.containsKey(str2)) {
                        log.info("Queue lengths equal: {} for all clusters in the group {}. Falling back to Running Counts : {}", new Object[]{Integer.valueOf(intValue), str2, map2.get(str2)});
                        concurrentHashMap.putAll(map2.get(str2));
                    } else {
                        concurrentHashMap.putAll(map.get(str2));
                    }
                    this.clusterQueueLengthMap.put(str2, concurrentHashMap);
                }
            }
            computeWeightsBasedOnQueueLength(this.clusterQueueLengthMap);
        }
    }

    public Map<String, Integer> getInternalWeightedRoutingTable(String str) {
        if (!this.weightedDistributionRouting.containsKey(str)) {
            return null;
        }
        HashMap hashMap = new HashMap();
        for (Integer num : this.weightedDistributionRouting.get(str).keySet()) {
            hashMap.put(this.weightedDistributionRouting.get(str).get(num), num);
        }
        return hashMap;
    }

    public Map<String, Integer> getInternalClusterQueueLength(String str) {
        if (this.clusterQueueLengthMap.containsKey(str)) {
            return this.clusterQueueLengthMap.get(str);
        }
        return null;
    }

    public String getEligibleBackEnd(String str, String str2) {
        ConcurrentHashMap<String, Integer> concurrentHashMap;
        if (!Strings.isNullOrEmpty(str2) && (concurrentHashMap = this.userClusterQueueLengthMap.get(str2)) != null && !concurrentHashMap.isEmpty()) {
            String str3 = null;
            Integer num = Integer.MAX_VALUE;
            Integer num2 = Integer.MIN_VALUE;
            for (String str4 : this.clusterQueueLengthMap.get(str).keySet()) {
                Integer orDefault = concurrentHashMap.getOrDefault(str4, 0);
                if (orDefault.intValue() < num.intValue()) {
                    str3 = str4;
                    num = orDefault;
                }
                if (orDefault.intValue() > num2.intValue()) {
                    num2 = orDefault;
                }
            }
            if (!Strings.isNullOrEmpty(str3) && num != num2) {
                log.debug("{} routing to:{}. userQueueCount:{}", new Object[]{str2, str3, num});
                return str3;
            }
        }
        if (this.routingGroupWeightSum.containsKey(str) && this.weightedDistributionRouting.containsKey(str)) {
            return this.weightedDistributionRouting.get(str).higherEntry(Integer.valueOf(RANDOM.nextInt(this.routingGroupWeightSum.get(str).intValue()))).getValue();
        }
        return null;
    }

    @Override // io.trino.gateway.ha.router.RoutingManager
    public String provideBackendForRoutingGroup(String str, String str2) {
        List<ProxyBackendConfiguration> activeBackends = getGatewayBackendManager().getActiveBackends(str);
        if (activeBackends.isEmpty()) {
            return provideAdhocBackend(str2);
        }
        HashMap hashMap = new HashMap();
        for (ProxyBackendConfiguration proxyBackendConfiguration : activeBackends) {
            hashMap.put(proxyBackendConfiguration.getName(), proxyBackendConfiguration.getProxyTo());
        }
        updateRoutingTable(str, hashMap.keySet());
        String eligibleBackEnd = getEligibleBackEnd(str, str2);
        log.debug("Routing to eligible backend : [{}] for routing group: [{}]", eligibleBackEnd, str);
        if (eligibleBackEnd != null) {
            return (String) hashMap.get(eligibleBackEnd);
        }
        log.debug("Falling back to random distribution");
        return activeBackends.get(Math.abs(RANDOM.nextInt()) % activeBackends.size()).getProxyTo();
    }

    @Override // io.trino.gateway.ha.router.RoutingManager
    public String provideAdhocBackend(String str) {
        HashMap hashMap = new HashMap();
        List<ProxyBackendConfiguration> activeAdhocBackends = getGatewayBackendManager().getActiveAdhocBackends();
        if (activeAdhocBackends.size() == 0) {
            throw new IllegalStateException("Number of active backends found zero");
        }
        for (ProxyBackendConfiguration proxyBackendConfiguration : activeAdhocBackends) {
            hashMap.put(proxyBackendConfiguration.getName(), proxyBackendConfiguration.getProxyTo());
        }
        updateRoutingTable("adhoc", hashMap.keySet());
        String eligibleBackEnd = getEligibleBackEnd("adhoc", str);
        log.debug("Routing to eligible backend : " + eligibleBackEnd + " for routing group: adhoc");
        if (eligibleBackEnd != null) {
            return (String) hashMap.get(eligibleBackEnd);
        }
        log.debug("Falling back to random distribution");
        return activeAdhocBackends.get(Math.abs(RANDOM.nextInt()) % activeAdhocBackends.size()).getProxyTo();
    }
}
