package smile.clustering;

import java.util.Arrays;
import smile.math.MathEx;

/* loaded from: input_file:smile/clustering/BBDTree.class */
public class BBDTree {
    private final Node root;
    private final int[] index;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/clustering/BBDTree$Node.class */
    public static class Node {
        int size;
        int index;
        final double[] center;
        final double[] radius;
        final double[] sum;
        double cost;
        Node lower;
        Node upper;

        Node(int i) {
            this.center = new double[i];
            this.radius = new double[i];
            this.sum = new double[i];
        }
    }

    public BBDTree(double[][] dArr) {
        int length = dArr.length;
        this.index = new int[length];
        for (int i = 0; i < length; i++) {
            this.index[i] = i;
        }
        this.root = buildNode(dArr, 0, length);
    }

    private Node buildNode(double[][] dArr, int i, int i2) {
        int length = dArr[0].length;
        Node node = new Node(length);
        node.size = i2 - i;
        node.index = i;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        for (int i3 = 0; i3 < length; i3++) {
            dArr2[i3] = dArr[this.index[i]][i3];
            dArr3[i3] = dArr[this.index[i]][i3];
        }
        for (int i4 = i + 1; i4 < i2; i4++) {
            for (int i5 = 0; i5 < length; i5++) {
                double d = dArr[this.index[i4]][i5];
                if (dArr2[i5] > d) {
                    dArr2[i5] = d;
                }
                if (dArr3[i5] < d) {
                    dArr3[i5] = d;
                }
            }
        }
        double d2 = -1.0d;
        int i6 = -1;
        for (int i7 = 0; i7 < length; i7++) {
            node.center[i7] = (dArr2[i7] + dArr3[i7]) / 2.0d;
            node.radius[i7] = (dArr3[i7] - dArr2[i7]) / 2.0d;
            if (node.radius[i7] > d2) {
                d2 = node.radius[i7];
                i6 = i7;
            }
        }
        if (d2 < 1.0E-10d) {
            node.upper = null;
            node.lower = null;
            System.arraycopy(dArr[this.index[i]], 0, node.sum, 0, length);
            if (i2 > i + 1) {
                int i8 = i2 - i;
                for (int i9 = 0; i9 < length; i9++) {
                    double[] dArr4 = node.sum;
                    int i10 = i9;
                    dArr4[i10] = dArr4[i10] * i8;
                }
            }
            node.cost = 0.0d;
            return node;
        }
        double d3 = node.center[i6];
        int i11 = i;
        int i12 = i2 - 1;
        int i13 = 0;
        while (i11 <= i12) {
            boolean z = dArr[this.index[i11]][i6] < d3;
            boolean z2 = dArr[this.index[i12]][i6] >= d3;
            if (!z && !z2) {
                int i14 = this.index[i11];
                this.index[i11] = this.index[i12];
                this.index[i12] = i14;
                z2 = true;
                z = true;
            }
            if (z) {
                i11++;
                i13++;
            }
            if (z2) {
                i12--;
            }
        }
        node.lower = buildNode(dArr, i, i + i13);
        node.upper = buildNode(dArr, i + i13, i2);
        for (int i15 = 0; i15 < length; i15++) {
            node.sum[i15] = node.lower.sum[i15] + node.upper.sum[i15];
        }
        double[] dArr5 = new double[length];
        for (int i16 = 0; i16 < length; i16++) {
            dArr5[i16] = node.sum[i16] / node.size;
        }
        node.cost = getNodeCost(node.lower, dArr5) + getNodeCost(node.upper, dArr5);
        return node;
    }

    private double getNodeCost(Node node, double[] dArr) {
        int length = dArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            double d2 = (node.sum[i] / node.size) - dArr[i];
            d += d2 * d2;
        }
        return node.cost + (node.size * d);
    }

    public double clustering(int i, double[][] dArr, double[][] dArr2, int[] iArr, int[] iArr2) {
        Arrays.fill(iArr, 0);
        int[] iArr3 = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr3[i2] = i2;
            Arrays.fill(dArr2[i2], 0.0d);
        }
        double filter = filter(this.root, dArr, iArr3, i, dArr2, iArr, iArr2);
        int length = dArr[0].length;
        for (int i3 = 0; i3 < i; i3++) {
            if (iArr[i3] > 0) {
                for (int i4 = 0; i4 < length; i4++) {
                    dArr[i3][i4] = dArr2[i3][i4] / iArr[i3];
                }
            }
        }
        return filter / iArr2.length;
    }

    private double filter(Node node, double[][] dArr, int[] iArr, int i, double[][] dArr2, int[] iArr2, int[] iArr3) {
        int length = dArr[0].length;
        double squaredDistance = MathEx.squaredDistance(node.center, dArr[iArr[0]]);
        int i2 = iArr[0];
        for (int i3 = 1; i3 < i; i3++) {
            double squaredDistance2 = MathEx.squaredDistance(node.center, dArr[iArr[i3]]);
            if (squaredDistance2 < squaredDistance) {
                squaredDistance = squaredDistance2;
                i2 = iArr[i3];
            }
        }
        if (node.lower != null) {
            int[] iArr4 = new int[i];
            int i4 = 0;
            for (int i5 = 0; i5 < i; i5++) {
                if (!prune(node.center, node.radius, dArr, i2, iArr[i5])) {
                    int i6 = i4;
                    i4++;
                    iArr4[i6] = iArr[i5];
                }
            }
            if (i4 > 1) {
                return filter(node.lower, dArr, iArr4, i4, dArr2, iArr2, iArr3) + filter(node.upper, dArr, iArr4, i4, dArr2, iArr2, iArr3);
            }
        }
        for (int i7 = 0; i7 < length; i7++) {
            double[] dArr3 = dArr2[i2];
            int i8 = i7;
            dArr3[i8] = dArr3[i8] + node.sum[i7];
        }
        int i9 = i2;
        iArr2[i9] = iArr2[i9] + node.size;
        int i10 = node.index + node.size;
        for (int i11 = node.index; i11 < i10; i11++) {
            iArr3[this.index[i11]] = i2;
        }
        return getNodeCost(node, dArr[i2]);
    }

    private boolean prune(double[] dArr, double[] dArr2, double[][] dArr3, int i, int i2) {
        double d;
        double d2;
        double d3;
        if (i == i2) {
            return false;
        }
        int length = dArr3[0].length;
        double[] dArr4 = dArr3[i];
        double[] dArr5 = dArr3[i2];
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            double d6 = dArr5[i3] - dArr4[i3];
            d4 += d6 * d6;
            if (d6 > 0.0d) {
                d = d5;
                d2 = dArr[i3] + dArr2[i3];
                d3 = dArr4[i3];
            } else {
                d = d5;
                d2 = dArr[i3] - dArr2[i3];
                d3 = dArr4[i3];
            }
            d5 = d + ((d2 - d3) * d6);
        }
        return d4 >= 2.0d * d5;
    }
}
