package smile.clustering;

import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

/* loaded from: input_file:smile/clustering/KMedoids.class */
public class KMedoids<T> {
    private static final Logger logger = LoggerFactory.getLogger(KMedoids.class);

    private KMedoids() {
    }

    public static <T> CentroidClustering<T, T> fit(T[] tArr, Distance<T> distance, int i) {
        return fit(tArr, distance, new Clustering.Options(i, 2, 0.0125d, null));
    }

    public static <T> CentroidClustering<T, T> fit(T[] tArr, Distance<T> distance, Clustering.Options options) {
        int length = tArr.length;
        int k = options.k();
        if (k >= length) {
            throw new IllegalArgumentException("Too large k: " + k);
        }
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int min = Math.min(3, options.maxIter());
        int max = Math.max(Math.min(100, k * (length - k)), (int) Math.round(options.tol() * k * (length - k)));
        if (max > length) {
            throw new IllegalArgumentException("Too large maxNeighbor: " + max);
        }
        double d = Double.MAX_VALUE;
        CentroidClustering<T, T> centroidClustering = null;
        for (int i = 1; i <= min; i++) {
            CentroidClustering init = CentroidClustering.init("K-Medoids", tArr, k, distance);
            Object[] centers = init.centers();
            double distortion = init.distortion();
            int[] group = init.group();
            double[] proximity = init.proximity();
            Object[] objArr = (Object[]) centers.clone();
            int[] iArr = new int[length];
            double[] dArr = new double[length];
            int i2 = 1;
            while (i2 <= max) {
                System.arraycopy(centers, 0, objArr, 0, k);
                System.arraycopy(group, 0, iArr, 0, length);
                System.arraycopy(proximity, 0, dArr, 0, length);
                double randomSearch = randomSearch(tArr, objArr, iArr, dArr, distance);
                if (randomSearch < distortion) {
                    System.arraycopy(objArr, 0, centers, 0, k);
                    System.arraycopy(iArr, 0, group, 0, length);
                    System.arraycopy(dArr, 0, proximity, 0, length);
                    distortion = randomSearch;
                    logger.info("Iteration {}: random search = {}, distortion = {} ", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(distortion)});
                    i2 = 0;
                }
                i2++;
            }
            if (distortion < d) {
                d = distortion;
                centroidClustering = new CentroidClustering<>("K-Medoids", centers, distance, group, proximity);
            }
            if (controller != null) {
                controller.submit(new AlgoStatus(i, distortion));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        return centroidClustering;
    }

    private static <T> double randomSearch(T[] tArr, T[] tArr2, int[] iArr, double[] dArr, Distance<T> distance) {
        int length = tArr.length;
        int length2 = tArr2.length;
        int randomInt = MathEx.randomInt(length2);
        Object randomMedoid = getRandomMedoid(tArr, tArr2);
        tArr2[randomInt] = randomMedoid;
        IntStream.range(0, length).parallel().forEach(i -> {
            double applyAsDouble = distance.applyAsDouble(tArr[i], randomMedoid);
            double d = applyAsDouble * applyAsDouble;
            if (dArr[i] > d) {
                iArr[i] = randomInt;
                dArr[i] = d;
                return;
            }
            if (iArr[i] == randomInt) {
                dArr[i] = d;
                for (int i = 0; i < length2; i++) {
                    if (i != randomInt) {
                        double applyAsDouble2 = distance.applyAsDouble(tArr[i], tArr2[i]);
                        double d2 = applyAsDouble2 * applyAsDouble2;
                        if (dArr[i] > d2) {
                            dArr[i] = d2;
                            iArr[i] = i;
                        }
                    }
                }
            }
        });
        return MathEx.mean(dArr);
    }

    private static <T> T getRandomMedoid(T[] tArr, T[] tArr2) {
        int length = tArr.length;
        T t = tArr[MathEx.randomInt(length)];
        while (true) {
            T t2 = t;
            if (!CentroidClustering.contains(t2, tArr2)) {
                return t2;
            }
            t = tArr[MathEx.randomInt(length)];
        }
    }
}
