package com.opengamma.strata.math.impl.rootfinding.newton;

import com.google.common.primitives.Doubles;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathException;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.OGMatrixAlgebra;
import com.opengamma.strata.math.impl.rootfinding.VectorRootFinder;
import com.opengamma.strata.math.rootfind.NewtonVectorRootFinder;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/opengamma/strata/math/impl/rootfinding/newton/BaseNewtonVectorRootFinder.class */
public class BaseNewtonVectorRootFinder extends VectorRootFinder implements NewtonVectorRootFinder {
    private static final Logger log = LoggerFactory.getLogger(BaseNewtonVectorRootFinder.class);
    private static final double ALPHA = 1.0E-4d;
    private static final double BETA = 1.5d;
    private static final int FULL_RECALC_FREQ = 20;
    private final double _absoluteTol;
    private final double _relativeTol;
    private final int _maxSteps;
    private final NewtonRootFinderDirectionFunction _directionFunction;
    private final NewtonRootFinderMatrixInitializationFunction _initializationFunction;
    private final NewtonRootFinderMatrixUpdateFunction _updateFunction;
    private final MatrixAlgebra _algebra = new OGMatrixAlgebra();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/opengamma/strata/math/impl/rootfinding/newton/BaseNewtonVectorRootFinder$DataBundle.class */
    public static class DataBundle {
        private double _g0;
        private double _g1;
        private double _g2;
        private double _lambda0;
        private double _lambda1;
        private DoubleArray _deltaY;
        private DoubleArray _y;
        private DoubleArray _deltaX;
        private DoubleArray _x;

        private DataBundle() {
        }

        public double getG0() {
            return this._g0;
        }

        public double getG1() {
            return this._g1;
        }

        public double getG2() {
            return this._g2;
        }

        public double getLambda0() {
            return this._lambda0;
        }

        public double getLambda1() {
            return this._lambda1;
        }

        public DoubleArray getDeltaY() {
            return this._deltaY;
        }

        public DoubleArray getY() {
            return this._y;
        }

        public DoubleArray getDeltaX() {
            return this._deltaX;
        }

        public DoubleArray getX() {
            return this._x;
        }

        public void setG0(double d) {
            this._g0 = d;
        }

        public void setG1(double d) {
            this._g1 = d;
        }

        public void setG2(double d) {
            this._g2 = d;
        }

        public void setLambda0(double d) {
            this._lambda0 = d;
        }

        public void setDeltaY(DoubleArray doubleArray) {
            this._deltaY = doubleArray;
        }

        public void setY(DoubleArray doubleArray) {
            this._y = doubleArray;
        }

        public void setDeltaX(DoubleArray doubleArray) {
            this._deltaX = doubleArray;
        }

        public void setX(DoubleArray doubleArray) {
            this._x = doubleArray;
        }

        public void swapLambdaAndReplace(double d) {
            this._lambda1 = this._lambda0;
            this._lambda0 = d;
        }
    }

    public BaseNewtonVectorRootFinder(double d, double d2, int i, NewtonRootFinderDirectionFunction newtonRootFinderDirectionFunction, NewtonRootFinderMatrixInitializationFunction newtonRootFinderMatrixInitializationFunction, NewtonRootFinderMatrixUpdateFunction newtonRootFinderMatrixUpdateFunction) {
        ArgChecker.notNegative(d, "absolute tolerance");
        ArgChecker.notNegative(d2, "relative tolerance");
        ArgChecker.notNegative(i, "maxSteps");
        this._absoluteTol = d;
        this._relativeTol = d2;
        this._maxSteps = i;
        this._directionFunction = newtonRootFinderDirectionFunction;
        this._initializationFunction = newtonRootFinderMatrixInitializationFunction;
        this._updateFunction = newtonRootFinderMatrixUpdateFunction;
    }

    @Override // com.opengamma.strata.math.impl.rootfinding.VectorRootFinder
    public DoubleArray getRoot(Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray) {
        return findRoot(function, doubleArray);
    }

    @Override // com.opengamma.strata.math.rootfind.NewtonVectorRootFinder
    public DoubleArray findRoot(Function<DoubleArray, DoubleArray> function, DoubleArray doubleArray) {
        return findRoot(function, new VectorFieldFirstOrderDifferentiator().differentiate(function), doubleArray);
    }

    @Override // com.opengamma.strata.math.rootfind.NewtonVectorRootFinder
    public DoubleArray findRoot(Function<DoubleArray, DoubleArray> function, Function<DoubleArray, DoubleMatrix> function2, DoubleArray doubleArray) {
        DataBundle dataBundle = new DataBundle();
        Matrix checkInputsAndApplyFunction = checkInputsAndApplyFunction(function, doubleArray);
        dataBundle.setX(doubleArray);
        dataBundle.setY(checkInputsAndApplyFunction);
        dataBundle.setG0(this._algebra.getInnerProduct(checkInputsAndApplyFunction, checkInputsAndApplyFunction));
        DoubleMatrix initializedMatrix = this._initializationFunction.getInitializedMatrix(function2, doubleArray);
        if (!getNextPosition(function, initializedMatrix, dataBundle)) {
            if (isConverged(dataBundle)) {
                return dataBundle.getX();
            }
            throw new MathException("Cannot work with this starting position. Please choose another point");
        }
        int i = 0;
        int i2 = 1;
        while (!isConverged(dataBundle)) {
            if (i2 % FULL_RECALC_FREQ == 0) {
                initializedMatrix = this._initializationFunction.getInitializedMatrix(function2, dataBundle.getX());
                i2 = 1;
            } else {
                initializedMatrix = this._updateFunction.getUpdatedMatrix(function2, dataBundle.getX(), dataBundle.getDeltaX(), dataBundle.getDeltaY(), initializedMatrix);
                i2++;
            }
            if (!getNextPosition(function, initializedMatrix, dataBundle)) {
                initializedMatrix = this._initializationFunction.getInitializedMatrix(function2, dataBundle.getX());
                i2 = 1;
                if (!getNextPosition(function, initializedMatrix, dataBundle)) {
                    if (isConverged(dataBundle)) {
                        return dataBundle.getX();
                    }
                    String str = "Failed to converge in backtracking, even after a Jacobian recalculation." + getErrorMessage(dataBundle, function2);
                    log.info(str);
                    throw new MathException(str);
                }
            }
            i++;
            if (i > this._maxSteps) {
                throw new MathException("Failed to converge - maximum iterations of " + this._maxSteps + " reached." + getErrorMessage(dataBundle, function2));
            }
        }
        return dataBundle.getX();
    }

    private String getErrorMessage(DataBundle dataBundle, Function<DoubleArray, DoubleMatrix> function) {
        return "Final position:" + dataBundle.getX() + "\nlast deltaX:" + dataBundle.getDeltaX() + "\n function value:" + dataBundle.getY() + "\nJacobian: \n" + function.apply(dataBundle.getX());
    }

    private boolean getNextPosition(Function<DoubleArray, DoubleArray> function, DoubleMatrix doubleMatrix, DataBundle dataBundle) {
        DoubleArray direction = this._directionFunction.getDirection(doubleMatrix, dataBundle.getY());
        if (dataBundle.getLambda0() < 1.0d) {
            dataBundle.setLambda0(1.0d);
        } else {
            dataBundle.setLambda0(dataBundle.getLambda0() * BETA);
        }
        updatePosition(direction, function, dataBundle);
        if (!Doubles.isFinite(dataBundle.getG1())) {
            bisectBacktrack(direction, function, dataBundle);
        }
        if (dataBundle.getG1() > dataBundle.getG0() / (1.0d + (ALPHA * dataBundle.getLambda0()))) {
            quadraticBacktrack(direction, function, dataBundle);
            int i = 0;
            while (dataBundle.getG1() > dataBundle.getG0() / (1.0d + (ALPHA * dataBundle.getLambda0()))) {
                if (i > 5) {
                    return false;
                }
                cubicBacktrack(direction, function, dataBundle);
                i++;
            }
        }
        Matrix deltaX = dataBundle.getDeltaX();
        Matrix deltaY = dataBundle.getDeltaY();
        dataBundle.setG0(dataBundle.getG1());
        dataBundle.setX((DoubleArray) this._algebra.add(dataBundle.getX(), deltaX));
        dataBundle.setY((DoubleArray) this._algebra.add(dataBundle.getY(), deltaY));
        return true;
    }

    protected void updatePosition(DoubleArray doubleArray, Function<DoubleArray, DoubleArray> function, DataBundle dataBundle) {
        Matrix matrix = (DoubleArray) this._algebra.scale(doubleArray, -dataBundle.getLambda0());
        Matrix matrix2 = (DoubleArray) function.apply((DoubleArray) this._algebra.add(dataBundle.getX(), matrix));
        dataBundle.setDeltaX(matrix);
        dataBundle.setDeltaY((DoubleArray) this._algebra.subtract(matrix2, dataBundle.getY()));
        dataBundle.setG2(dataBundle.getG1());
        dataBundle.setG1(this._algebra.getInnerProduct(matrix2, matrix2));
    }

    private void bisectBacktrack(DoubleArray doubleArray, Function<DoubleArray, DoubleArray> function, DataBundle dataBundle) {
        while (true) {
            dataBundle.setLambda0(dataBundle.getLambda0() * 0.1d);
            updatePosition(doubleArray, function, dataBundle);
            if (dataBundle.getLambda0() == 0.0d) {
                throw new MathException("Failed to converge");
            }
            if (!Double.isNaN(dataBundle.getG1()) && !Double.isInfinite(dataBundle.getG1()) && !Double.isNaN(dataBundle.getG2()) && !Double.isInfinite(dataBundle.getG2())) {
                return;
            }
        }
    }

    private void quadraticBacktrack(DoubleArray doubleArray, Function<DoubleArray, DoubleArray> function, DataBundle dataBundle) {
        double lambda0 = dataBundle.getLambda0();
        double g0 = dataBundle.getG0();
        dataBundle.swapLambdaAndReplace(Math.max(0.01d * lambda0, ((g0 * lambda0) * lambda0) / (dataBundle.getG1() + (g0 * ((2.0d * lambda0) - 1.0d)))));
        updatePosition(doubleArray, function, dataBundle);
    }

    private void cubicBacktrack(DoubleArray doubleArray, Function<DoubleArray, DoubleArray> function, DataBundle dataBundle) {
        double lambda0 = dataBundle.getLambda0();
        double lambda1 = dataBundle.getLambda1();
        double g0 = dataBundle.getG0();
        double d = (1.0d / lambda0) / lambda0;
        double d2 = (1.0d / lambda1) / lambda1;
        double g1 = dataBundle.getG1() + (g0 * ((2.0d * lambda0) - 1.0d));
        double g2 = dataBundle.getG2() + (g0 * ((2.0d * lambda1) - 1.0d));
        double d3 = 1.0d / (lambda0 - lambda1);
        double d4 = d3 * ((d * g1) - (d2 * g2));
        double d5 = d3 * (((-lambda1) * d * g1) + (lambda0 * d2 * g2));
        dataBundle.swapLambdaAndReplace(Math.min(Math.max((((-d5) + Math.sqrt((d5 * d5) + ((6.0d * d4) * g0))) / 3.0d) / d4, 0.01d * lambda0), 0.75d * lambda1));
        updatePosition(doubleArray, function, dataBundle);
    }

    private boolean isConverged(DataBundle dataBundle) {
        DoubleArray deltaX = dataBundle.getDeltaX();
        DoubleArray x = dataBundle.getX();
        int size = deltaX.size();
        for (int i = 0; i < size; i++) {
            if (Math.abs(deltaX.get(i)) > this._absoluteTol + (Math.abs(x.get(i)) * this._relativeTol)) {
                return false;
            }
        }
        return Math.sqrt(dataBundle.getG0()) < this._absoluteTol;
    }
}
