package smile.clustering;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;
import smile.util.SparseArray;

/* loaded from: input_file:smile/clustering/SIB.class */
public class SIB {
    private static final Logger logger = LoggerFactory.getLogger(SIB.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/clustering/SIB$JSDistance.class */
    public static class JSDistance implements ToDoubleBiFunction<double[], SparseArray>, Serializable {
        private static final long serialVersionUID = 1;

        private JSDistance() {
        }

        @Override // java.util.function.ToDoubleBiFunction
        public double applyAsDouble(double[] dArr, SparseArray sparseArray) {
            return MathEx.JensenShannonDivergence(dArr, sparseArray);
        }
    }

    private SIB() {
    }

    public static CentroidClustering<double[], SparseArray> fit(SparseArray[] sparseArrayArr, int i, int i2) {
        return fit(sparseArrayArr, new Clustering.Options(i, i2));
    }

    public static CentroidClustering<double[], SparseArray> fit(SparseArray[] sparseArrayArr, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double lVar = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int length = sparseArrayArr.length;
        int orElse = 1 + Arrays.stream(sparseArrayArr).flatMapToInt((v0) -> {
            return v0.indexStream();
        }).max().orElse(0);
        CentroidClustering init = CentroidClustering.init("SIB", sparseArrayArr, k, MathEx::JensenShannonDivergence);
        logger.info("Initial distortion = {}", Double.valueOf(init.distortion()));
        int[] size = init.size();
        int[] group = init.group();
        double[][] dArr = new double[k][orElse];
        IntStream.range(0, k).parallel().forEach(i -> {
            for (int i = 0; i < length; i++) {
                if (group[i] == i) {
                    size[i] = size[i] + 1;
                    Iterator it = sparseArrayArr[i].iterator();
                    while (it.hasNext()) {
                        SparseArray.Entry entry = (SparseArray.Entry) it.next();
                        double[] dArr2 = dArr[i];
                        int index = entry.index();
                        dArr2[index] = dArr2[index] + entry.value();
                    }
                }
            }
            for (int i2 = 0; i2 < orElse; i2++) {
                double[] dArr3 = dArr[i];
                int i3 = i2;
                dArr3[i3] = dArr3[i3] / size[i];
            }
        });
        int i2 = length;
        for (int i3 = 1; i3 <= maxIter && i2 > lVar; i3++) {
            i2 = 0;
            for (int i4 = 0; i4 < length; i4++) {
                int i5 = group[i4];
                double d = Double.MAX_VALUE;
                for (int i6 = 0; i6 < k; i6++) {
                    double JensenShannonDivergence = MathEx.JensenShannonDivergence(sparseArrayArr[i4], dArr[i6]);
                    if (d > JensenShannonDivergence) {
                        d = JensenShannonDivergence;
                        i5 = i6;
                    }
                }
                if (i5 != group[i4]) {
                    int i7 = group[i4];
                    for (int i8 = 0; i8 < orElse; i8++) {
                        double[] dArr2 = dArr[i5];
                        int i9 = i8;
                        dArr2[i9] = dArr2[i9] * size[i5];
                        double[] dArr3 = dArr[i7];
                        int i10 = i8;
                        dArr3[i10] = dArr3[i10] * size[i7];
                    }
                    Iterator it = sparseArrayArr[i4].iterator();
                    while (it.hasNext()) {
                        SparseArray.Entry entry = (SparseArray.Entry) it.next();
                        int index = entry.index();
                        double value = entry.value();
                        double[] dArr4 = dArr[i5];
                        dArr4[index] = dArr4[index] + value;
                        double[] dArr5 = dArr[i7];
                        dArr5[index] = dArr5[index] - value;
                        if (dArr[i7][index] < 0.0d) {
                            dArr[i7][index] = 0.0d;
                        }
                    }
                    size[i7] = size[i7] - 1;
                    int i11 = i5;
                    size[i11] = size[i11] + 1;
                    for (int i12 = 0; i12 < orElse; i12++) {
                        double[] dArr6 = dArr[i5];
                        int i13 = i12;
                        dArr6[i13] = dArr6[i13] / size[i5];
                    }
                    if (size[i7] > 0) {
                        for (int i14 = 0; i14 < orElse; i14++) {
                            double[] dArr7 = dArr[i7];
                            int i15 = i14;
                            dArr7[i15] = dArr7[i15] / size[i7];
                        }
                    }
                    group[i4] = i5;
                    i2++;
                }
            }
            logger.info("Iteration {}: assignments = {}", Integer.valueOf(i3), Integer.valueOf(i2));
            if (controller != null) {
                controller.submit(new AlgoStatus(i3, i2));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        double[] proximity = init.proximity();
        logger.info("Final distortion: {}", Double.valueOf(IntStream.range(0, length).parallel().mapToDouble(i16 -> {
            double JensenShannonDivergence2 = MathEx.JensenShannonDivergence(sparseArrayArr[i16], dArr[group[i16]]);
            double d2 = JensenShannonDivergence2 * JensenShannonDivergence2;
            proximity[i16] = d2;
            return d2;
        }).sum() / length));
        return new CentroidClustering<>("SIB", dArr, new JSDistance(), group, proximity);
    }
}
