package smile.regression;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.Transpose;
import smile.math.matrix.IMatrix;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/regression/LASSO.class */
public class LASSO {
    private static final Logger logger = LoggerFactory.getLogger(LASSO.class);

    /* loaded from: input_file:smile/regression/LASSO$Options.class */
    public static final class Options extends Record {
        private final double lambda;
        private final double tol;
        private final int maxIter;
        private final double alpha;
        private final double beta;
        private final double eta;
        private final int lsMaxIter;
        private final int pcgMaxIter;

        public Options(double d, double d2, int i, double d3, double d4, double d5, int i2, int i3) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
            }
            if (d2 <= 0.0d) {
                throw new IllegalArgumentException("Invalid tolerance: " + d2);
            }
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
            }
            if (d3 <= 0.0d) {
                throw new IllegalArgumentException("Invalid alpha: " + d3);
            }
            if (d4 <= 0.0d) {
                throw new IllegalArgumentException("Invalid beta: " + d4);
            }
            if (d5 <= 0.0d) {
                throw new IllegalArgumentException("Invalid eta: " + d5);
            }
            if (i2 <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of line search iterations: " + i2);
            }
            if (i3 <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of PCG iterations: " + i3);
            }
            this.lambda = d;
            this.tol = d2;
            this.maxIter = i;
            this.alpha = d3;
            this.beta = d4;
            this.eta = d5;
            this.lsMaxIter = i2;
            this.pcgMaxIter = i3;
        }

        public Options(double d) {
            this(d, 1.0E-4d, 1000);
        }

        public Options(double d, double d2, int i) {
            this(d, d2, i, 0.01d, 0.5d, 0.001d, 100, 5000);
        }

        public Properties toProperties() {
            Properties properties = new Properties();
            properties.setProperty("smile.lasso.lambda", Double.toString(this.lambda));
            properties.setProperty("smile.lasso.tolerance", Double.toString(this.tol));
            properties.setProperty("smile.lasso.iterations", Integer.toString(this.maxIter));
            properties.setProperty("smile.lasso.alpha", Double.toString(this.alpha));
            properties.setProperty("smile.lasso.beta", Double.toString(this.beta));
            properties.setProperty("smile.lasso.eta", Double.toString(this.eta));
            properties.setProperty("smile.lasso.line_search_iterations", Integer.toString(this.lsMaxIter));
            properties.setProperty("smile.lasso.pcg_iterations", Integer.toString(this.pcgMaxIter));
            return properties;
        }

        public static Options of(Properties properties) {
            return new Options(Double.parseDouble(properties.getProperty("smile.lasso.lambda", "1")), Double.parseDouble(properties.getProperty("smile.lasso.tolerance", "1E-4")), Integer.parseInt(properties.getProperty("smile.lasso.iterations", "1000")), Double.parseDouble(properties.getProperty("smile.lasso.alpha", "0.01")), Double.parseDouble(properties.getProperty("smile.lasso.beta", "0.5")), Double.parseDouble(properties.getProperty("smile.lasso.eta", "1E-3")), Integer.parseInt(properties.getProperty("smile.lasso.line_search_iterations", "100")), Integer.parseInt(properties.getProperty("smile.lasso.pcg_iterations", "5000")));
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Options.class), Options.class, "lambda;tol;maxIter;alpha;beta;eta;lsMaxIter;pcgMaxIter", "FIELD:Lsmile/regression/LASSO$Options;->lambda:D", "FIELD:Lsmile/regression/LASSO$Options;->tol:D", "FIELD:Lsmile/regression/LASSO$Options;->maxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->alpha:D", "FIELD:Lsmile/regression/LASSO$Options;->beta:D", "FIELD:Lsmile/regression/LASSO$Options;->eta:D", "FIELD:Lsmile/regression/LASSO$Options;->lsMaxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->pcgMaxIter:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Options.class), Options.class, "lambda;tol;maxIter;alpha;beta;eta;lsMaxIter;pcgMaxIter", "FIELD:Lsmile/regression/LASSO$Options;->lambda:D", "FIELD:Lsmile/regression/LASSO$Options;->tol:D", "FIELD:Lsmile/regression/LASSO$Options;->maxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->alpha:D", "FIELD:Lsmile/regression/LASSO$Options;->beta:D", "FIELD:Lsmile/regression/LASSO$Options;->eta:D", "FIELD:Lsmile/regression/LASSO$Options;->lsMaxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->pcgMaxIter:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Options.class, Object.class), Options.class, "lambda;tol;maxIter;alpha;beta;eta;lsMaxIter;pcgMaxIter", "FIELD:Lsmile/regression/LASSO$Options;->lambda:D", "FIELD:Lsmile/regression/LASSO$Options;->tol:D", "FIELD:Lsmile/regression/LASSO$Options;->maxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->alpha:D", "FIELD:Lsmile/regression/LASSO$Options;->beta:D", "FIELD:Lsmile/regression/LASSO$Options;->eta:D", "FIELD:Lsmile/regression/LASSO$Options;->lsMaxIter:I", "FIELD:Lsmile/regression/LASSO$Options;->pcgMaxIter:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public double lambda() {
            return this.lambda;
        }

        public double tol() {
            return this.tol;
        }

        public int maxIter() {
            return this.maxIter;
        }

        public double alpha() {
            return this.alpha;
        }

        public double beta() {
            return this.beta;
        }

        public double eta() {
            return this.eta;
        }

        public int lsMaxIter() {
            return this.lsMaxIter;
        }

        public int pcgMaxIter() {
            return this.pcgMaxIter;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/regression/LASSO$PCG.class */
    public static class PCG extends IMatrix implements IMatrix.Preconditioner {
        final Matrix A;
        Matrix AtA;
        final int p;
        final double[] d1;
        final double[] d2;
        final double[] prb;
        final double[] prs;
        final double[] ax;
        final double[] atax;

        PCG(Matrix matrix, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
            this.A = matrix;
            this.d1 = dArr;
            this.d2 = dArr2;
            this.prb = dArr3;
            this.prs = dArr4;
            int nrow = matrix.nrow();
            this.p = matrix.ncol();
            this.ax = new double[nrow];
            this.atax = new double[this.p];
            if (matrix.ncol() < 10000) {
                this.AtA = matrix.ata();
            }
        }

        public int nrow() {
            return 2 * this.p;
        }

        public int ncol() {
            return 2 * this.p;
        }

        public long size() {
            return this.A.size();
        }

        public void mv(double[] dArr, double[] dArr2) {
            if (this.AtA != null) {
                this.AtA.mv(dArr, this.atax);
            } else {
                this.A.mv(dArr, this.ax);
                this.A.tv(this.ax, this.atax);
            }
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = (2.0d * this.atax[i]) + (this.d1[i] * dArr[i]) + (this.d2[i] * dArr[i + this.p]);
                dArr2[i + this.p] = (this.d2[i] * dArr[i]) + (this.d1[i] * dArr[i + this.p]);
            }
        }

        public void tv(double[] dArr, double[] dArr2) {
            mv(dArr, dArr2);
        }

        public void asolve(double[] dArr, double[] dArr2) {
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = ((this.d1[i] * dArr[i]) - (this.d2[i] * dArr[i + this.p])) / this.prs[i];
                dArr2[i + this.p] = (((-this.d2[i]) * dArr[i]) + (this.prb[i] * dArr[i + this.p])) / this.prs[i];
            }
        }

        public void mv(Transpose transpose, double d, double[] dArr, double d2, double[] dArr2) {
            throw new UnsupportedOperationException();
        }

        public void mv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }

        public void tv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }
    }

    private LASSO() {
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Options(1.0d));
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Options options) {
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i = 0; i < colSds.length; i++) {
            if (MathEx.isZero(colSds[i])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", matrix.colName(i)));
            }
        }
        double[] train = train(matrix.scale(colMeans, colSds), doubleArray, options);
        int length = train.length;
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            train[i3] = train[i3] / colSds[i2];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, train, MathEx.mean(doubleArray) - MathEx.dot(train, colMeans));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] train(Matrix matrix, double[] dArr, Options options) {
        double d = options.tol;
        double d2 = options.lambda;
        int i = options.maxIter;
        double d3 = options.alpha;
        double d4 = options.beta;
        double d5 = options.eta;
        int i2 = options.lsMaxIter;
        int i3 = options.pcgMaxIter;
        int i4 = 0;
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        double[] dArr2 = new double[nrow];
        double mean = MathEx.mean(dArr);
        for (int i5 = 0; i5 < nrow; i5++) {
            dArr2[i5] = dArr[i5] - mean;
        }
        double min = Math.min(Math.max(1.0d, 1.0d / d2), (2 * ncol) / 0.001d);
        double d6 = Double.NEGATIVE_INFINITY;
        double d7 = Double.POSITIVE_INFINITY;
        double[] dArr3 = new double[ncol];
        double[] dArr4 = new double[ncol];
        double[] dArr5 = new double[nrow];
        double[][] dArr6 = new double[2][ncol];
        Arrays.fill(dArr4, 1.0d);
        for (int i6 = 0; i6 < ncol; i6++) {
            dArr6[0][i6] = dArr3[i6] - dArr4[i6];
            dArr6[1][i6] = (-dArr3[i6]) - dArr4[i6];
        }
        double[] dArr7 = new double[ncol];
        double[] dArr8 = new double[ncol];
        double[] dArr9 = new double[nrow];
        double[][] dArr10 = new double[2][ncol];
        double[] dArr11 = new double[ncol];
        double[] dArr12 = new double[ncol];
        double[] dArr13 = new double[2 * ncol];
        double[] dArr14 = new double[2 * ncol];
        double[] dArr15 = new double[ncol];
        Arrays.fill(dArr15, 2.0d);
        double[] dArr16 = new double[nrow];
        double[] dArr17 = new double[ncol];
        double[] dArr18 = new double[ncol];
        double[] dArr19 = new double[ncol];
        double[] dArr20 = new double[ncol];
        double[] dArr21 = new double[ncol];
        double[][] dArr22 = new double[2][ncol];
        double[] dArr23 = new double[ncol];
        double[] dArr24 = new double[ncol];
        PCG pcg = new PCG(matrix, dArr20, dArr21, dArr23, dArr24);
        int i7 = 1;
        while (true) {
            if (i7 > i) {
                break;
            }
            matrix.mv(dArr3, dArr5);
            for (int i8 = 0; i8 < nrow; i8++) {
                int i9 = i8;
                dArr5[i9] = dArr5[i9] - dArr2[i8];
                dArr16[i8] = 2.0d * dArr5[i8];
            }
            matrix.tv(dArr16, dArr17);
            double normInf = MathEx.normInf(dArr17);
            if (normInf > d2) {
                double d8 = d2 / normInf;
                for (int i10 = 0; i10 < nrow; i10++) {
                    int i11 = i10;
                    dArr16[i11] = dArr16[i11] * d8;
                }
            }
            double dot = MathEx.dot(dArr5, dArr5) + (d2 * MathEx.norm1(dArr3));
            d6 = Math.max(((-0.25d) * MathEx.dot(dArr16, dArr16)) - MathEx.dot(dArr16, dArr2), d6);
            double d9 = dot - d6;
            if (i7 % 10 == 0 || d9 / d6 < d) {
                logger.info("Iteration {}: primal objective = {}, dual objective = {}", new Object[]{Integer.valueOf(i7), Double.valueOf(dot), Double.valueOf(d6)});
            }
            if (d9 / d6 < d) {
                break;
            }
            if (d7 >= 0.5d) {
                min = Math.max(Math.min(((2 * ncol) * 2) / d9, 2.0d * min), min);
            }
            for (int i12 = 0; i12 < ncol; i12++) {
                double d10 = 1.0d / (dArr4[i12] + dArr3[i12]);
                double d11 = 1.0d / (dArr4[i12] - dArr3[i12]);
                dArr18[i12] = d10;
                dArr19[i12] = d11;
                dArr20[i12] = ((d10 * d10) + (d11 * d11)) / min;
                dArr21[i12] = ((d10 * d10) - (d11 * d11)) / min;
            }
            matrix.tv(dArr5, dArr22[0]);
            for (int i13 = 0; i13 < ncol; i13++) {
                dArr22[0][i13] = (2.0d * dArr22[0][i13]) - ((dArr18[i13] - dArr19[i13]) / min);
                dArr22[1][i13] = d2 - ((dArr18[i13] + dArr19[i13]) / min);
                dArr14[i13] = -dArr22[0][i13];
                dArr14[i13 + ncol] = -dArr22[1][i13];
            }
            for (int i14 = 0; i14 < ncol; i14++) {
                dArr23[i14] = dArr15[i14] + dArr20[i14];
                dArr24[i14] = (dArr23[i14] * dArr20[i14]) - (dArr21[i14] * dArr21[i14]);
            }
            double min2 = Math.min(0.1d, (d5 * d9) / Math.min(1.0d, MathEx.norm(dArr14)));
            if (i7 != 0 && i4 == 0) {
                min2 *= 0.1d;
            }
            if (pcg.solve(dArr14, dArr13, pcg, min2, 1, i3) > min2) {
                i4 = i3;
            }
            for (int i15 = 0; i15 < ncol; i15++) {
                dArr11[i15] = dArr13[i15];
                dArr12[i15] = dArr13[i15 + ncol];
            }
            double dot2 = (MathEx.dot(dArr5, dArr5) + (d2 * MathEx.sum(dArr4))) - (sumlogneg(dArr6) / min);
            d7 = 1.0d;
            double dot3 = MathEx.dot(dArr14, dArr13);
            int i16 = 0;
            while (i16 < i2) {
                for (int i17 = 0; i17 < ncol; i17++) {
                    dArr7[i17] = dArr3[i17] + (d7 * dArr11[i17]);
                    dArr8[i17] = dArr4[i17] + (d7 * dArr12[i17]);
                    dArr10[0][i17] = dArr7[i17] - dArr8[i17];
                    dArr10[1][i17] = (-dArr7[i17]) - dArr8[i17];
                }
                if (MathEx.max(dArr10) < 0.0d) {
                    matrix.mv(dArr7, dArr9);
                    for (int i18 = 0; i18 < nrow; i18++) {
                        int i19 = i18;
                        dArr9[i19] = dArr9[i19] - dArr2[i18];
                    }
                    if (((MathEx.dot(dArr9, dArr9) + (d2 * MathEx.sum(dArr8))) - (sumlogneg(dArr10) / min)) - dot2 <= d3 * d7 * dot3) {
                        break;
                    }
                }
                d7 = d4 * d7;
                i16++;
            }
            if (i16 == i2) {
                logger.warn("Linear search reaches maximum number of iterations: {}", Integer.valueOf(i2));
                break;
            }
            System.arraycopy(dArr7, 0, dArr3, 0, ncol);
            System.arraycopy(dArr8, 0, dArr4, 0, ncol);
            System.arraycopy(dArr10[0], 0, dArr6[0], 0, ncol);
            System.arraycopy(dArr10[1], 0, dArr6[1], 0, ncol);
            i7++;
        }
        if (i7 == i) {
            logger.warn("IPM reaches maximum number of iterations: {}", Integer.valueOf(i));
        }
        return dArr3;
    }

    private static double sumlogneg(double[][] dArr) {
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            for (double d2 : dArr2) {
                d += Math.log(-d2);
            }
        }
        return d;
    }
}
