package com.opengamma.strata.pricer.impl.volatility.smile;

import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.math.impl.cern.MersenneTwister;
import com.opengamma.strata.math.impl.cern.RandomEngine;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResults;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.strata.pricer.impl.volatility.smile.SmileModelData;
import java.util.Arrays;
import java.util.BitSet;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;

/* loaded from: input_file:com/opengamma/strata/pricer/impl/volatility/smile/SmileModelFitterTest.class */
public abstract class SmileModelFitterTest<T extends SmileModelData> {
    protected static double TIME_TO_EXPIRY = 7.0d;
    protected static double F = 0.03d;
    private static RandomEngine UNIFORM = new MersenneTwister();
    protected static double[] STRIKES = {0.005d, 0.01d, 0.02d, 0.03d, 0.04d, 0.05d, 0.07d, 0.1d};
    protected double[] _cleanVols;
    protected double[] _noisyVols;
    protected double[] _errors;
    protected VolatilityFunctionProvider<T> _model;
    protected SmileModelFitter<T> _fitter;
    protected SmileModelFitter<T> _nosiyFitter;
    protected double _chiSqEps = 1.0E-6d;
    protected double _paramValueEps = 1.0E-6d;

    abstract Logger getlogger();

    abstract VolatilityFunctionProvider<T> getModel();

    abstract T getModelData();

    abstract SmileModelFitter<T> getFitter(double d, double[] dArr, double d2, double[] dArr2, double[] dArr3, VolatilityFunctionProvider<T> volatilityFunctionProvider);

    abstract double[][] getStartValues();

    abstract double[] getRandomStartValues();

    abstract BitSet[] getFixedValues();

    /* JADX INFO: Access modifiers changed from: package-private */
    public SmileModelFitterTest() {
        VolatilityFunctionProvider<T> model = getModel();
        T modelData = getModelData();
        int length = STRIKES.length;
        this._noisyVols = new double[length];
        this._errors = new double[length];
        this._cleanVols = new double[length];
        Arrays.fill(this._errors, 1.0E-4d);
        for (int i = 0; i < length; i++) {
            this._cleanVols[i] = model.volatility(F, STRIKES[i], TIME_TO_EXPIRY, modelData);
            this._noisyVols[i] = this._cleanVols[i] + (UNIFORM.nextDouble() * this._errors[i]);
        }
        this._fitter = getFitter(F, STRIKES, TIME_TO_EXPIRY, this._cleanVols, this._errors, model);
        this._nosiyFitter = getFitter(F, STRIKES, TIME_TO_EXPIRY, this._noisyVols, this._errors, model);
    }

    @Test
    public void testExactFit() {
        double[][] startValues = getStartValues();
        BitSet[] fixedValues = getFixedValues();
        int length = startValues.length;
        ArgChecker.isTrue(fixedValues.length == length);
        for (int i = 0; i < length; i++) {
            LeastSquareResultsWithTransform solve = this._fitter.solve(DoubleArray.copyOf(startValues[i]), fixedValues[i]);
            DoubleArray modelParameters = solve.getModelParameters();
            Assertions.assertThat(0.0d).isCloseTo(solve.getChiSq(), Offset.offset(Double.valueOf(this._chiSqEps)));
            int size = modelParameters.size();
            T modelData = getModelData();
            Assertions.assertThat(modelData.getNumberOfParameters()).isEqualTo(size);
            for (int i2 = 0; i2 < size; i2++) {
                Assertions.assertThat(modelData.getParameter(i2)).isCloseTo(modelParameters.get(i2), Offset.offset(Double.valueOf(this._paramValueEps)));
            }
        }
    }

    @Test
    public void testNoisyFit() {
        double[][] startValues = getStartValues();
        BitSet[] fixedValues = getFixedValues();
        int length = startValues.length;
        ArgChecker.isTrue(fixedValues.length == length);
        for (int i = 0; i < length; i++) {
            LeastSquareResultsWithTransform solve = this._nosiyFitter.solve(DoubleArray.copyOf(startValues[i]), fixedValues[i]);
            DoubleArray modelParameters = solve.getModelParameters();
            Assertions.assertThat(solve.getChiSq() < 7.0d).isTrue();
            int size = modelParameters.size();
            T modelData = getModelData();
            Assertions.assertThat(modelData.getNumberOfParameters()).isEqualTo(size);
            for (int i2 = 0; i2 < size; i2++) {
                Assertions.assertThat(modelData.getParameter(i2)).isCloseTo(modelParameters.get(i2), Offset.offset(Double.valueOf(0.01d)));
            }
        }
    }

    @Test
    public void timeTest() {
        int length = getStartValues().length;
        for (int i = 0; i < 200; i++) {
            testNoisyFit();
        }
        long nanoTime = System.nanoTime();
        for (int i2 = 0; i2 < 1000; i2++) {
            testNoisyFit();
        }
        getlogger().info("time per fit: " + (((System.nanoTime() - nanoTime) / 1000) / length) + "ms");
    }

    @Test
    public void horribleMarketDataTest() {
        double[] dArr = {0.0012499999999999734d, 0.0024999999999999467d, 0.003750000000000031d, 0.0050000000000000044d, 0.006249999999999978d, 0.007499999999999951d, 0.008750000000000036d, 0.010000000000000009d, 0.011249999999999982d, 0.012499999999999956d, 0.01375000000000004d, 0.015000000000000013d, 0.016249999999999987d, 0.01749999999999996d, 0.018750000000000044d, 0.020000000000000018d, 0.02124999999999999d, 0.022499999999999964d, 0.02375000000000005d, 0.025000000000000022d, 0.026249999999999996d, 0.02749999999999997d, 0.028750000000000053d, 0.030000000000000027d};
        double[] dArr2 = new double[dArr.length];
        Arrays.fill(dArr2, 0.01d);
        SmileModelFitter<T> fitter = getFitter(0.0059875d, dArr, 0.09041095890410959d, new double[]{2.7100433855959642d, 1.5506135190088546d, 0.9083977239618538d, 0.738416513934868d, 0.8806973450124451d, 1.0906290439592792d, 1.2461975189027226d, 1.496275983572826d, 1.5885915338673156d, 1.4842142974195722d, 1.7667347426399058d, 1.4550288621444052d, 1.0651798188736166d, 1.143318270172714d, 1.216215092528441d, 1.2845258218014657d, 1.3488224665755535d, 1.9259326343836376d, 1.9868728791190922d, 2.0441767092857317d, 2.0982583238541026d, 2.1494622372820675d, 2.198020785622251d, 2.244237863291375d}, dArr2, getModel());
        LeastSquareResults leastSquareResults = null;
        BitSet bitSet = new BitSet();
        for (int i = 0; i < 5; i++) {
            LeastSquareResults solve = fitter.solve(DoubleArray.copyOf(getRandomStartValues()), bitSet);
            if (leastSquareResults == null) {
                leastSquareResults = solve;
            } else if (solve.getChiSq() < leastSquareResults.getChiSq()) {
                leastSquareResults = solve;
            }
        }
        if (leastSquareResults != null) {
            Assertions.assertThat(leastSquareResults.getChiSq() < 24000.0d).isTrue();
        }
    }

    @Test
    public void testJacobian() {
        T modelData = getModelData();
        int numberOfParameters = modelData.getNumberOfParameters();
        double[] dArr = new double[numberOfParameters];
        for (int i = 0; i < numberOfParameters; i++) {
            dArr[i] = modelData.getParameter(i);
        }
        testJacobian(DoubleArray.copyOf(dArr));
    }

    @Disabled
    public void testRandomJacobian() {
        for (int i = 0; i < 10; i++) {
            DoubleArray copyOf = DoubleArray.copyOf(getRandomStartValues());
            try {
                testJacobian(copyOf);
            } catch (AssertionError e) {
                System.out.println("Jacobian test failed at " + copyOf.toString());
                throw e;
            }
        }
    }

    private void testJacobian(DoubleArray doubleArray) {
        int size = doubleArray.size();
        Function modelValueFunction = this._fitter.getModelValueFunction();
        Function modelJacobianFunction = this._fitter.getModelJacobianFunction();
        Function differentiate = new VectorFieldFirstOrderDifferentiator().differentiate(modelValueFunction);
        DoubleMatrix doubleMatrix = (DoubleMatrix) modelJacobianFunction.apply(doubleArray);
        DoubleMatrix doubleMatrix2 = (DoubleMatrix) differentiate.apply(doubleArray);
        int rowCount = doubleMatrix2.rowCount();
        int columnCount = doubleMatrix2.columnCount();
        Assertions.assertThat(this._cleanVols.length).isEqualTo(rowCount);
        Assertions.assertThat(size).isEqualTo(columnCount);
        Assertions.assertThat(rowCount).isEqualTo(doubleMatrix.rowCount());
        Assertions.assertThat(columnCount).isEqualTo(doubleMatrix.columnCount());
        for (int i = 0; i < rowCount; i++) {
            for (int i2 = 0; i2 < columnCount; i2++) {
                Assertions.assertThat(doubleMatrix2.get(i, i2)).isCloseTo(doubleMatrix.get(i, i2), Offset.offset(Double.valueOf(0.02d)));
            }
        }
    }
}
