package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;
import java.util.stream.IntStream;
import smile.data.Dataset;
import smile.data.measure.NominalScale;
import smile.data.vector.ValueVector;
import smile.math.MathEx;
import smile.util.IntSet;

/* loaded from: input_file:smile/classification/ClassLabels.class */
public class ClassLabels implements Serializable {
    private static final long serialVersionUID = 2;
    public final int k;
    public final IntSet classes;
    public final int[] y;
    public final int[] ni;
    public final double[] priori;

    public ClassLabels(int i, int[] iArr, IntSet intSet) {
        this.k = i;
        this.y = iArr;
        this.classes = intSet;
        this.ni = count(iArr, i);
        this.priori = new double[i];
        double length = iArr.length;
        for (int i2 = 0; i2 < i; i2++) {
            this.priori[i2] = this.ni[i2] / length;
        }
    }

    public NominalScale scale() {
        String[] strArr = new String[this.classes.size()];
        for (int i = 0; i < this.classes.size(); i++) {
            strArr[i] = String.valueOf(this.classes.valueOf(i));
        }
        return new NominalScale(strArr);
    }

    public int[] indexOf(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = this.classes.indexOf(iArr[i]);
        }
        return iArr2;
    }

    public static ClassLabels fit(Dataset<?, Integer> dataset) {
        int size = dataset.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = ((Integer) dataset.get(i).y()).intValue();
        }
        return fit(iArr);
    }

    public static ClassLabels fit(int[] iArr) {
        int[] unique = MathEx.unique(iArr);
        Arrays.sort(unique);
        int length = unique.length;
        if (length < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        IntSet intSet = new IntSet(unique);
        if (unique[0] == 0 && unique[length - 1] == length - 1) {
            return new ClassLabels(length, iArr, intSet);
        }
        IntStream stream = Arrays.stream(iArr);
        Objects.requireNonNull(intSet);
        return new ClassLabels(length, stream.map(intSet::indexOf).toArray(), intSet);
    }

    public static ClassLabels fit(ValueVector valueVector) {
        int[] intArray = valueVector.toIntArray();
        NominalScale measure = valueVector.field().measure();
        if (!(measure instanceof NominalScale)) {
            return fit(intArray);
        }
        int size = measure.size();
        return new ClassLabels(size, intArray, new IntSet(IntStream.range(0, size).toArray()));
    }

    private static int[] count(int[] iArr, int i) {
        int[] iArr2 = new int[i];
        for (int i2 : iArr) {
            iArr2[i2] = iArr2[i2] + 1;
        }
        return iArr2;
    }
}
