package smile.classification;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Arrays;
import java.util.Properties;
import smile.base.svm.KernelMachine;
import smile.base.svm.LASVM;
import smile.base.svm.LinearKernelMachine;
import smile.math.MathEx;
import smile.math.kernel.BinarySparseLinearKernel;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.math.kernel.SparseLinearKernel;
import smile.util.IntSet;
import smile.util.SparseArray;

/* loaded from: input_file:smile/classification/SVM.class */
public class SVM<T> extends KernelMachine<T> implements Classifier<T> {

    /* loaded from: input_file:smile/classification/SVM$Options.class */
    public static final class Options extends Record {
        private final double C;
        private final double tol;
        private final int epochs;

        public Options(double d, double d2, int i) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + d);
            }
            if (d2 <= 0.0d) {
                throw new IllegalArgumentException("Invalid tolerance: " + d2);
            }
            if (i < 1) {
                throw new IllegalArgumentException("Invalid epochs: " + i);
            }
            this.C = d;
            this.tol = d2;
            this.epochs = i;
        }

        public Options(double d) {
            this(d, 0.001d, 1);
        }

        public Properties toProperties() {
            Properties properties = new Properties();
            properties.setProperty("smile.svm.C", Double.toString(this.C));
            properties.setProperty("smile.svm.tolerance", Double.toString(this.tol));
            properties.setProperty("smile.svm.epochs", Integer.toString(this.epochs));
            return properties;
        }

        public static Options of(Properties properties) {
            return new Options(Double.parseDouble(properties.getProperty("smile.svm.C", "1.0")), Double.parseDouble(properties.getProperty("smile.svm.tolerance", "1E-3")), Integer.parseInt(properties.getProperty("smile.svm.epochs", "1")));
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Options.class), Options.class, "C;tol;epochs", "FIELD:Lsmile/classification/SVM$Options;->C:D", "FIELD:Lsmile/classification/SVM$Options;->tol:D", "FIELD:Lsmile/classification/SVM$Options;->epochs:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Options.class), Options.class, "C;tol;epochs", "FIELD:Lsmile/classification/SVM$Options;->C:D", "FIELD:Lsmile/classification/SVM$Options;->tol:D", "FIELD:Lsmile/classification/SVM$Options;->epochs:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Options.class, Object.class), Options.class, "C;tol;epochs", "FIELD:Lsmile/classification/SVM$Options;->C:D", "FIELD:Lsmile/classification/SVM$Options;->tol:D", "FIELD:Lsmile/classification/SVM$Options;->epochs:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public double C() {
            return this.C;
        }

        public double tol() {
            return this.tol;
        }

        public int epochs() {
            return this.epochs;
        }
    }

    public SVM(MercerKernel<T> mercerKernel, T[] tArr, double[] dArr, double d) {
        super(mercerKernel, tArr, dArr, d);
    }

    @Override // smile.classification.Classifier
    public int numClasses() {
        return 2;
    }

    @Override // smile.classification.Classifier
    public int[] classes() {
        return new int[]{-1, 1};
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        return score(t) > 0.0d ? 1 : -1;
    }

    public static Classifier<double[]> fit(double[][] dArr, int[] iArr, Options options) {
        final KernelMachine<T> fit = new LASVM(new LinearKernel(), options.C, options.tol).fit(dArr, iArr, options.epochs);
        return new AbstractClassifier<double[]>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.1
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.of(fit);
            }

            @Override // smile.classification.Classifier
            public int predict(double[] dArr2) {
                return this.model.f(dArr2) > 0.0d ? 1 : -1;
            }
        };
    }

    public static Classifier<int[]> fit(int[][] iArr, int[] iArr2, final int i, Options options) {
        final KernelMachine<T> fit = new LASVM(new BinarySparseLinearKernel(), options.C, options.tol).fit(iArr, iArr2, options.epochs);
        return new AbstractClassifier<int[]>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.2
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.binary(i, fit);
            }

            @Override // smile.classification.Classifier
            public int predict(int[] iArr3) {
                return this.model.f(iArr3) > 0.0d ? 1 : -1;
            }
        };
    }

    public static Classifier<SparseArray> fit(SparseArray[] sparseArrayArr, int[] iArr, final int i, Options options) {
        final KernelMachine<T> fit = new LASVM(new SparseLinearKernel(), options.C, options.tol).fit(sparseArrayArr, iArr, options.epochs);
        return new AbstractClassifier<SparseArray>(new IntSet(new int[]{-1, 1})) { // from class: smile.classification.SVM.3
            final LinearKernelMachine model;

            {
                this.model = LinearKernelMachine.sparse(i, fit);
            }

            @Override // smile.classification.Classifier
            public int predict(SparseArray sparseArray) {
                return this.model.f(sparseArray) > 0.0d ? 1 : -1;
            }
        };
    }

    public static <T> SVM<T> fit(T[] tArr, int[] iArr, MercerKernel<T> mercerKernel, Options options) {
        KernelMachine<T> fit = new LASVM(mercerKernel, options.C, options.tol).fit(tArr, iArr, options.epochs);
        return new SVM<>(fit.kernel(), fit.vectors(), fit.weights(), fit.intercept());
    }

    public static Classifier<double[]> fit(double[][] dArr, int[] iArr, Properties properties) {
        MercerKernel of = MercerKernel.of(properties.getProperty("smile.svm.kernel", "linear"));
        Options of2 = Options.of(properties);
        int[] unique = MathEx.unique(iArr);
        String lowerCase = properties.getProperty("smile.svm.type", unique.length == 2 ? "binary" : "ovr").toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -1388966911:
                if (lowerCase.equals("binary")) {
                    z = 2;
                    break;
                }
                break;
            case 110440:
                if (lowerCase.equals("ovo")) {
                    z = true;
                    break;
                }
                break;
            case 110443:
                if (lowerCase.equals("ovr")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return of instanceof LinearKernel ? OneVersusRest.fit(dArr, iArr, (dArr2, iArr2) -> {
                    return fit(dArr2, iArr2, of2);
                }) : OneVersusRest.fit(dArr, iArr, (dArr3, iArr3) -> {
                    return fit(dArr3, iArr3, of, of2);
                });
            case true:
                return of instanceof LinearKernel ? OneVersusOne.fit(dArr, iArr, (dArr4, iArr4) -> {
                    return fit(dArr4, iArr4, of2);
                }) : OneVersusOne.fit(dArr, iArr, (dArr5, iArr5) -> {
                    return fit(dArr5, iArr5, of, of2);
                });
            case true:
                Arrays.sort(unique);
                if (unique[0] != -1 || unique[1] != 1) {
                    iArr = (int[]) iArr.clone();
                    for (int i = 0; i < iArr.length; i++) {
                        iArr[i] = iArr[i] == unique[0] ? -1 : 1;
                    }
                }
                return of instanceof LinearKernel ? fit(dArr, iArr, of2) : fit(dArr, iArr, of, of2);
            default:
                throw new IllegalArgumentException("Unknown SVM type: " + lowerCase);
        }
    }
}
