package smile.regression;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.matrix.Matrix;
import smile.math.special.Beta;

/* loaded from: input_file:smile/regression/OLS.class */
public class OLS {
    private static final Logger logger = LoggerFactory.getLogger(OLS.class);

    /* loaded from: input_file:smile/regression/OLS$Method.class */
    public enum Method {
        QR,
        SVD
    }

    /* loaded from: input_file:smile/regression/OLS$Options.class */
    public static final class Options extends Record {
        private final Method method;
        private final boolean stderr;
        private final boolean recursive;

        public Options() {
            this(Method.QR, true, true);
        }

        public Options(Method method, boolean z, boolean z2) {
            this.method = method;
            this.stderr = z;
            this.recursive = z2;
        }

        public Properties toProperties() {
            Properties properties = new Properties();
            properties.setProperty("smile.ols.method", this.method.toString());
            properties.setProperty("smile.ols.standard_error", Boolean.toString(this.stderr));
            properties.setProperty("smile.ols.recursive", Boolean.toString(this.recursive));
            return properties;
        }

        public static Options of(Properties properties) {
            return new Options(Method.valueOf(properties.getProperty("smile.ols.method", "QR")), Boolean.parseBoolean(properties.getProperty("smile.ols.standard_error", "true")), Boolean.parseBoolean(properties.getProperty("smile.ols.recursive", "true")));
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Options.class), Options.class, "method;stderr;recursive", "FIELD:Lsmile/regression/OLS$Options;->method:Lsmile/regression/OLS$Method;", "FIELD:Lsmile/regression/OLS$Options;->stderr:Z", "FIELD:Lsmile/regression/OLS$Options;->recursive:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Options.class), Options.class, "method;stderr;recursive", "FIELD:Lsmile/regression/OLS$Options;->method:Lsmile/regression/OLS$Method;", "FIELD:Lsmile/regression/OLS$Options;->stderr:Z", "FIELD:Lsmile/regression/OLS$Options;->recursive:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Options.class, Object.class), Options.class, "method;stderr;recursive", "FIELD:Lsmile/regression/OLS$Options;->method:Lsmile/regression/OLS$Method;", "FIELD:Lsmile/regression/OLS$Options;->stderr:Z", "FIELD:Lsmile/regression/OLS$Options;->recursive:Z").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Method method() {
            return this.method;
        }

        public boolean stderr() {
            return this.stderr;
        }

        public boolean recursive() {
            return this.recursive;
        }
    }

    private OLS() {
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Options());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Options options) {
        double[] solve;
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        if (nrow <= ncol) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(nrow), Integer.valueOf(ncol)));
        }
        Matrix.QR qr = null;
        if (options.method == Method.SVD) {
            solve = matrix.svd().solve(doubleArray);
        } else {
            try {
                qr = matrix.qr();
                solve = qr.solve(doubleArray);
            } catch (RuntimeException e) {
                logger.warn("Matrix is not of full rank, try SVD instead");
                solve = matrix.svd().solve(doubleArray);
            }
        }
        LinearModel linearModel = new LinearModel(expand, bind, matrix, doubleArray, solve, 0.0d);
        Matrix matrix2 = null;
        if (options.stderr || options.recursive) {
            matrix2 = (qr == null ? matrix.ata().cholesky(true) : qr.CholeskyOfAtA()).inverse();
            linearModel.V = matrix2;
        }
        if (options.stderr) {
            double[][] dArr = new double[ncol][4];
            linearModel.ttest = dArr;
            for (int i = 0; i < ncol; i++) {
                dArr[i][0] = solve[i];
                double sqrt = linearModel.error * Math.sqrt(matrix2.get(i, i));
                dArr[i][1] = sqrt;
                double d = solve[i] / sqrt;
                dArr[i][2] = d;
                dArr[i][3] = Beta.regularizedIncompleteBetaFunction(0.5d * linearModel.df, 0.5d, linearModel.df / (linearModel.df + (d * d)));
            }
        }
        return linearModel;
    }
}
