package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.net.InetAddresses;
import com.google.inject.Inject;
import io.trino.spi.HostAddress;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/execution/scheduler/SubnetBasedTopology.class */
public class SubnetBasedTopology implements NetworkTopology {
    private final List<byte[]> subnetMasks;
    private final AddressProtocol protocol;

    /* loaded from: input_file:io/trino/execution/scheduler/SubnetBasedTopology$AddressProtocol.class */
    public enum AddressProtocol {
        IPv4(Inet4Address.class, 32),
        IPv6(Inet6Address.class, 128);

        private final Class<?> addressClass;
        private final int totalBitCount;

        AddressProtocol(Class cls, int i) {
            this.addressClass = cls;
            this.totalBitCount = i;
        }

        int getTotalBitCount() {
            return this.totalBitCount;
        }

        byte[] computeSubnetMask(int i) {
            Preconditions.checkArgument(i > 0 && i < getTotalBitCount(), "Invalid length for subnet mask");
            byte[] bArr = new byte[getTotalBitCount() / 8];
            int i2 = 0;
            while (true) {
                if (i2 >= bArr.length) {
                    break;
                }
                if (i < 8) {
                    bArr[i2] = (byte) (-(1 << (8 - i)));
                    break;
                }
                bArr[i2] = -1;
                i -= 8;
                i2++;
            }
            return bArr;
        }

        InetAddress getInetAddress(List<InetAddress> list) {
            Stream<InetAddress> stream = list.stream();
            Class<?> cls = this.addressClass;
            Objects.requireNonNull(cls);
            return stream.filter((v1) -> {
                return r1.isInstance(v1);
            }).findFirst().orElse(null);
        }
    }

    @Inject
    public SubnetBasedTopology(SubnetTopologyConfig subnetTopologyConfig) {
        this(subnetTopologyConfig.getCidrPrefixLengths(), subnetTopologyConfig.getAddressProtocol());
    }

    public SubnetBasedTopology(List<Integer> list, AddressProtocol addressProtocol) {
        Objects.requireNonNull(list, "cidrPrefixLengths is null");
        Objects.requireNonNull(addressProtocol, "protocol is null");
        validateHierarchy(list, addressProtocol);
        this.protocol = addressProtocol;
        Stream<Integer> stream = list.stream();
        Objects.requireNonNull(addressProtocol);
        this.subnetMasks = (List) stream.map((v1) -> {
            return r2.computeSubnetMask(v1);
        }).collect(ImmutableList.toImmutableList());
    }

    @Override // io.trino.execution.scheduler.NetworkTopology
    public NetworkLocation locate(HostAddress hostAddress) {
        try {
            InetAddress inetAddress = this.protocol.getInetAddress(hostAddress.getAllInetAddresses());
            if (inetAddress == null) {
                return NetworkLocation.ROOT_LOCATION;
            }
            byte[] address = inetAddress.getAddress();
            ImmutableList.Builder builder = ImmutableList.builder();
            Iterator<byte[]> it = this.subnetMasks.iterator();
            while (it.hasNext()) {
                builder.add(InetAddresses.toAddrString(InetAddress.getByAddress(applyMask(address, it.next()))));
            }
            builder.add(InetAddresses.toAddrString(inetAddress));
            return new NetworkLocation((Collection<String>) builder.build());
        } catch (UnknownHostException e) {
            return NetworkLocation.ROOT_LOCATION;
        }
    }

    private byte[] applyMask(byte[] bArr, byte[] bArr2) {
        int length = bArr2.length;
        byte[] bArr3 = new byte[length];
        for (int i = 0; i < length; i++) {
            bArr3[i] = (byte) (bArr[i] & bArr2[i]);
        }
        return bArr3;
    }

    private static void validateHierarchy(List<Integer> list, AddressProtocol addressProtocol) {
        if (!Ordering.natural().isStrictlyOrdered(list)) {
            throw new IllegalArgumentException("Subnet hierarchy should be listed in the order of increasing prefix lengths");
        }
        if (list.isEmpty()) {
            return;
        }
        if (list.get(0).intValue() <= 0 || ((Integer) Iterables.getLast(list)).intValue() >= addressProtocol.getTotalBitCount()) {
            throw new IllegalArgumentException("Subnet mask prefix lengths are invalid");
        }
    }
}
