package dev.brachtendorf.clustering;

import com.github.kilianB.pcg.fast.PcgRSFast;
import dev.brachtendorf.ArrayUtil;
import dev.brachtendorf.clustering.distance.DistanceFunction;
import dev.brachtendorf.clustering.distance.EuclideanDistance;
import java.util.DoubleSummaryStatistics;
import java.util.function.Supplier;

/* loaded from: input_file:dev/brachtendorf/clustering/KMeans.class */
public class KMeans implements ClusterAlgorithm {
    protected int k;
    protected DistanceFunction distanceFunction;
    protected int lastIterationCount;

    public KMeans(int i) {
        this(i, new EuclideanDistance());
    }

    public KMeans(int i, DistanceFunction distanceFunction) {
        this.k = i;
        this.distanceFunction = distanceFunction;
    }

    @Override // dev.brachtendorf.clustering.ClusterAlgorithm
    public ClusterResult cluster(double[][] dArr) {
        int[] iArr = new int[dArr.length];
        if (this.k == 1) {
            ArrayUtil.fillArray(iArr, (Supplier<Integer>) () -> {
                return 0;
            });
            return new ClusterResult(iArr, dArr);
        }
        if (this.k >= dArr.length) {
            throw new IllegalArgumentException("Can't compute more clusters than datapoints are present");
        }
        int length = dArr[0].length;
        computeKMeans(computeStartingClusters(dArr, this.k, length), dArr, iArr, length);
        return new ClusterResult(iArr, dArr);
    }

    protected DoubleSummaryStatistics[][] computeStartingClusters(double[][] dArr, int i, int i2) {
        PcgRSFast pcgRSFast = new PcgRSFast();
        double[][] dArr2 = new double[dArr.length][2];
        DoubleSummaryStatistics[][] doubleSummaryStatisticsArr = new DoubleSummaryStatistics[i][i2];
        for (double[] dArr3 : dArr2) {
            dArr3[0] = Double.MAX_VALUE;
            dArr3[1] = -1.7976931348623157E308d;
        }
        ArrayUtil.fillArrayMulti(doubleSummaryStatisticsArr, (Supplier<DoubleSummaryStatistics[][]>) () -> {
            return new DoubleSummaryStatistics();
        });
        for (int i3 = 0; i3 < dArr.length; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                double d = dArr[i3][i4];
                if (d < dArr2[i3][0]) {
                    dArr2[i3][0] = d;
                }
                if (d > dArr2[i3][1]) {
                    dArr2[i3][1] = d;
                }
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2; i6++) {
                doubleSummaryStatisticsArr[i5][i6].accept((pcgRSFast.nextDouble() * (dArr2[i6][1] - dArr2[i6][0])) + dArr2[i6][0]);
            }
        }
        return doubleSummaryStatisticsArr;
    }

    protected void computeKMeans(DoubleSummaryStatistics[][] doubleSummaryStatisticsArr, double[][] dArr, int[] iArr, int i) {
        boolean z;
        this.lastIterationCount = 0;
        do {
            z = false;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                double d = Double.MAX_VALUE;
                int i3 = -1;
                for (int i4 = 0; i4 < this.k; i4++) {
                    double distance = this.distanceFunction.distance(doubleSummaryStatisticsArr[i4], dArr[i2]);
                    if (distance < d) {
                        i3 = i4;
                        d = distance;
                    }
                }
                if (iArr[i2] != i3) {
                    iArr[i2] = i3;
                    z = true;
                }
            }
            if (z) {
                ArrayUtil.fillArrayMulti(doubleSummaryStatisticsArr, (Supplier<DoubleSummaryStatistics[][]>) () -> {
                    return new DoubleSummaryStatistics();
                });
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    double[] dArr2 = dArr[i5];
                    DoubleSummaryStatistics[] doubleSummaryStatisticsArr2 = doubleSummaryStatisticsArr[iArr[i5]];
                    for (int i6 = 0; i6 < i; i6++) {
                        doubleSummaryStatisticsArr2[i6].accept(dArr2[i6]);
                    }
                }
            }
            this.lastIterationCount++;
        } while (z);
    }

    public int iterations() {
        return this.lastIterationCount;
    }
}
