package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

/* loaded from: input_file:smile/clustering/KMeans.class */
public class KMeans {
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);

    private KMeans() {
    }

    public static CentroidClustering<double[], double[]> fit(double[][] dArr, int i, int i2) {
        return fit(dArr, new Clustering.Options(i, i2));
    }

    public static CentroidClustering<double[], double[]> fit(double[][] dArr, Clustering.Options options) {
        return fit(new BBDTree(dArr), dArr, options);
    }

    public static CentroidClustering<double[], double[]> fit(BBDTree bBDTree, double[][] dArr, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double lVar = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int length = dArr.length;
        int length2 = dArr[0].length;
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        CentroidClustering init = CentroidClustering.init("K-Means", dArr, k, euclideanDistance);
        double distortion = init.distortion();
        logger.info("Initial distortion = {}", Double.valueOf(distortion));
        int[] size = init.size();
        int[] group = init.group();
        double[][] dArr2 = (double[][]) init.centers();
        updateCentroids(init, dArr);
        double[][] dArr3 = new double[k][length2];
        double d = Double.MAX_VALUE;
        for (int i = 1; i <= maxIter && d > lVar; i++) {
            double clustering = bBDTree.clustering(k, dArr2, dArr3, size, group);
            d = distortion - clustering;
            distortion = clustering;
            logger.info("Iteration {}: distortion = {}", Integer.valueOf(i), Double.valueOf(distortion));
            if (controller != null) {
                controller.submit(new AlgoStatus(i, distortion));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        if (d > lVar) {
            updateCentroids(init, dArr);
        }
        double[] proximity = init.proximity();
        IntStream.range(0, length).parallel().forEach(i2 -> {
            double applyAsDouble = euclideanDistance.applyAsDouble(dArr[i2], dArr2[group[i2]]);
            proximity[i2] = applyAsDouble * applyAsDouble;
        });
        return new CentroidClustering<>("X-Means", dArr2, euclideanDistance, group, proximity);
    }

    public static CentroidClustering<double[], double[]> lloyd(double[][] dArr, int i, int i2) {
        return lloyd(dArr, new Clustering.Options(i, i2));
    }

    public static CentroidClustering<double[], double[]> lloyd(double[][] dArr, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double lVar = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int length = dArr.length;
        int length2 = dArr[0].length;
        CentroidClustering<double[], double[]> init = CentroidClustering.init("K-Means", dArr, k, MathEx::distanceWithMissingValues);
        double distortion = init.distortion();
        logger.info("Initial distortion = {}", Double.valueOf(distortion));
        int[][] iArr = new int[k][length2];
        init.size();
        init.group();
        double d = Double.MAX_VALUE;
        for (int i = 1; i <= maxIter && d > lVar; i++) {
            updateCentroidsWithMissingValues(init, dArr, iArr);
            init = init.assign(dArr);
            d = distortion - init.distortion();
            distortion = init.distortion();
            logger.info("Iteration {}: distortion = {}", Integer.valueOf(i), Double.valueOf(distortion));
            if (controller != null) {
                controller.submit(new AlgoStatus(i, distortion));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        if (d > lVar) {
            updateCentroidsWithMissingValues(init, dArr, iArr);
        }
        return init;
    }

    static void updateCentroids(CentroidClustering<double[], double[]> centroidClustering, double[][] dArr) {
        int length = dArr.length;
        int[] size = centroidClustering.size();
        int[] group = centroidClustering.group();
        double[][] centers = centroidClustering.centers();
        int length2 = centers.length;
        int length3 = centers[0].length;
        Arrays.fill(size, 0);
        IntStream.range(0, length2).parallel().forEach(i -> {
            double[] dArr2 = new double[length3];
            for (int i = 0; i < length; i++) {
                if (group[i] == i) {
                    size[i] = size[i] + 1;
                    for (int i2 = 0; i2 < length3; i2++) {
                        int i3 = i2;
                        dArr2[i3] = dArr2[i3] + dArr[i][i2];
                    }
                }
            }
            for (int i4 = 0; i4 < length3; i4++) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / size[i];
            }
            centers[i] = dArr2;
        });
    }

    static void updateCentroidsWithMissingValues(CentroidClustering<double[], double[]> centroidClustering, double[][] dArr, int[][] iArr) {
        int length = dArr.length;
        int[] size = centroidClustering.size();
        int[] group = centroidClustering.group();
        double[][] centers = centroidClustering.centers();
        int length2 = centers.length;
        int length3 = centers[0].length;
        IntStream.range(0, length2).parallel().forEach(i -> {
            double[] dArr2 = new double[length3];
            Arrays.fill(iArr[i], 0);
            for (int i = 0; i < length; i++) {
                if (group[i] == i) {
                    size[i] = size[i] + 1;
                    for (int i2 = 0; i2 < length3; i2++) {
                        if (!Double.isNaN(dArr[i][i2])) {
                            int i3 = i2;
                            dArr2[i3] = dArr2[i3] + dArr[i][i2];
                            int[] iArr2 = iArr[i];
                            int i4 = i2;
                            iArr2[i4] = iArr2[i4] + 1;
                        }
                    }
                }
            }
            for (int i5 = 0; i5 < length3; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] / iArr[i][i5];
            }
            centers[i] = dArr2;
        });
    }
}
