package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.sort.QuickSort;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

/* loaded from: input_file:smile/clustering/XMeans.class */
public class XMeans {
    private static final Logger logger = LoggerFactory.getLogger(XMeans.class);
    private static final double LOG2PI = Math.log(6.283185307179586d);

    private XMeans() {
    }

    public static CentroidClustering<double[], double[]> fit(double[][] dArr, int i, int i2) {
        return fit(dArr, new Clustering.Options(i, i2));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v105, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v22, types: [double[], double[][], java.lang.Object[]] */
    public static CentroidClustering<double[], double[]> fit(double[][] dArr, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double lVar = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int length = dArr.length;
        int length2 = dArr[0].length;
        int[] iArr = new int[length];
        double[] dArr2 = new double[length];
        double[][] dArr3 = new double[k][length2];
        double[] colMeans = MathEx.colMeans(dArr);
        ?? r0 = new double[k];
        int[] iArr2 = new int[k];
        r0[0] = colMeans;
        iArr2[0] = length;
        double[] dArr4 = new double[k];
        dArr4[0] = ((Stream) Arrays.stream(dArr).parallel()).mapToDouble(dArr5 -> {
            return MathEx.squaredDistance(dArr5, colMeans);
        }).sum() / length;
        BBDTree bBDTree = new BBDTree(dArr);
        ArrayList arrayList = new ArrayList(k);
        ArrayList arrayList2 = new ArrayList();
        int i = 1;
        while (true) {
            if (i >= k) {
                break;
            }
            arrayList.clear();
            arrayList2.clear();
            double[] dArr6 = new double[i];
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = iArr2[i2];
                if (i3 < 25) {
                    logger.info("Cluster {} too small to split: {} observations", Integer.valueOf(i2), Integer.valueOf(i3));
                    dArr6[i2] = 0.0d;
                    arrayList.add(null);
                } else {
                    ?? r02 = new double[i3];
                    int i4 = 0;
                    for (int i5 = 0; i5 < length; i5++) {
                        if (iArr[i5] == i2) {
                            int i6 = i4;
                            i4++;
                            r02[i6] = dArr[i5];
                        }
                    }
                    CentroidClustering<double[], double[]> fit = KMeans.fit(r02, new Clustering.Options(2, maxIter, lVar, null));
                    arrayList.add(fit);
                    double bic = bic(2, i3, length2, fit.distortion(), fit.size());
                    double bic2 = bic(i3, length2, dArr4[i2]);
                    dArr6[i2] = bic - bic2;
                    logger.info("Cluster {} BIC: {}, BIC after split: {}, improvement: {}", new Object[]{Integer.valueOf(i2), Double.valueOf(bic2), Double.valueOf(bic), Double.valueOf(dArr6[i2])});
                }
            }
            int[] sort = QuickSort.sort(dArr6);
            for (int i7 = 0; i7 < i; i7++) {
                if (dArr6[i7] <= 0.0d) {
                    arrayList2.add(r0[sort[i7]]);
                }
            }
            int size = arrayList2.size();
            int i8 = i;
            while (true) {
                i8--;
                if (i8 < 0) {
                    break;
                }
                if (dArr6[i8] > 0.0d) {
                    if (((arrayList2.size() + i8) - size) + 1 < k) {
                        logger.info("Split cluster {}", Integer.valueOf(sort[i8]));
                        arrayList2.add((double[]) ((CentroidClustering) arrayList.get(sort[i8])).center(0));
                        arrayList2.add((double[]) ((CentroidClustering) arrayList.get(sort[i8])).center(1));
                    } else {
                        arrayList2.add(r0[sort[i8]]);
                    }
                }
            }
            if (arrayList2.size() == i) {
                logger.info("No more split. Finish with {} clusters", Integer.valueOf(i));
                break;
            }
            i = arrayList2.size();
            arrayList2.toArray((Object[]) r0);
            double d = Double.MAX_VALUE;
            double d2 = Double.MAX_VALUE;
            for (int i9 = 1; i9 <= maxIter && d > lVar; i9++) {
                double clustering = bBDTree.clustering(i, r0, dArr3, iArr2, iArr);
                d = d2 - clustering;
                d2 = clustering;
                logger.info("Iteration {}: {}-cluster distortion = {}", new Object[]{Integer.valueOf(i9), Integer.valueOf(i), Double.valueOf(d2)});
            }
            Arrays.fill(dArr4, 0.0d);
            IntStream.range(0, i).parallel().forEach(i10 -> {
                double[] dArr7 = r0[i10];
                for (int i10 = 0; i10 < length; i10++) {
                    if (iArr[i10] == i10) {
                        double squaredDistance = MathEx.squaredDistance(dArr[i10], dArr7);
                        dArr2[i10] = squaredDistance;
                        dArr4[i10] = dArr4[i10] + squaredDistance;
                    }
                }
            });
            for (int i11 = 0; i11 < i; i11++) {
                int i12 = i11;
                dArr4[i12] = dArr4[i12] / iArr2[i11];
            }
            if (controller != null) {
                controller.submit(new AlgoStatus(i, d2));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        return new CentroidClustering<>("X-Means", (double[][]) Arrays.copyOf((Object[]) r0, i), new EuclideanDistance(), iArr, dArr2);
    }

    private static double bic(int i, int i2, double d) {
        return (((((-i) * LOG2PI) + (((-i) * i2) * Math.log((i * d) / (i - 1)))) + (-(i - 1))) / 2.0d) - ((0.5d * (i2 + 1)) * Math.log(i));
    }

    private static double bic(int i, int i2, int i3, double d, int[] iArr) {
        double d2 = (i2 * d) / (i2 - i);
        double d3 = 0.0d;
        for (int i4 = 0; i4 < i; i4++) {
            d3 += logLikelihood(i, i2, iArr[i4], i3, d2);
        }
        return d3 - ((0.5d * (i + (i * i3))) * Math.log(i2));
    }

    private static double logLikelihood(int i, int i2, int i3, int i4, double d) {
        double d2 = (-i3) * LOG2PI;
        double log = (-i3) * i4 * Math.log(d);
        double d3 = -(i3 - i);
        return (((d2 + log) + d3) / 2.0d) + (i3 * Math.log(i3)) + ((-i3) * Math.log(i2));
    }
}
