package smile.validation;

import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.LoggerFactory;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.sort.QuickSort;
import smile.stat.Sampling;
import smile.util.IntSet;

/* loaded from: input_file:smile/validation/CrossValidation.class */
public interface CrossValidation {
    static Bag[] of(int i, int i2) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + i);
        }
        if (i2 < 0 || i2 > i) {
            throw new IllegalArgumentException("Invalid number of CV rounds: " + i2);
        }
        Bag[] bagArr = new Bag[i2];
        int[] permutate = MathEx.permutate(i);
        int i3 = i / i2;
        for (int i4 = 0; i4 < i2; i4++) {
            int i5 = i3 * i4;
            int i6 = i3 * (i4 + 1);
            if (i4 == i2 - 1) {
                i6 = i;
            }
            int[] iArr = new int[(i - i6) + i5];
            int[] iArr2 = new int[i6 - i5];
            int i7 = 0;
            int i8 = 0;
            for (int i9 = 0; i9 < i; i9++) {
                if (i9 < i5 || i9 >= i6) {
                    int i10 = i8;
                    i8++;
                    iArr[i10] = permutate[i9];
                } else {
                    int i11 = i7;
                    i7++;
                    iArr2[i11] = permutate[i9];
                }
            }
            bagArr[i4] = new Bag(iArr, iArr2);
        }
        return bagArr;
    }

    static Bag[] stratify(int[] iArr, int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + i);
        }
        int[][] strata = Sampling.strata(iArr);
        int orElseThrow = Arrays.stream(strata).mapToInt(iArr2 -> {
            return iArr2.length;
        }).min().orElseThrow(NoSuchElementException::new);
        if (orElseThrow < i) {
            LoggerFactory.getLogger(CrossValidation.class).warn("The least populated class has only {} members, which is less than k={}.", Integer.valueOf(orElseThrow), Integer.valueOf(i));
        }
        int length = iArr.length;
        int length2 = strata.length;
        for (int[] iArr3 : strata) {
            MathEx.permutate(iArr3);
        }
        int[] iArr4 = new int[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            iArr4[i2] = Math.max(1, strata[i2].length / i);
        }
        Bag[] bagArr = new Bag[i];
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = 0;
            int i5 = 0;
            int[] iArr5 = new int[length];
            int[] iArr6 = new int[length];
            for (int i6 = 0; i6 < length2; i6++) {
                int length3 = strata[i6].length;
                int i7 = iArr4[i6] * i3;
                int i8 = iArr4[i6] * (i3 + 1);
                if (i3 == i - 1) {
                    i8 = length3;
                }
                int[] iArr7 = strata[i6];
                for (int i9 = 0; i9 < length3; i9++) {
                    if (i9 < i7 || i9 >= i8) {
                        int i10 = i4;
                        i4++;
                        iArr5[i10] = iArr7[i9];
                    } else {
                        int i11 = i5;
                        i5++;
                        iArr6[i11] = iArr7[i9];
                    }
                }
            }
            int[] copyOf = Arrays.copyOf(iArr5, i4);
            int[] copyOf2 = Arrays.copyOf(iArr6, i5);
            MathEx.permutate(copyOf);
            MathEx.permutate(copyOf2);
            bagArr[i3] = new Bag(copyOf, copyOf2);
        }
        return bagArr;
    }

    static Bag[] nonoverlap(int[] iArr, int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + i);
        }
        int[] unique = MathEx.unique(iArr);
        int length = unique.length;
        if (i > length) {
            throw new IllegalArgumentException("k-fold must be not greater than the than number of groups");
        }
        Arrays.sort(unique);
        IntSet intSet = new IntSet(unique);
        int length2 = iArr.length;
        int[] iArr2 = iArr;
        if (unique[0] != 0 || unique[length - 1] != length - 1) {
            iArr2 = new int[length2];
            for (int i2 = 0; i2 < length2; i2++) {
                iArr2[i2] = intSet.indexOf(iArr[i2]);
            }
        }
        int[] iArr3 = new int[length];
        for (int i3 : iArr2) {
            iArr3[i3] = iArr3[i3] + 1;
        }
        int[] sort = QuickSort.sort(iArr3);
        int[] iArr4 = new int[i];
        int[] iArr5 = new int[length];
        for (int i4 = length - 1; i4 >= 0; i4--) {
            int whichMin = MathEx.whichMin(iArr4);
            iArr4[whichMin] = iArr4[whichMin] + iArr3[i4];
            iArr5[sort[i4]] = whichMin;
        }
        Bag[] bagArr = new Bag[i];
        for (int i5 = 0; i5 < i; i5++) {
            int[] iArr6 = new int[length2 - iArr4[i5]];
            int[] iArr7 = new int[iArr4[i5]];
            bagArr[i5] = new Bag(iArr6, iArr7);
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < length2; i8++) {
                if (iArr5[iArr2[i8]] == i5) {
                    int i9 = i7;
                    i7++;
                    iArr7[i9] = i8;
                } else {
                    int i10 = i6;
                    i6++;
                    iArr6[i10] = i8;
                }
            }
        }
        return bagArr;
    }

    static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int i, T[] tArr, int[] iArr, BiFunction<T[], int[], M> biFunction) {
        return ClassificationValidation.of(of(tArr.length, i), tArr, iArr, biFunction);
    }

    static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int i, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        return ClassificationValidation.of(of(dataFrame.size(), i), formula, dataFrame, biFunction);
    }

    static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int i, int i2, T[] tArr, int[] iArr, BiFunction<T[], int[], M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        return ClassificationValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return of(tArr.length, i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), tArr, iArr, biFunction);
    }

    static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int i, int i2, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        return ClassificationValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return of(dataFrame.size(), i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), formula, dataFrame, biFunction);
    }

    static <T, M extends Classifier<T>> ClassificationValidations<M> stratify(int i, T[] tArr, int[] iArr, BiFunction<T[], int[], M> biFunction) {
        return ClassificationValidation.of(stratify(iArr, i), tArr, iArr, biFunction);
    }

    static <M extends DataFrameClassifier> ClassificationValidations<M> stratify(int i, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        return ClassificationValidation.of(stratify(formula.y(dataFrame).toIntArray(), i), formula, dataFrame, biFunction);
    }

    static <T, M extends Classifier<T>> ClassificationValidations<M> stratify(int i, int i2, T[] tArr, int[] iArr, BiFunction<T[], int[], M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        return ClassificationValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return stratify(iArr, i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), tArr, iArr, biFunction);
    }

    static <M extends DataFrameClassifier> ClassificationValidations<M> stratify(int i, int i2, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        int[] intArray = formula.y(dataFrame).toIntArray();
        return ClassificationValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return stratify(intArray, i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), formula, dataFrame, biFunction);
    }

    static <T, M extends Regression<T>> RegressionValidations<M> regression(int i, T[] tArr, double[] dArr, BiFunction<T[], double[], M> biFunction) {
        return RegressionValidation.of(of(tArr.length, i), tArr, dArr, biFunction);
    }

    static <M extends DataFrameRegression> RegressionValidations<M> regression(int i, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        return RegressionValidation.of(of(dataFrame.size(), i), formula, dataFrame, biFunction);
    }

    static <T, M extends Regression<T>> RegressionValidations<M> regression(int i, int i2, T[] tArr, double[] dArr, BiFunction<T[], double[], M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        return RegressionValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return of(tArr.length, i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), tArr, dArr, biFunction);
    }

    static <M extends DataFrameRegression> RegressionValidations<M> regression(int i, int i2, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid round: " + i);
        }
        return RegressionValidation.of((Bag[]) IntStream.range(0, i).mapToObj(i3 -> {
            return of(dataFrame.size(), i2);
        }).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i4 -> {
            return new Bag[i4];
        }), formula, dataFrame, biFunction);
    }
}
