/*
 * Decompiled with CFR 0.152.
 */
package org.vesalainen.math;

import java.io.Serializable;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.data.RowD1Matrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.SpecializedOps;

public class LevenbergMarquardt
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static final double DELTA = 1.0E-8;
    private int iter1 = 25;
    private int iter2 = 5;
    private double maxDifference = 1.0E-8;
    private double initialLambda = 1.0;
    private Function func;
    private JacobianFactory jacobianFactory;
    private DenseMatrix64F param;
    private double initialCost;
    private double finalCost;
    private DenseMatrix64F d;
    private DenseMatrix64F H;
    private DenseMatrix64F negDelta;
    private DenseMatrix64F tempParam;
    private DenseMatrix64F A;
    private DenseMatrix64F temp0;
    private DenseMatrix64F temp1;
    private DenseMatrix64F tempDH;
    private DenseMatrix64F jacobian;

    public LevenbergMarquardt(Function func) {
        this(func, null);
    }

    public LevenbergMarquardt(Function funcCost, JacobianFactory jacobianFactory) {
        int maxElements = 1;
        int numParam = 1;
        this.temp0 = new DenseMatrix64F(maxElements, 1);
        this.temp1 = new DenseMatrix64F(maxElements, 1);
        this.tempDH = new DenseMatrix64F(maxElements, 1);
        this.jacobian = new DenseMatrix64F(numParam, maxElements);
        this.func = funcCost;
        this.jacobianFactory = jacobianFactory;
        this.param = new DenseMatrix64F(numParam, 1);
        this.d = new DenseMatrix64F(numParam, 1);
        this.H = new DenseMatrix64F(numParam, numParam);
        this.negDelta = new DenseMatrix64F(numParam, 1);
        this.tempParam = new DenseMatrix64F(numParam, 1);
        this.A = new DenseMatrix64F(numParam, numParam);
    }

    public double getInitialCost() {
        return this.initialCost;
    }

    public double getFinalCost() {
        return this.finalCost;
    }

    public DenseMatrix64F getParameters() {
        return this.param;
    }

    public boolean optimize(DenseMatrix64F initParam, DenseMatrix64F X, DenseMatrix64F Y) {
        if (X.numRows == 0) {
            return false;
        }
        this.configure(initParam, X, Y);
        this.initialCost = this.cost(this.param, X, Y);
        if (!this.adjustParam(X, Y, this.initialCost)) {
            this.finalCost = Double.NaN;
            return false;
        }
        return true;
    }

    private boolean adjustParam(DenseMatrix64F X, DenseMatrix64F Y, double prevCost) {
        double lambda = this.initialLambda;
        double difference = 1000.0;
        for (int iter = 0; iter < this.iter1 && difference > this.maxDifference; ++iter) {
            this.computeDandH(this.param, X, Y);
            boolean foundBetter = false;
            for (int i = 0; i < this.iter2; ++i) {
                this.computeA(this.A, this.H, lambda);
                if (!CommonOps.solve((DenseMatrix64F)this.A, (DenseMatrix64F)this.d, (DenseMatrix64F)this.negDelta)) {
                    return false;
                }
                CommonOps.subtract((D1Matrix64F)this.param, (D1Matrix64F)this.negDelta, (D1Matrix64F)this.tempParam);
                double cost = this.cost(this.tempParam, X, Y);
                if (cost < prevCost) {
                    foundBetter = true;
                    this.param.set((D1Matrix64F)this.tempParam);
                    difference = prevCost - cost;
                    prevCost = cost;
                    lambda /= 10.0;
                    continue;
                }
                lambda *= 10.0;
            }
            if (!foundBetter) break;
        }
        this.finalCost = prevCost;
        return true;
    }

    protected void configure(DenseMatrix64F initParam, DenseMatrix64F X, DenseMatrix64F Y) {
        if (Y.getNumRows() != X.getNumRows()) {
            throw new IllegalArgumentException("Different vector lengths");
        }
        if (Y.getNumCols() != 1) {
            throw new IllegalArgumentException("Inputs must be a column vector");
        }
        int numParam = initParam.getNumElements();
        int numPoints = Y.getNumRows();
        if (this.param.getNumElements() != initParam.getNumElements()) {
            this.param.reshape(numParam, 1, false);
            this.d.reshape(numParam, 1, false);
            this.H.reshape(numParam, numParam, false);
            this.negDelta.reshape(numParam, 1, false);
            this.tempParam.reshape(numParam, 1, false);
            this.A.reshape(numParam, numParam, false);
        }
        this.param.set((D1Matrix64F)initParam);
        this.temp0.reshape(numPoints, 1, false);
        this.temp1.reshape(numPoints, 1, false);
        this.tempDH.reshape(numPoints, 1, false);
        this.jacobian.reshape(numParam, numPoints, false);
    }

    private void computeDandH(DenseMatrix64F param, DenseMatrix64F x, DenseMatrix64F y) {
        this.func.compute(param, x, this.tempDH);
        CommonOps.subtractEquals((D1Matrix64F)this.tempDH, (D1Matrix64F)y);
        if (this.jacobianFactory != null) {
            this.jacobianFactory.computeJacobian(param, x, this.jacobian);
        } else {
            this.computeNumericalJacobian(param, x, this.jacobian);
        }
        int numParam = param.getNumElements();
        int length = y.getNumElements();
        for (int i = 0; i < numParam; ++i) {
            double total = 0.0;
            for (int j = 0; j < length; ++j) {
                total += this.tempDH.get(j, 0) * this.jacobian.get(i, j);
            }
            this.d.set(i, 0, total / (double)length);
        }
        CommonOps.multTransB((RowD1Matrix64F)this.jacobian, (RowD1Matrix64F)this.jacobian, (RowD1Matrix64F)this.H);
        CommonOps.scale((double)(1.0 / (double)length), (D1Matrix64F)this.H);
    }

    private void computeA(DenseMatrix64F A, DenseMatrix64F H, double lambda) {
        int numParam = this.param.getNumElements();
        A.set((D1Matrix64F)H);
        for (int i = 0; i < numParam; ++i) {
            A.set(i, i, A.get(i, i) + lambda);
        }
    }

    public double cost(DenseMatrix64F param, DenseMatrix64F X, DenseMatrix64F Y) {
        this.func.compute(param, X, this.temp0);
        double error = SpecializedOps.diffNormF((D1Matrix64F)this.temp0, (D1Matrix64F)Y);
        return error * error / (double)X.numRows;
    }

    protected void computeNumericalJacobian(DenseMatrix64F param, DenseMatrix64F pt, DenseMatrix64F deriv) {
        double invDelta = 1.0E8;
        this.func.compute(param, pt, this.temp0);
        int i = 0;
        while (i < param.numRows) {
            int n = i;
            param.data[n] = param.data[n] + 1.0E-8;
            this.func.compute(param, pt, this.temp1);
            CommonOps.add((double)invDelta, (D1Matrix64F)this.temp1, (double)(-invDelta), (D1Matrix64F)this.temp0, (D1Matrix64F)this.temp1);
            System.arraycopy(this.temp1.data, 0, deriv.data, i * pt.numRows, pt.numRows);
            int n2 = i++;
            param.data[n2] = param.data[n2] - 1.0E-8;
        }
    }

    public void setIter1(int iter1) {
        this.iter1 = iter1;
    }

    public void setIter2(int iter2) {
        this.iter2 = iter2;
    }

    public void setMaxDifference(double maxDifference) {
        this.maxDifference = maxDifference;
    }

    public static interface JacobianFactory {
        public void computeJacobian(DenseMatrix64F var1, DenseMatrix64F var2, DenseMatrix64F var3);
    }

    public static interface Function {
        public void compute(DenseMatrix64F var1, DenseMatrix64F var2, DenseMatrix64F var3);
    }
}

