package com.opengamma.strata.pricer.impl.option;

import com.opengamma.strata.basics.value.ValueDerivatives;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionCommons;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionResult;
import com.opengamma.strata.math.impl.rootfinding.BracketRoot;
import com.opengamma.strata.math.impl.rootfinding.RidderSingleRootFinder;
import com.opengamma.strata.pricer.impl.volatility.smile.SabrFormulaData;
import com.opengamma.strata.pricer.impl.volatility.smile.SabrHaganVolatilityFunctionProvider;
import com.opengamma.strata.pricer.impl.volatility.smile.VolatilityFunctionProvider;
import com.opengamma.strata.product.common.PutCall;
import java.util.Arrays;
import java.util.function.Function;

/* loaded from: input_file:com/opengamma/strata/pricer/impl/option/SabrExtrapolationRightFunction.class */
public final class SabrExtrapolationRightFunction {
    private static final SVDecompositionCommons SVD = new SVDecompositionCommons();
    private static final double SMALL_EXPIRY = 1.0E-6d;
    private static final double SMALL_PARAMETER = -10000.0d;
    private static final double SMALL_PRICE = 1.0E-15d;
    private final VolatilityFunctionProvider<SabrFormulaData> sabrFunction;
    private final double forward;
    private final double timeToExpiry;
    private final SabrFormulaData sabrData;
    private final double cutOffStrike;
    private final double mu;
    private final double[] parameter;
    private volatile double[] parameterDerivativeForward;
    private volatile double[][] parameterDerivativeSabr;
    private volatile double volatilityK;
    private final double[] priceK = new double[3];

    public static SabrExtrapolationRightFunction of(double d, double d2, SabrFormulaData sabrFormulaData, double d3, double d4) {
        return new SabrExtrapolationRightFunction(d, sabrFormulaData, d3, d2, d4, SabrHaganVolatilityFunctionProvider.DEFAULT);
    }

    public static SabrExtrapolationRightFunction of(double d, SabrFormulaData sabrFormulaData, double d2, double d3, double d4, VolatilityFunctionProvider<SabrFormulaData> volatilityFunctionProvider) {
        return new SabrExtrapolationRightFunction(d, sabrFormulaData, d2, d3, d4, volatilityFunctionProvider);
    }

    private SabrExtrapolationRightFunction(double d, SabrFormulaData sabrFormulaData, double d2, double d3, double d4, VolatilityFunctionProvider<SabrFormulaData> volatilityFunctionProvider) {
        ArgChecker.notNull(sabrFormulaData, "sabrData");
        ArgChecker.notNull(volatilityFunctionProvider, "volatilityFunction");
        this.sabrFunction = volatilityFunctionProvider;
        this.forward = d;
        this.sabrData = sabrFormulaData;
        this.cutOffStrike = d2;
        this.timeToExpiry = d3;
        this.mu = d4;
        if (d3 > SMALL_EXPIRY) {
            this.parameter = computesFittingParameters();
            return;
        }
        this.parameter = new double[]{SMALL_PARAMETER, 0.0d, 0.0d};
        this.parameterDerivativeForward = new double[3];
        this.parameterDerivativeSabr = new double[4][3];
    }

    public double price(double d, PutCall putCall) {
        if (d <= this.cutOffStrike) {
            return BlackFormulaRepository.price(this.forward, d, this.timeToExpiry, this.sabrFunction.volatility(this.forward, d, this.timeToExpiry, this.sabrData), putCall.isCall());
        }
        double extrapolation = extrapolation(d);
        if (putCall.isPut()) {
            extrapolation -= this.forward - d;
        }
        return extrapolation;
    }

    public double priceDerivativeStrike(double d, PutCall putCall) {
        if (d <= this.cutOffStrike) {
            ValueDerivatives volatilityAdjoint = this.sabrFunction.volatilityAdjoint(this.forward, d, this.timeToExpiry, this.sabrData);
            ValueDerivatives priceAdjoint = BlackFormulaRepository.priceAdjoint(this.forward, d, this.timeToExpiry, volatilityAdjoint.getValue(), putCall.equals(PutCall.CALL));
            return priceAdjoint.getDerivative(1) + (priceAdjoint.getDerivative(3) * volatilityAdjoint.getDerivative(1));
        }
        double extrapolationDerivative = extrapolationDerivative(d);
        if (putCall.isPut()) {
            extrapolationDerivative += 1.0d;
        }
        return extrapolationDerivative;
    }

    public double priceDerivativeForward(double d, PutCall putCall) {
        if (d <= this.cutOffStrike) {
            ValueDerivatives volatilityAdjoint = this.sabrFunction.volatilityAdjoint(this.forward, d, this.timeToExpiry, this.sabrData);
            ValueDerivatives priceAdjoint = BlackFormulaRepository.priceAdjoint(this.forward, d, this.timeToExpiry, volatilityAdjoint.getValue(), putCall == PutCall.CALL);
            return priceAdjoint.getDerivative(0) + (priceAdjoint.getDerivative(3) * volatilityAdjoint.getDerivative(0));
        }
        if (this.parameterDerivativeForward == null) {
            this.parameterDerivativeForward = computesParametersDerivativeForward();
        }
        double extrapolation = extrapolation(d);
        double d2 = extrapolation / d;
        double d3 = (extrapolation * this.parameterDerivativeForward[0]) + (d2 * this.parameterDerivativeForward[1]) + ((d2 / d) * this.parameterDerivativeForward[2]);
        if (putCall.isPut()) {
            d3 -= 1.0d;
        }
        return d3;
    }

    public ValueDerivatives priceAdjointSabr(double d, PutCall putCall) {
        double d2;
        double[] dArr = new double[4];
        if (d <= this.cutOffStrike) {
            ValueDerivatives volatilityAdjoint = this.sabrFunction.volatilityAdjoint(this.forward, d, this.timeToExpiry, this.sabrData);
            ValueDerivatives priceAdjoint = BlackFormulaRepository.priceAdjoint(this.forward, d, this.timeToExpiry, volatilityAdjoint.getValue(), putCall == PutCall.CALL);
            d2 = priceAdjoint.getValue();
            for (int i = 0; i < 4; i++) {
                dArr[i] = priceAdjoint.getDerivative(3) * volatilityAdjoint.getDerivative(i + 2);
            }
        } else {
            if (this.parameterDerivativeSabr == null) {
                this.parameterDerivativeSabr = computesParametersDerivativeSabr();
            }
            double extrapolation = extrapolation(d);
            double d3 = extrapolation / d;
            double d4 = d3 / d;
            d2 = putCall.isCall() ? extrapolation : (extrapolation - this.forward) + d;
            for (int i2 = 0; i2 < 4; i2++) {
                dArr[i2] = (extrapolation * this.parameterDerivativeSabr[i2][0]) + (d3 * this.parameterDerivativeSabr[i2][1]) + (d4 * this.parameterDerivativeSabr[i2][2]);
            }
        }
        return ValueDerivatives.of(d2, DoubleArray.ofUnsafe(dArr));
    }

    public SabrFormulaData getSabrData() {
        return this.sabrData;
    }

    public double getCutOffStrike() {
        return this.cutOffStrike;
    }

    public double getMu() {
        return this.mu;
    }

    public double getTimeToExpiry() {
        return this.timeToExpiry;
    }

    public double[] getParameter() {
        return this.parameter;
    }

    public double[] getParameterDerivativeForward() {
        if (this.parameterDerivativeForward == null) {
            this.parameterDerivativeForward = computesParametersDerivativeForward();
        }
        return this.parameterDerivativeForward;
    }

    public double[][] getParameterDerivativeSabr() {
        if (this.parameterDerivativeSabr == null) {
            this.parameterDerivativeSabr = computesParametersDerivativeSabr();
        }
        return this.parameterDerivativeSabr;
    }

    private double[] computesFittingParameters() {
        double[] dArr = new double[3];
        double[] dArr2 = new double[6];
        double[][] dArr3 = new double[2][2];
        this.volatilityK = this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, dArr2, dArr3);
        Pair<ValueDerivatives, double[][]> priceAdjoint2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
        double[] arrayUnsafe = ((ValueDerivatives) priceAdjoint2.getFirst()).getDerivatives().toArrayUnsafe();
        double[][] dArr4 = (double[][]) priceAdjoint2.getSecond();
        this.priceK[0] = ((ValueDerivatives) priceAdjoint2.getFirst()).getValue();
        this.priceK[1] = arrayUnsafe[1] + (arrayUnsafe[3] * dArr2[1]);
        this.priceK[2] = dArr4[1][1] + (dArr4[1][2] * dArr2[1]) + ((dArr4[2][1] + (dArr4[2][2] * dArr2[1])) * dArr2[1]) + (arrayUnsafe[3] * dArr3[1][1]);
        if (Math.abs(this.priceK[0]) < SMALL_PRICE && Math.abs(this.priceK[1]) < SMALL_PRICE && Math.abs(this.priceK[2]) < SMALL_PRICE) {
            return new double[]{-100.0d, 0.0d, 0.0d};
        }
        Function<Double, Double> cFunction = getCFunction(this.priceK, this.cutOffStrike, this.mu);
        BracketRoot bracketRoot = new BracketRoot();
        RidderSingleRootFinder ridderSingleRootFinder = new RidderSingleRootFinder(1.0E-5d);
        double[] bracketedPoints = bracketRoot.getBracketedPoints(cFunction, -1.0d, 1.0d);
        dArr[2] = ridderSingleRootFinder.getRoot(cFunction, Double.valueOf(bracketedPoints[0]), Double.valueOf(bracketedPoints[1])).doubleValue();
        dArr[1] = (((-2.0d) * dArr[2]) / this.cutOffStrike) - ((((this.priceK[1] / this.priceK[0]) * this.cutOffStrike) + this.mu) * this.cutOffStrike);
        dArr[0] = (Math.log(this.priceK[0] / Math.pow(this.cutOffStrike, -this.mu)) - (dArr[1] / this.cutOffStrike)) - (dArr[2] / (this.cutOffStrike * this.cutOffStrike));
        return dArr;
    }

    private double[] computesParametersDerivativeForward() {
        if (Math.abs(this.priceK[0]) < SMALL_PRICE && Math.abs(this.priceK[1]) < SMALL_PRICE && Math.abs(this.priceK[2]) < SMALL_PRICE) {
            return new double[]{0.0d, 0.0d, 0.0d};
        }
        double[] dArr = new double[6];
        double[][] dArr2 = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, dArr, dArr2);
        Pair<ValueDerivatives, double[][]> priceAdjoint2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
        double[] arrayUnsafe = ((ValueDerivatives) priceAdjoint2.getFirst()).getDerivatives().toArrayUnsafe();
        double[][] dArr3 = (double[][]) priceAdjoint2.getSecond();
        double d = (((double[][]) BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike * (1.0d + 1.0E-5d), this.timeToExpiry, this.volatilityK, true).getSecond())[1][0] - dArr3[1][0]) / (this.cutOffStrike * 1.0E-5d);
        double[][] dArr4 = (double[][]) BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK * (1.0d + 1.0E-5d), true).getSecond();
        double d2 = (dArr4[2][2] - dArr3[2][2]) / (this.volatilityK * 1.0E-5d);
        double d3 = (dArr4[0][1] - dArr3[0][1]) / (this.volatilityK * 1.0E-5d);
        double d4 = (dArr4[0][2] - dArr3[0][2]) / (this.volatilityK * 1.0E-5d);
        double d5 = (dArr4[1][1] - dArr3[1][1]) / (this.volatilityK * 1.0E-5d);
        double d6 = (dArr4[2][1] - dArr3[2][1]) / (this.volatilityK * 1.0E-5d);
        double[][] dArr5 = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike * (1.0d + 1.0E-5d), this.timeToExpiry, this.sabrData, new double[6], dArr5);
        double[] dArr6 = {arrayUnsafe[0] + (arrayUnsafe[3] * dArr[0]), dArr3[0][1] + (dArr3[2][0] * dArr[1]) + ((dArr3[1][2] + (dArr3[2][2] * dArr[1])) * dArr[0]) + (arrayUnsafe[3] * dArr2[1][0]), d + (d3 * dArr[1]) + ((d3 + (d4 * dArr[1])) * dArr[1]) + (dArr3[2][0] * dArr2[1][1]) + ((d5 + (d6 * dArr[1]) + ((d6 + (d2 * dArr[1])) * dArr[1]) + (dArr3[2][2] * dArr2[1][1])) * dArr[0]) + (2.0d * (dArr3[1][2] + (dArr3[2][2] * dArr[1])) * dArr2[1][0]) + (arrayUnsafe[3] * ((dArr5[1][0] - dArr2[1][0]) / (this.cutOffStrike * 1.0E-5d)))};
        double[][] dArr7 = new double[3][3];
        double d7 = this.priceK[0];
        double d8 = this.priceK[1];
        double d9 = this.priceK[2];
        dArr7[0][0] = d7;
        dArr7[0][1] = d7 / this.cutOffStrike;
        dArr7[0][2] = dArr7[0][1] / this.cutOffStrike;
        dArr7[1][0] = d8;
        dArr7[1][1] = (d8 - dArr7[0][1]) / this.cutOffStrike;
        dArr7[1][2] = (d8 - (2.0d * dArr7[0][1])) / (this.cutOffStrike * this.cutOffStrike);
        dArr7[2][0] = d9;
        dArr7[2][1] = (d9 + (dArr7[0][2] * (((2.0d * (this.mu + 1.0d)) + ((2.0d * this.parameter[1]) / this.cutOffStrike)) + ((4.0d * this.parameter[2]) / (this.cutOffStrike * this.cutOffStrike))))) / this.cutOffStrike;
        dArr7[2][2] = (d9 + (dArr7[0][2] * (((2.0d * ((2.0d * this.mu) + 3.0d)) + ((4.0d * this.parameter[1]) / this.cutOffStrike)) + ((8.0d * this.parameter[2]) / (this.cutOffStrike * this.cutOffStrike))))) / (this.cutOffStrike * this.cutOffStrike);
        return SVD.apply(DoubleMatrix.ofUnsafe(dArr7)).solve(dArr6);
    }

    private double[][] computesParametersDerivativeSabr() {
        double d;
        SabrFormulaData withNu;
        double[][] dArr = new double[4][3];
        if (Math.abs(this.priceK[0]) < SMALL_PRICE && Math.abs(this.priceK[1]) < SMALL_PRICE && Math.abs(this.priceK[2]) < SMALL_PRICE) {
            return dArr;
        }
        double[][] dArr2 = new double[4][3];
        double[] dArr3 = new double[6];
        double[][] dArr4 = new double[2][2];
        this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.sabrData, dArr3, dArr4);
        for (int i = 0; i < 4; i++) {
            int i2 = 2 + i;
            Pair<ValueDerivatives, double[][]> priceAdjoint2 = BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK, true);
            double[] arrayUnsafe = ((ValueDerivatives) priceAdjoint2.getFirst()).getDerivatives().toArrayUnsafe();
            double[][] dArr5 = (double[][]) priceAdjoint2.getSecond();
            dArr2[i][0] = arrayUnsafe[3] * dArr3[i2];
            double[] dArr6 = new double[6];
            double[][] dArr7 = new double[2][2];
            switch (i) {
                case 0:
                    double alpha = this.sabrData.getAlpha();
                    d = alpha * 1.0E-5d;
                    withNu = this.sabrData.withAlpha(alpha + d);
                    break;
                case 1:
                    d = 1.0E-5d;
                    withNu = this.sabrData.withBeta(this.sabrData.getBeta() + 1.0E-5d);
                    break;
                case 2:
                    d = 1.0E-5d;
                    withNu = this.sabrData.withRho(this.sabrData.getRho() + 1.0E-5d);
                    break;
                default:
                    d = 1.0E-5d;
                    withNu = this.sabrData.withNu(this.sabrData.getNu() + 1.0E-5d);
                    break;
            }
            this.sabrFunction.volatilityAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, withNu, dArr6, dArr7);
            double d2 = (dArr6[1] - dArr3[1]) / d;
            double d3 = (dArr7[1][1] - dArr4[1][1]) / d;
            dArr2[i][1] = ((dArr5[1][2] + (dArr5[2][2] * dArr3[1])) * dArr3[i2]) + (arrayUnsafe[3] * d2);
            double[][] dArr8 = (double[][]) BlackFormulaRepository.priceAdjoint2(this.forward, this.cutOffStrike, this.timeToExpiry, this.volatilityK * (1.0d + 1.0E-5d), true).getSecond();
            double d4 = (dArr8[2][2] - dArr5[2][2]) / (this.volatilityK * 1.0E-5d);
            double d5 = (dArr8[1][1] - dArr5[1][1]) / (this.volatilityK * 1.0E-5d);
            double d6 = (dArr8[2][1] - dArr5[2][1]) / (this.volatilityK * 1.0E-5d);
            dArr2[i][2] = ((d5 + (d6 * dArr3[1]) + ((d6 + (d4 * dArr3[1])) * dArr3[1]) + (dArr5[2][2] * dArr4[1][1])) * dArr3[i2]) + (2.0d * (dArr5[2][1] + (dArr5[2][2] * dArr3[1])) * d2) + (arrayUnsafe[3] * d3);
        }
        double[][] dArr9 = new double[3][3];
        double d7 = this.priceK[0];
        double d8 = this.priceK[1];
        double d9 = this.priceK[2];
        dArr9[0][0] = d7;
        dArr9[0][1] = d7 / this.cutOffStrike;
        dArr9[0][2] = dArr9[0][1] / this.cutOffStrike;
        dArr9[1][0] = d8;
        dArr9[1][1] = (d8 - dArr9[0][1]) / this.cutOffStrike;
        dArr9[1][2] = (d8 - (2.0d * dArr9[0][1])) / (this.cutOffStrike * this.cutOffStrike);
        dArr9[2][0] = d9;
        dArr9[2][1] = (d9 + (dArr9[0][2] * (((2.0d * (this.mu + 1.0d)) + ((2.0d * this.parameter[1]) / this.cutOffStrike)) + ((4.0d * this.parameter[2]) / (this.cutOffStrike * this.cutOffStrike))))) / this.cutOffStrike;
        dArr9[2][2] = (d9 + (dArr9[0][2] * (((2.0d * ((2.0d * this.mu) + 3.0d)) + ((4.0d * this.parameter[1]) / this.cutOffStrike)) + ((8.0d * this.parameter[2]) / (this.cutOffStrike * this.cutOffStrike))))) / (this.cutOffStrike * this.cutOffStrike);
        SVDecompositionResult apply = SVD.apply(DoubleMatrix.ofUnsafe(dArr9));
        for (int i3 = 0; i3 < 4; i3++) {
            dArr[i3] = apply.solve(dArr2[i3]);
        }
        return dArr;
    }

    private double extrapolation(double d) {
        return Math.pow(d, -this.mu) * Math.exp(this.parameter[0] + (this.parameter[1] / d) + (this.parameter[2] / (d * d)));
    }

    private double extrapolationDerivative(double d) {
        return ((-extrapolation(d)) * (this.mu + ((this.parameter[1] + ((2.0d * this.parameter[2]) / d)) / d))) / d;
    }

    private Function<Double, Double> getCFunction(double[] dArr, final double d, final double d2) {
        final double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        return new Function<Double, Double>() { // from class: com.opengamma.strata.pricer.impl.option.SabrExtrapolationRightFunction.1
            @Override // java.util.function.Function
            public Double apply(Double d3) {
                double doubleValue = (((-2.0d) * d3.doubleValue()) / d) - ((((copyOf[1] / copyOf[0]) * d) + d2) * d);
                double d4 = d * d;
                return Double.valueOf((((-copyOf[2]) / copyOf[0]) * d4) + (d2 * (d2 + 1.0d)) + (((2.0d * doubleValue) * (d2 + 1.0d)) / d) + ((((2.0d * d3.doubleValue()) * ((2.0d * d2) + 3.0d)) + (doubleValue * doubleValue)) / d4) + (((4.0d * doubleValue) * d3.doubleValue()) / (d4 * d)) + (((4.0d * d3.doubleValue()) * d3.doubleValue()) / (d4 * d4)));
            }
        };
    }
}
