package com.opengamma.strata.pricer.model;

import com.opengamma.strata.basics.date.DayCounts;
import com.opengamma.strata.collect.TestHelper;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.market.ValueType;
import com.opengamma.strata.market.curve.ConstantCurve;
import com.opengamma.strata.market.curve.CurveName;
import com.opengamma.strata.market.curve.Curves;
import com.opengamma.strata.market.curve.InterpolatedNodalCurve;
import com.opengamma.strata.market.curve.interpolator.CurveInterpolators;
import com.opengamma.strata.market.param.ParameterMetadata;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/opengamma/strata/pricer/model/SabrParametersTest.class */
public class SabrParametersTest {
    private static final InterpolatedNodalCurve ALPHA_CURVE = InterpolatedNodalCurve.of(Curves.sabrParameterByExpiry("SabrAlpha", DayCounts.ACT_ACT_ISDA, ValueType.SABR_ALPHA), DoubleArray.of(0.0d, 10.0d), DoubleArray.of(0.2d, 0.2d), CurveInterpolators.LINEAR);
    private static final InterpolatedNodalCurve BETA_CURVE = InterpolatedNodalCurve.of(Curves.sabrParameterByExpiry("SabrBeta", DayCounts.ACT_ACT_ISDA, ValueType.SABR_BETA), DoubleArray.of(0.0d, 10.0d), DoubleArray.of(1.0d, 1.0d), CurveInterpolators.LINEAR);
    private static final InterpolatedNodalCurve RHO_CURVE = InterpolatedNodalCurve.of(Curves.sabrParameterByExpiry("SabrRho", DayCounts.ACT_ACT_ISDA, ValueType.SABR_RHO), DoubleArray.of(0.0d, 10.0d), DoubleArray.of(-0.5d, -0.5d), CurveInterpolators.LINEAR);
    private static final InterpolatedNodalCurve NU_CURVE = InterpolatedNodalCurve.of(Curves.sabrParameterByExpiry("SabrNu", DayCounts.ACT_ACT_ISDA, ValueType.SABR_NU), DoubleArray.of(0.0d, 10.0d), DoubleArray.of(0.5d, 0.5d), CurveInterpolators.LINEAR);
    private static final SabrVolatilityFormula FORMULA = SabrVolatilityFormula.hagan();
    private static final SabrParameters PARAMETERS = SabrParameters.of(ALPHA_CURVE, BETA_CURVE, RHO_CURVE, NU_CURVE, FORMULA);

    @Test
    public void getter() {
        Assertions.assertThat(PARAMETERS.getAlphaCurve()).isEqualTo(ALPHA_CURVE);
        Assertions.assertThat(PARAMETERS.getBetaCurve()).isEqualTo(BETA_CURVE);
        Assertions.assertThat(PARAMETERS.getRhoCurve()).isEqualTo(RHO_CURVE);
        Assertions.assertThat(PARAMETERS.getNuCurve()).isEqualTo(NU_CURVE);
        Assertions.assertThat(PARAMETERS.getSabrVolatilityFormula()).isEqualTo(FORMULA);
        Assertions.assertThat(PARAMETERS.getShiftCurve().getName()).isEqualTo(CurveName.of("Zero shift"));
        Assertions.assertThat(PARAMETERS.getDayCount()).isEqualTo(DayCounts.ACT_ACT_ISDA);
        Assertions.assertThat(PARAMETERS.getParameterCount()).isEqualTo(9);
        double yValue = ALPHA_CURVE.yValue(2.0d);
        double yValue2 = BETA_CURVE.yValue(2.0d);
        double yValue3 = RHO_CURVE.yValue(2.0d);
        double yValue4 = NU_CURVE.yValue(2.0d);
        Assertions.assertThat(PARAMETERS.alpha(2.0d)).isEqualTo(yValue);
        Assertions.assertThat(PARAMETERS.beta(2.0d)).isEqualTo(yValue2);
        Assertions.assertThat(PARAMETERS.rho(2.0d)).isEqualTo(yValue3);
        Assertions.assertThat(PARAMETERS.nu(2.0d)).isEqualTo(yValue4);
        Assertions.assertThat(PARAMETERS.volatility(2.0d, 1.1d, 1.05d)).isEqualTo(FORMULA.volatility(1.05d, 1.1d, 2.0d, yValue, yValue2, yValue3, yValue4));
        double[] array = PARAMETERS.volatilityAdjoint(2.0d, 1.1d, 1.05d).getDerivatives().toArray();
        double[] array2 = FORMULA.volatilityAdjoint(1.05d, 1.1d, 2.0d, yValue, yValue2, yValue3, yValue4).getDerivatives().toArray();
        for (int i = 0; i < 6; i++) {
            Assertions.assertThat(array[i]).isEqualTo(array2[i]);
        }
        for (int i2 = 0; i2 < 9; i2++) {
            if (i2 < 2) {
                Assertions.assertThat(PARAMETERS.getParameterMetadata(i2)).isEqualTo(ALPHA_CURVE.getParameterMetadata(i2));
                Assertions.assertThat(PARAMETERS.getParameter(i2)).isEqualTo(ALPHA_CURVE.getParameter(i2));
            } else if (i2 < 4) {
                Assertions.assertThat(PARAMETERS.getParameterMetadata(i2)).isEqualTo(BETA_CURVE.getParameterMetadata(i2 - 2));
                Assertions.assertThat(PARAMETERS.getParameter(i2)).isEqualTo(BETA_CURVE.getParameter(i2 - 2));
            } else if (i2 < 6) {
                Assertions.assertThat(PARAMETERS.getParameterMetadata(i2)).isEqualTo(RHO_CURVE.getParameterMetadata(i2 - 4));
                Assertions.assertThat(PARAMETERS.getParameter(i2)).isEqualTo(RHO_CURVE.getParameter(i2 - 4));
            } else if (i2 < 8) {
                Assertions.assertThat(PARAMETERS.getParameterMetadata(i2)).isEqualTo(NU_CURVE.getParameterMetadata(i2 - 6));
                Assertions.assertThat(PARAMETERS.getParameter(i2)).isEqualTo(NU_CURVE.getParameter(i2 - 6));
            } else {
                Assertions.assertThat(PARAMETERS.getParameterMetadata(i2)).isEqualTo(ParameterMetadata.empty());
                Assertions.assertThat(PARAMETERS.getParameter(i2)).isEqualTo(0.0d);
            }
        }
    }

    @Test
    public void negativeRates() {
        SabrParameters of = SabrParameters.of(ALPHA_CURVE, BETA_CURVE, RHO_CURVE, NU_CURVE, ConstantCurve.of("shfit", 0.05d), FORMULA);
        Assertions.assertThat(of.alpha(2.0d)).isEqualTo(ALPHA_CURVE.yValue(2.0d));
        Assertions.assertThat(of.beta(2.0d)).isEqualTo(BETA_CURVE.yValue(2.0d));
        Assertions.assertThat(of.rho(2.0d)).isEqualTo(RHO_CURVE.yValue(2.0d));
        Assertions.assertThat(of.nu(2.0d)).isEqualTo(NU_CURVE.yValue(2.0d));
        double yValue = ALPHA_CURVE.yValue(2.0d);
        double yValue2 = BETA_CURVE.yValue(2.0d);
        double yValue3 = RHO_CURVE.yValue(2.0d);
        double yValue4 = NU_CURVE.yValue(2.0d);
        Assertions.assertThat(of.volatility(2.0d, -0.02d, 0.015d)).isEqualTo(FORMULA.volatility(0.015d + 0.05d, (-0.02d) + 0.05d, 2.0d, yValue, yValue2, yValue3, yValue4));
        double[] array = of.volatilityAdjoint(2.0d, -0.02d, 0.015d).getDerivatives().toArray();
        double[] array2 = FORMULA.volatilityAdjoint(0.015d + 0.05d, (-0.02d) + 0.05d, 2.0d, yValue, yValue2, yValue3, yValue4).getDerivatives().toArray();
        for (int i = 0; i < 4; i++) {
            Assertions.assertThat(array[i]).isEqualTo(array2[i]);
        }
    }

    @Test
    public void perturbation() {
        SabrParameters withPerturbation = PARAMETERS.withPerturbation((i, d, parameterMetadata) -> {
            return (2.0d + i) * d;
        });
        SabrParameters sabrParameters = PARAMETERS;
        for (int i2 = 0; i2 < PARAMETERS.getParameterCount(); i2++) {
            sabrParameters = sabrParameters.withParameter(i2, (2.0d + i2) * sabrParameters.getParameter(i2));
        }
        Assertions.assertThat(withPerturbation).isEqualTo(sabrParameters);
    }

    @Test
    public void coverage() {
        TestHelper.coverImmutableBean(PARAMETERS);
    }
}
