/*
 * Decompiled with CFR 0.152.
 */
package org.hipparchus.optim.nonlinear.vector.constrained;

import java.util.ArrayList;
import org.hipparchus.exception.Localizable;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.linear.Array2DRowRealMatrix;
import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.optim.ConvergenceChecker;
import org.hipparchus.optim.LocalizedOptimFormats;
import org.hipparchus.optim.OptimizationData;
import org.hipparchus.optim.nonlinear.scalar.ObjectiveFunction;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPConvergenceChecker;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPKKT;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPModifiedRuizEquilibrium;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPOption;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPSolution;
import org.hipparchus.optim.nonlinear.vector.constrained.LagrangeSolution;
import org.hipparchus.optim.nonlinear.vector.constrained.LinearBoundedConstraint;
import org.hipparchus.optim.nonlinear.vector.constrained.LinearEqualityConstraint;
import org.hipparchus.optim.nonlinear.vector.constrained.LinearInequalityConstraint;
import org.hipparchus.optim.nonlinear.vector.constrained.QPOptimizer;
import org.hipparchus.optim.nonlinear.vector.constrained.QuadraticFunction;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.MathUtils;

public class ADMMQPOptimizer
extends QPOptimizer {
    private ADMMQPOption settings = new ADMMQPOption();
    private LinearEqualityConstraint eqConstraint;
    private LinearInequalityConstraint iqConstraint;
    private LinearBoundedConstraint bqConstraint;
    private QuadraticFunction function;
    private final ADMMQPKKT solver = new ADMMQPKKT();
    private ADMMQPConvergenceChecker checker;
    private boolean converged = false;
    private double rho = 0.1;

    @Override
    public ConvergenceChecker<LagrangeSolution> getConvergenceChecker() {
        return this.checker;
    }

    @Override
    public LagrangeSolution optimize(OptimizationData ... optData) {
        return super.optimize(optData);
    }

    @Override
    protected void parseOptimizationData(OptimizationData ... optData) {
        super.parseOptimizationData(optData);
        for (OptimizationData data : optData) {
            if (data instanceof ObjectiveFunction) {
                this.function = (QuadraticFunction)((ObjectiveFunction)data).getObjectiveFunction();
                continue;
            }
            if (data instanceof LinearEqualityConstraint) {
                this.eqConstraint = (LinearEqualityConstraint)data;
                continue;
            }
            if (data instanceof LinearInequalityConstraint) {
                this.iqConstraint = (LinearInequalityConstraint)data;
                continue;
            }
            if (data instanceof LinearBoundedConstraint) {
                this.bqConstraint = (LinearBoundedConstraint)data;
                continue;
            }
            if (!(data instanceof ADMMQPOption)) continue;
            this.settings = (ADMMQPOption)data;
        }
        int n = this.function.dim();
        if (this.eqConstraint != null) {
            int nDual = this.eqConstraint.dimY();
            if (nDual >= n) {
                throw new MathIllegalArgumentException((Localizable)LocalizedOptimFormats.CONSTRAINTS_RANK, new Object[]{nDual, n});
            }
            int nTest = this.eqConstraint.getA().getColumnDimension();
            if (nDual == 0) {
                throw new MathIllegalArgumentException((Localizable)LocalizedCoreFormats.ZERO_NOT_ALLOWED, new Object[0]);
            }
            MathUtils.checkDimension((int)nTest, (int)n);
        }
    }

    @Override
    public LagrangeSolution doOptimize() {
        int n = this.function.dim();
        int me = 0;
        int mi = 0;
        int mb = 0;
        int rhoUpdateCount = 0;
        RealMatrix H = this.function.getP();
        RealVector q = this.function.getQ();
        if (this.eqConstraint != null) {
            me = this.eqConstraint.dimY();
        }
        if (this.iqConstraint != null) {
            mi = this.iqConstraint.dimY();
        }
        if (this.bqConstraint != null) {
            mb = this.bqConstraint.dimY();
        }
        ArrayRealVector lb = new ArrayRealVector(me + mi + mb);
        ArrayRealVector ub = new ArrayRealVector(me + mi + mb);
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(me + mi + mb, n);
        if (this.eqConstraint != null) {
            A.setSubMatrix(this.eqConstraint.jacobian(null).getData(), 0, 0);
            lb.setSubVector(0, this.eqConstraint.getLowerBound());
            ub.setSubVector(0, this.eqConstraint.getUpperBound());
        }
        if (this.iqConstraint != null) {
            A.setSubMatrix(this.iqConstraint.jacobian(null).getData(), me, 0);
            ub.setSubVector(me, this.iqConstraint.getUpperBound());
            lb.setSubVector(me, this.iqConstraint.getLowerBound());
        }
        if (mb > 0) {
            A.setSubMatrix(this.bqConstraint.jacobian(null).getData(), me + mi, 0);
            ub.setSubVector(me + mi, this.bqConstraint.getUpperBound());
            lb.setSubVector(me + mi, this.bqConstraint.getLowerBound());
        }
        this.checker = new ADMMQPConvergenceChecker(H, (RealMatrix)A, q, this.settings.getEps(), this.settings.getEps());
        RealMatrix Hw = H.copy();
        RealMatrix Aw = A.copy();
        RealVector qw = q.copy();
        RealVector ubw = ub.copy();
        RealVector lbw = lb.copy();
        ArrayRealVector x = this.getStartPoint() != null ? new ArrayRealVector(this.getStartPoint()) : new ArrayRealVector(this.function.dim());
        ADMMQPModifiedRuizEquilibrium dec = new ADMMQPModifiedRuizEquilibrium(H, (RealMatrix)A, q);
        if (this.settings.isScaling()) {
            dec.normalize(this.settings.getEps(), this.settings.getScaleMaxIteration());
            Hw = dec.getScaledH();
            Aw = dec.getScaledA();
            qw = dec.getScaledQ();
            lbw = dec.getScaledLUb((RealVector)lb);
            ubw = dec.getScaledLUb((RealVector)ub);
            x = dec.scaleX(x.copy());
        }
        ADMMQPConvergenceChecker checkerRho = new ADMMQPConvergenceChecker(Hw, Aw, qw, this.settings.getEps(), this.settings.getEps());
        RealVector z = Aw.operate((RealVector)x);
        ArrayRealVector y = new ArrayRealVector(me + mi + mb);
        this.solver.initialize(Hw, Aw, qw, me, lbw, ubw, this.rho, this.settings.getSigma(), this.settings.getAlpha());
        RealVector xstar = null;
        RealVector ystar = null;
        while (this.iterations.getCount() <= this.iterations.getMaximalCount()) {
            RealVector zstar;
            double maxD;
            double maxP;
            double rd;
            double rp;
            boolean updated;
            ADMMQPSolution sol = this.solver.iterate(new RealVector[]{x, y, z});
            x = sol.getX();
            y = sol.getLambda();
            z = sol.getZ();
            if (rhoUpdateCount < this.settings.getMaxRhoIteration() && (updated = this.manageRho(me, rp = checkerRho.residualPrime((RealVector)x, z), rd = checkerRho.residualDual((RealVector)x, (RealVector)y), maxP = checkerRho.maxPrimal((RealVector)x, z), maxD = checkerRho.maxDual((RealVector)x, (RealVector)y)))) {
                ++rhoUpdateCount;
            }
            if (this.settings.isScaling()) {
                xstar = dec.unscaleX((RealVector)x);
                ystar = dec.unscaleY((RealVector)y);
                zstar = dec.unscaleZ(z);
            } else {
                xstar = x.copy();
                ystar = y.copy();
                zstar = z.copy();
            }
            rp = this.checker.residualPrime(xstar, zstar);
            rd = this.checker.residualDual(xstar, ystar);
            double maxPrimal = this.checker.maxPrimal(xstar, zstar);
            double maxDual = this.checker.maxDual(xstar, ystar);
            if (this.checker.converged(rp, rd, maxPrimal, maxDual)) {
                this.converged = true;
                break;
            }
            this.iterations.increment();
        }
        if (this.settings.isPolishing()) {
            ADMMQPSolution finalSol = this.polish(Hw, Aw, qw, lbw, ubw, (RealVector)x, (RealVector)y, z);
            if (this.settings.isScaling()) {
                xstar = dec.unscaleX(finalSol.getX());
                ystar = dec.unscaleY(finalSol.getLambda());
            } else {
                xstar = finalSol.getX();
                ystar = finalSol.getLambda();
            }
        }
        for (int i = 0; i < me + mi; ++i) {
            ystar.setEntry(i, -ystar.getEntry(i));
        }
        return new LagrangeSolution(xstar, ystar, this.function.value(xstar));
    }

    public boolean isConverged() {
        return this.converged;
    }

    private ADMMQPSolution polish(RealMatrix H, RealMatrix A, RealVector q, RealVector lb, RealVector ub, RealVector x, RealVector y, RealVector z) {
        int j;
        ArrayList<double[]> Aentry = new ArrayList<double[]>();
        ArrayList<Double> lubEntry = new ArrayList<Double>();
        ArrayList<Double> yEntry = new ArrayList<Double>();
        for (j = 0; j < A.getRowDimension(); ++j) {
            if (!(z.getEntry(j) - lb.getEntry(j) < -y.getEntry(j))) continue;
            Aentry.add(A.getRow(j));
            lubEntry.add(lb.getEntry(j));
            yEntry.add(y.getEntry(j));
        }
        for (j = 0; j < A.getRowDimension(); ++j) {
            if (!(-z.getEntry(j) + ub.getEntry(j) < y.getEntry(j))) continue;
            Aentry.add(A.getRow(j));
            lubEntry.add(ub.getEntry(j));
            yEntry.add(y.getEntry(j));
        }
        RealVector xstar = x.copy();
        if (!Aentry.isEmpty()) {
            Array2DRowRealMatrix Aactive = new Array2DRowRealMatrix((double[][])Aentry.toArray((T[])new double[0][]));
            ArrayRealVector lub = new ArrayRealVector(lubEntry.toArray(new Double[0]));
            ArrayRealVector ystar = new ArrayRealVector(yEntry.toArray(new Double[0]));
            this.solver.initialize(H, (RealMatrix)Aactive, q, 0, (RealVector)lub, (RealVector)lub, this.settings.getSigma(), this.settings.getSigma(), this.settings.getAlpha());
            for (int i = 0; i < this.settings.getPolishIteration(); ++i) {
                RealVector kttx = H.operate(xstar).add(Aactive.transpose().operate((RealVector)ystar));
                RealVector ktty = Aactive.operate(xstar);
                RealVector b1 = q.mapMultiply(-1.0).subtract(kttx);
                RealVector b2 = lub.mapMultiply(1.0).subtract(ktty);
                ADMMQPSolution dxz = this.solver.solve(b1, b2);
                xstar = xstar.add(dxz.getX());
                ystar = ystar.add(dxz.getV());
            }
            return new ADMMQPSolution(xstar, null, y, A.operate(xstar));
        }
        return new ADMMQPSolution(x, null, y, z);
    }

    private boolean manageRho(int me, double rp, double rd, double maxPrimal, double maxDual) {
        double rhonew;
        boolean updated = false;
        if (this.settings.updateRho() && ((rhonew = FastMath.min((double)FastMath.max((double)(this.rho * FastMath.sqrt((double)(rp * maxDual / (rd * maxPrimal)))), (double)this.settings.getRhoMin()), (double)this.settings.getRhoMax())) > this.rho * 5.0 || rhonew < this.rho / 5.0)) {
            this.rho = rhonew;
            updated = true;
            this.solver.updateSigmaRho(this.settings.getSigma(), me, this.rho);
        }
        return updated;
    }
}

