package com.opengamma.strata.math.impl.statistics.leastsquare;

import com.google.common.collect.Lists;
import com.google.common.primitives.Doubles;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.DoubleArrayMath;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionCommons;
import com.opengamma.strata.math.impl.matrix.CommonsMatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.math3.util.CombinatoricsUtils;

/* loaded from: input_file:com/opengamma/strata/math/impl/statistics/leastsquare/GeneralizedLeastSquare.class */
public class GeneralizedLeastSquare {
    private final Decomposition<?> _decomposition = new SVDecompositionCommons();
    private final MatrixAlgebra _algebra = new CommonsMatrixAlgebra();

    public <T> GeneralizedLeastSquareResults<T> solve(T[] tArr, double[] dArr, double[] dArr2, List<Function<T, Double>> list) {
        return solve(tArr, dArr, dArr2, list, 0.0d, 0);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(T[] tArr, double[] dArr, double[] dArr2, List<Function<T, Double>> list, double d, int i) {
        ArgChecker.notNull(tArr, "x null");
        ArgChecker.notNull(dArr, "y null");
        ArgChecker.notNull(dArr2, "sigma null");
        ArgChecker.notEmpty(list, "empty basisFunctions");
        int length = tArr.length;
        ArgChecker.isTrue(length > 0, "no data");
        ArgChecker.isTrue(dArr.length == length, "y wrong length");
        ArgChecker.isTrue(dArr2.length == length, "sigma wrong length");
        ArgChecker.isTrue(d >= 0.0d, "negative lambda");
        ArgChecker.isTrue(i >= 0, "difference order");
        return solveImp(Lists.newArrayList(tArr), Lists.newArrayList(Doubles.asList(dArr)), Lists.newArrayList(Doubles.asList(dArr2)), list, d, i);
    }

    GeneralizedLeastSquareResults<Double> solve(double[] dArr, double[] dArr2, double[] dArr3, List<Function<Double, Double>> list, double d, int i) {
        return solve(DoubleArrayMath.toObject(dArr), dArr2, dArr3, list, d, i);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> list, List<Double> list2, List<Double> list3, List<Function<T, Double>> list4) {
        return solve(list, list2, list3, list4, 0.0d, 0);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> list, List<Double> list2, List<Double> list3, List<Function<T, Double>> list4, double d, int i) {
        ArgChecker.notEmpty(list, "empty measurement points");
        ArgChecker.notEmpty(list2, "empty measurement values");
        ArgChecker.notEmpty(list3, "empty measurement errors");
        ArgChecker.notEmpty(list4, "empty basisFunctions");
        int size = list.size();
        ArgChecker.isTrue(size > 0, "no data");
        ArgChecker.isTrue(list2.size() == size, "y wrong length");
        ArgChecker.isTrue(list3.size() == size, "sigma wrong length");
        ArgChecker.isTrue(d >= 0.0d, "negative lambda");
        ArgChecker.isTrue(i >= 0, "difference order");
        return solveImp(list, list2, list3, list4, d, i);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> list, List<Double> list2, List<Double> list3, List<Function<T, Double>> list4, int[] iArr, double[] dArr, int[] iArr2) {
        ArgChecker.notEmpty(list, "empty measurement points");
        ArgChecker.notEmpty(list2, "empty measurement values");
        ArgChecker.notEmpty(list3, "empty measurement errors");
        ArgChecker.notEmpty(list4, "empty basisFunctions");
        int size = list.size();
        ArgChecker.isTrue(size > 0, "no data");
        ArgChecker.isTrue(list2.size() == size, "y wrong length");
        ArgChecker.isTrue(list3.size() == size, "sigma wrong length");
        int length = iArr.length;
        ArgChecker.isTrue(length == dArr.length, "number of penalty functions {} must be equal to number of directions {}", new Object[]{Integer.valueOf(dArr.length), Integer.valueOf(length)});
        ArgChecker.isTrue(length == iArr2.length, "number of difference order {} must be equal to number of directions {}", new Object[]{Integer.valueOf(iArr2.length), Integer.valueOf(length)});
        for (int i = 0; i < length; i++) {
            ArgChecker.isTrue(iArr[i] > 0, "sizes must be >= 1");
            ArgChecker.isTrue(dArr[i] >= 0.0d, "negative lambda");
            ArgChecker.isTrue(iArr2[i] >= 0, "difference order");
        }
        return solveImp(list, list2, list3, list4, iArr, dArr, iArr2);
    }

    /* JADX WARN: Type inference failed for: r0v24, types: [com.opengamma.strata.math.linearalgebra.DecompositionResult] */
    private <T> GeneralizedLeastSquareResults<T> solveImp(List<T> list, List<Double> list2, List<Double> list3, List<Function<T, Double>> list4, double d, int i) {
        int size = list.size();
        int size2 = list4.size();
        double[] dArr = new double[size2];
        double[] dArr2 = new double[size];
        double[][] dArr3 = new double[size2][size];
        for (int i2 = 0; i2 < size; i2++) {
            double doubleValue = list3.get(i2).doubleValue();
            ArgChecker.isTrue(doubleValue > 0.0d, "sigma must be greater than zero");
            dArr2[i2] = (1.0d / doubleValue) / doubleValue;
        }
        for (int i3 = 0; i3 < size2; i3++) {
            for (int i4 = 0; i4 < size; i4++) {
                dArr3[i3][i4] = list4.get(i3).apply(list.get(i4)).doubleValue();
            }
        }
        for (int i5 = 0; i5 < size2; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < size; i6++) {
                d2 += list2.get(i6).doubleValue() * dArr3[i5][i6] * dArr2[i6];
            }
            dArr[i5] = d2;
        }
        DoubleArray copyOf = DoubleArray.copyOf(dArr);
        Matrix aMatrix = getAMatrix(dArr3, dArr2);
        if (d > 0.0d) {
            aMatrix = (DoubleMatrix) this._algebra.add(aMatrix, this._algebra.scale(getDiffMatrix(size2, i), d));
        }
        ?? apply = this._decomposition.apply((DoubleMatrix) aMatrix);
        DoubleArray solve = apply.solve(copyOf);
        DoubleMatrix solve2 = apply.solve(DoubleMatrix.identity(size2));
        double d3 = 0.0d;
        for (int i7 = 0; i7 < size; i7++) {
            double d4 = 0.0d;
            for (int i8 = 0; i8 < size2; i8++) {
                d4 += solve.get(i8) * dArr3[i8][i7];
            }
            d3 += MathUtils.pow2(list2.get(i7).doubleValue() - d4) * dArr2[i7];
        }
        return new GeneralizedLeastSquareResults<>(list4, d3, solve, solve2);
    }

    /* JADX WARN: Type inference failed for: r0v26, types: [com.opengamma.strata.math.linearalgebra.DecompositionResult] */
    private <T> GeneralizedLeastSquareResults<T> solveImp(List<T> list, List<Double> list2, List<Double> list3, List<Function<T, Double>> list4, int[] iArr, double[] dArr, int[] iArr2) {
        int length = iArr.length;
        int size = list.size();
        int size2 = list4.size();
        double[] dArr2 = new double[size2];
        double[] dArr3 = new double[size];
        double[][] dArr4 = new double[size2][size];
        for (int i = 0; i < size; i++) {
            double doubleValue = list3.get(i).doubleValue();
            ArgChecker.isTrue(doubleValue > 0.0d, "sigma must be great than zero");
            dArr3[i] = (1.0d / doubleValue) / doubleValue;
        }
        for (int i2 = 0; i2 < size2; i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                dArr4[i2][i3] = list4.get(i2).apply(list.get(i3)).doubleValue();
            }
        }
        for (int i4 = 0; i4 < size2; i4++) {
            double d = 0.0d;
            for (int i5 = 0; i5 < size; i5++) {
                d += list2.get(i5).doubleValue() * dArr4[i4][i5] * dArr3[i5];
            }
            dArr2[i4] = d;
        }
        DoubleArray copyOf = DoubleArray.copyOf(dArr2);
        DoubleMatrix aMatrix = getAMatrix(dArr4, dArr3);
        for (int i6 = 0; i6 < length; i6++) {
            if (dArr[i6] > 0.0d) {
                aMatrix = this._algebra.add(aMatrix, this._algebra.scale(getDiffMatrix(iArr, iArr2[i6], i6), dArr[i6]));
            }
        }
        ?? apply = this._decomposition.apply(aMatrix);
        DoubleArray solve = apply.solve(copyOf);
        DoubleMatrix solve2 = apply.solve(DoubleMatrix.identity(size2));
        double d2 = 0.0d;
        for (int i7 = 0; i7 < size; i7++) {
            double d3 = 0.0d;
            for (int i8 = 0; i8 < size2; i8++) {
                d3 += solve.get(i8) * dArr4[i8][i7];
            }
            d2 += MathUtils.pow2(list2.get(i7).doubleValue() - d3) * dArr3[i7];
        }
        return new GeneralizedLeastSquareResults<>(list4, d2, solve, solve2);
    }

    private DoubleMatrix getAMatrix(double[][] dArr, double[] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[][] dArr3 = new double[length][length];
        for (int i = 0; i < length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < length2; i2++) {
                d += MathUtils.pow2(dArr[i][i2]) * dArr2[i2];
            }
            dArr3[i][i] = d;
            for (int i3 = i + 1; i3 < length; i3++) {
                double d2 = 0.0d;
                for (int i4 = 0; i4 < length2; i4++) {
                    d2 += dArr[i][i4] * dArr[i3][i4] * dArr2[i4];
                }
                dArr3[i][i3] = d2;
                dArr3[i3][i] = d2;
            }
        }
        return DoubleMatrix.copyOf(dArr3);
    }

    private DoubleMatrix getDiffMatrix(int i, int i2) {
        ArgChecker.isTrue(i2 < i, "difference order too high");
        double[][] dArr = new double[i][i];
        if (i == 0) {
            return DoubleMatrix.copyOf(dArr);
        }
        int[] iArr = new int[i2 + 1];
        int i3 = 1;
        for (int i4 = i2; i4 >= 0; i4--) {
            iArr[i4] = (int) (i3 * CombinatoricsUtils.binomialCoefficient(i2, i4));
            i3 *= -1;
        }
        for (int i5 = i2; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2 + 1; i6++) {
                dArr[i5][(i6 + i5) - i2] = iArr[i6];
            }
        }
        Matrix copyOf = DoubleMatrix.copyOf(dArr);
        return this._algebra.multiply(this._algebra.getTranspose(copyOf), copyOf);
    }

    private DoubleMatrix getDiffMatrix(int[] iArr, int i, int i2) {
        int length = iArr.length;
        Matrix diffMatrix = getDiffMatrix(iArr[i2], i);
        int i3 = 1;
        int i4 = 1;
        for (int i5 = i2 + 1; i5 < length; i5++) {
            i3 *= iArr[i5];
        }
        for (int i6 = 0; i6 < i2; i6++) {
            i4 *= iArr[i6];
        }
        Matrix matrix = diffMatrix;
        if (i3 != 1) {
            matrix = (DoubleMatrix) this._algebra.kroneckerProduct(DoubleMatrix.identity(i3), matrix);
        }
        if (i4 != 1) {
            matrix = (DoubleMatrix) this._algebra.kroneckerProduct(matrix, DoubleMatrix.identity(i4));
        }
        return matrix;
    }
}
