package com.opengamma.strata.math.impl.statistics.leastsquare;

import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.math.MathException;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.linearalgebra.DecompositionFactory;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.OGMatrixAlgebra;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import com.opengamma.strata.math.linearalgebra.DecompositionResult;
import java.util.function.Function;

/* loaded from: input_file:com/opengamma/strata/math/impl/statistics/leastsquare/NonLinearLeastSquareWithPenalty.class */
public class NonLinearLeastSquareWithPenalty {
    private static final int MAX_ATTEMPTS = 100000;
    private static final double EPS = 1.0E-8d;
    private final double _eps;
    private final Decomposition<?> _decomposition;
    private final MatrixAlgebra _algebra;
    private static final Decomposition<?> DEFAULT_DECOMP = DecompositionFactory.SV_COMMONS;
    private static final OGMatrixAlgebra MA = new OGMatrixAlgebra();
    public static final Function<DoubleArray, Boolean> UNCONSTRAINED = new Function<DoubleArray, Boolean>() { // from class: com.opengamma.strata.math.impl.statistics.leastsquare.NonLinearLeastSquareWithPenalty.1
        @Override // java.util.function.Function
        public Boolean apply(DoubleArray doubleArray) {
            return true;
        }
    };

    public NonLinearLeastSquareWithPenalty() {
        this(DEFAULT_DECOMP, MA, EPS);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition) {
        this(decomposition, MA, EPS);
    }

    public NonLinearLeastSquareWithPenalty(double d) {
        this(DEFAULT_DECOMP, MA, d);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition, double d) {
        this(decomposition, MA, d);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition, MatrixAlgebra matrixAlgebra, double d) {
        ArgChecker.notNull(decomposition, "decomposition");
        ArgChecker.notNull(matrixAlgebra, "algebra");
        ArgChecker.isTrue(d > 0.0d, "must have positive eps");
        this._decomposition = decomposition;
        this._algebra = matrixAlgebra;
        this._eps = d;
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray doubleArray, Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray2, DoubleMatrix doubleMatrix) {
        return solve(doubleArray, DoubleArray.filled(doubleArray.size(), 1.0d), function, new VectorFieldFirstOrderDifferentiator().differentiate(function), doubleArray2, doubleMatrix);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray doubleArray, DoubleArray doubleArray2, Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray3, DoubleMatrix doubleMatrix) {
        return solve(doubleArray, doubleArray2, function, new VectorFieldFirstOrderDifferentiator().differentiate(function), doubleArray3, doubleMatrix);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray doubleArray, DoubleArray doubleArray2, Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray3, DoubleMatrix doubleMatrix, Function<DoubleArray, Boolean> function2) {
        return solve(doubleArray, doubleArray2, function, new VectorFieldFirstOrderDifferentiator().differentiate(function), doubleArray3, doubleMatrix, function2);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray doubleArray, DoubleArray doubleArray2, Function<DoubleArray, DoubleArray> function, Function<DoubleArray, DoubleMatrix> function2, DoubleArray doubleArray3, DoubleMatrix doubleMatrix) {
        return solve(doubleArray, doubleArray2, function, function2, doubleArray3, doubleMatrix, UNCONSTRAINED);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray doubleArray, DoubleArray doubleArray2, Function<DoubleArray, DoubleArray> function, Function<DoubleArray, DoubleMatrix> function2, DoubleArray doubleArray3, DoubleMatrix doubleMatrix, Function<DoubleArray, Boolean> function3) {
        ArgChecker.notNull(doubleArray, "observedValues");
        ArgChecker.notNull(doubleArray2, " sigma");
        ArgChecker.notNull(function, " func");
        ArgChecker.notNull(function2, " jac");
        ArgChecker.notNull(doubleArray3, "startPos");
        ArgChecker.isTrue(doubleArray.size() == doubleArray2.size(), "observedValues and sigma must be same length");
        ArgChecker.isTrue(function3.apply(doubleArray3).booleanValue(), "The start position {} is not valid for this model. Please choose a valid start position", new Object[]{doubleArray3});
        DoubleArray doubleArray4 = doubleArray3;
        double d = 0.0d;
        DoubleArray error = getError(function, doubleArray, doubleArray2, doubleArray4);
        DoubleMatrix jacobian = getJacobian(function2, doubleArray2, doubleArray4);
        double chiSqr = getChiSqr(error) + getANorm(doubleMatrix, doubleArray4);
        DoubleArray doubleArray5 = (DoubleArray) this._algebra.subtract(getChiSqrGrad(error, jacobian), (DoubleArray) this._algebra.multiply(doubleMatrix, doubleArray4));
        for (int i = 0; i < MAX_ATTEMPTS; i++) {
            DoubleMatrix modifiedCurvatureMatrix = getModifiedCurvatureMatrix(jacobian, d, doubleMatrix);
            try {
                DecompositionResult apply = this._decomposition.apply(modifiedCurvatureMatrix);
                DoubleArray doubleArray6 = (DoubleArray) this._algebra.add(doubleArray4, apply.solve(doubleArray5));
                if (function3.apply(doubleArray6).booleanValue()) {
                    DoubleArray error2 = getError(function, doubleArray, doubleArray2, doubleArray6);
                    double aNorm = getANorm(doubleMatrix, doubleArray6);
                    double chiSqr2 = getChiSqr(error2) + aNorm;
                    if (Math.abs(chiSqr2 - chiSqr) / (1.0d + chiSqr) < this._eps) {
                        DoubleMatrix modifiedCurvatureMatrix2 = d == 0.0d ? modifiedCurvatureMatrix : getModifiedCurvatureMatrix(jacobian, 0.0d, doubleMatrix);
                        if (d > 0.0d) {
                            apply = this._decomposition.apply(modifiedCurvatureMatrix2);
                        }
                        return finish(modifiedCurvatureMatrix2, apply, chiSqr2 - aNorm, aNorm, jacobian, doubleArray6, doubleArray2);
                    }
                    if (chiSqr2 < chiSqr) {
                        d = decreaseLambda(d);
                        doubleArray4 = doubleArray6;
                        jacobian = getJacobian(function2, doubleArray2, doubleArray6);
                        doubleArray5 = (DoubleArray) this._algebra.subtract(getChiSqrGrad(error2, jacobian), (DoubleArray) this._algebra.multiply(doubleMatrix, doubleArray4));
                        chiSqr = chiSqr2;
                    } else {
                        d = increaseLambda(d);
                    }
                } else {
                    d = increaseLambda(d);
                }
            } catch (Exception e) {
                throw new MathException(e);
            }
        }
        throw new MathException("Could not converge in 100000 attempts");
    }

    private double decreaseLambda(double d) {
        return d / 10.0d;
    }

    private double increaseLambda(double d) {
        if (d == 0.0d) {
            return 0.1d;
        }
        return d * 10.0d;
    }

    private LeastSquareWithPenaltyResults finish(DoubleMatrix doubleMatrix, DecompositionResult decompositionResult, double d, double d2, DoubleMatrix doubleMatrix2, DoubleArray doubleArray, DoubleArray doubleArray2) {
        return new LeastSquareWithPenaltyResults(d, d2, doubleArray, decompositionResult.solve(DoubleMatrix.identity(doubleMatrix.rowCount())), decompositionResult.solve(getBTranspose(doubleMatrix2, doubleArray2)));
    }

    private DoubleArray getError(Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray, DoubleArray doubleArray2, DoubleArray doubleArray3) {
        int size = doubleArray.size();
        DoubleArray apply = function.apply(doubleArray3);
        ArgChecker.isTrue(size == apply.size(), "Number of data points different between model (" + apply.size() + ") and observed (" + size + ")");
        return DoubleArray.of(size, i -> {
            return (doubleArray.get(i) - apply.get(i)) / doubleArray2.get(i);
        });
    }

    private DoubleMatrix getBTranspose(DoubleMatrix doubleMatrix, DoubleArray doubleArray) {
        int rowCount = doubleMatrix.rowCount();
        int columnCount = doubleMatrix.columnCount();
        double[][] array = DoubleMatrix.filled(columnCount, rowCount).toArray();
        for (int i = 0; i < rowCount; i++) {
            double d = 1.0d / doubleArray.get(i);
            for (int i2 = 0; i2 < columnCount; i2++) {
                array[i2][i] = doubleMatrix.get(i, i2) * d;
            }
        }
        return DoubleMatrix.ofUnsafe(array);
    }

    private DoubleMatrix getJacobian(Function<DoubleArray, DoubleMatrix> function, DoubleArray doubleArray, DoubleArray doubleArray2) {
        DoubleMatrix apply = function.apply(doubleArray2);
        double[][] array = apply.toArray();
        int rowCount = apply.rowCount();
        int columnCount = apply.columnCount();
        ArgChecker.isTrue(doubleArray2.size() == columnCount, "Jacobian is wrong size");
        ArgChecker.isTrue(doubleArray.size() == rowCount, "Jacobian is wrong size");
        for (int i = 0; i < rowCount; i++) {
            double d = 1.0d / doubleArray.get(i);
            for (int i2 = 0; i2 < columnCount; i2++) {
                double[] dArr = array[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] * d;
            }
        }
        return DoubleMatrix.ofUnsafe(array);
    }

    private double getChiSqr(DoubleArray doubleArray) {
        return this._algebra.getInnerProduct(doubleArray, doubleArray);
    }

    private DoubleArray getChiSqrGrad(DoubleArray doubleArray, DoubleMatrix doubleMatrix) {
        return this._algebra.multiply(doubleArray, doubleMatrix);
    }

    private DoubleMatrix getModifiedCurvatureMatrix(DoubleMatrix doubleMatrix, double d, DoubleMatrix doubleMatrix2) {
        double d2 = 1.0d + d;
        int columnCount = doubleMatrix.columnCount();
        double[][] array = MA.add(MA.matrixTransposeMultiplyMatrix(doubleMatrix), doubleMatrix2).toArray();
        for (int i = 0; i < columnCount; i++) {
            double[] dArr = array[i];
            int i2 = i;
            dArr[i2] = dArr[i2] * d2;
        }
        return DoubleMatrix.ofUnsafe(array);
    }

    private double getANorm(DoubleMatrix doubleMatrix, DoubleArray doubleArray) {
        int size = doubleArray.size();
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                d += doubleMatrix.get(i, i2) * doubleArray.get(i) * doubleArray.get(i2);
            }
        }
        return d;
    }
}
