package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.HammingDistance;
import smile.util.AlgoStatus;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;

/* loaded from: input_file:smile/clustering/KModes.class */
public class KModes {
    private static final Logger logger = LoggerFactory.getLogger(KModes.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/clustering/KModes$Codec.class */
    public static class Codec {
        public final int k;
        public final int[] x;
        public final IntSet encoder;

        public Codec(int[] iArr) {
            int[] unique = MathEx.unique(iArr);
            Arrays.sort(unique);
            this.x = iArr;
            this.k = unique.length;
            this.encoder = new IntSet(unique);
            if (unique[0] == 0 && unique[this.k - 1] == this.k - 1) {
                return;
            }
            int length = iArr.length;
            for (int i = 0; i < length; i++) {
                iArr[i] = this.encoder.indexOf(iArr[i]);
            }
        }

        public int valueOf(int i) {
            return this.encoder.valueOf(i);
        }
    }

    private KModes() {
    }

    public static CentroidClustering<int[], int[]> fit(int[][] iArr, int i, int i2) {
        return fit(iArr, new Clustering.Options(i, i2));
    }

    public static CentroidClustering<int[], int[]> fit(int[][] iArr, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double lVar = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int length = iArr.length;
        Codec[] codecArr = (Codec[]) IntStream.range(0, iArr[0].length).parallel().mapToObj(i -> {
            int[] iArr2 = new int[length];
            for (int i = 0; i < length; i++) {
                iArr2[i] = iArr[i][i];
            }
            return new Codec(iArr2);
        }).toArray(i2 -> {
            return new Codec[i2];
        });
        CentroidClustering<int[], int[]> init = CentroidClustering.init("K-Modes", iArr, k, new HammingDistance());
        double distortion = init.distortion();
        logger.info("Initial distortion = {}", Double.valueOf(distortion));
        double d = 2.147483647E9d;
        for (int i3 = 1; i3 <= maxIter && d > lVar; i3++) {
            updateCentroids(init, iArr, codecArr);
            init = init.assign(iArr);
            d = distortion - init.distortion();
            distortion = init.distortion();
            logger.info("Iteration {}: distortion = {}", Integer.valueOf(i3), Double.valueOf(init.distortion()));
            if (controller != null) {
                controller.submit(new AlgoStatus(i3, distortion));
                if (controller.isInterrupted()) {
                    break;
                }
            }
        }
        if (d > 0.0d) {
            updateCentroids(init, iArr, codecArr);
        }
        return init;
    }

    private static void updateCentroids(CentroidClustering<int[], int[]> centroidClustering, int[][] iArr, Codec[] codecArr) {
        int length = iArr.length;
        int[] group = centroidClustering.group();
        int[][] centers = centroidClustering.centers();
        int length2 = centers.length;
        int length3 = centers[0].length;
        IntStream.range(0, length2).parallel().forEach(i -> {
            int[] iArr2 = new int[length3];
            for (int i = 0; i < length3; i++) {
                if (codecArr[i].k > 1) {
                    int[] iArr3 = new int[codecArr[i].k];
                    int[] iArr4 = codecArr[i].x;
                    for (int i2 = 0; i2 < length; i2++) {
                        if (group[i2] == i) {
                            int i3 = iArr4[i2];
                            iArr3[i3] = iArr3[i3] + 1;
                        }
                    }
                    iArr2[i] = codecArr[i].valueOf(MathEx.whichMax(iArr3));
                }
            }
            centers[i] = iArr2;
        });
    }
}
