package breeze.optimize.linear;

import breeze.generic.UFunc;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.linalg.operators.OpMulMatrix$;
import breeze.math.MutableInnerProductVectorSpace;
import breeze.util.Implicits$;
import breeze.util.LazyLogger;
import breeze.util.SerializableLogging;
import java.io.Serializable;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Product;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.StringOps$;
import scala.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.LazyVals$;
import scala.runtime.Scala3RunTime$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;

/* compiled from: ConjugateGradient.scala */
/* loaded from: input_file:breeze/optimize/linear/ConjugateGradient.class */
public class ConjugateGradient<T, M> implements SerializableLogging {
    private volatile transient LazyLogger breeze$util$SerializableLogging$$_the_logger;
    private final double maxNormValue;
    private final int maxIterations;
    private final double normSquaredPenalty;
    private final double tolerance;
    public final MutableInnerProductVectorSpace<T, Object> breeze$optimize$linear$ConjugateGradient$$space;
    private final UFunc.UImpl2<OpMulMatrix$, M, T, T> mult;
    public final ConjugateGradient$State$ State$lzy1 = new ConjugateGradient$State$(this);

    /* compiled from: ConjugateGradient.scala */
    /* loaded from: input_file:breeze/optimize/linear/ConjugateGradient$State.class */
    public class State implements Product, Serializable {
        public static final long OFFSET$0 = LazyVals$.MODULE$.getOffset(State.class, "0bitmap$1");

        /* renamed from: 0bitmap$1, reason: not valid java name */
        public long f190bitmap$1;
        private final Object x;
        private final Object residual;
        private final Object direction;
        private final int iter;
        private final boolean converged;
        public double rtr$lzy1;
        private final /* synthetic */ ConjugateGradient $outer;

        public State(ConjugateGradient conjugateGradient, T t, T t2, T t3, int i, boolean z) {
            this.x = t;
            this.residual = t2;
            this.direction = t3;
            this.iter = i;
            this.converged = z;
            if (conjugateGradient == null) {
                throw new NullPointerException();
            }
            this.$outer = conjugateGradient;
        }

        public /* bridge */ /* synthetic */ Iterator productIterator() {
            return Product.productIterator$(this);
        }

        public /* bridge */ /* synthetic */ Iterator productElementNames() {
            return Product.productElementNames$(this);
        }

        public int hashCode() {
            return Statics.finalizeHash(Statics.mix(Statics.mix(Statics.mix(Statics.mix(Statics.mix(Statics.mix(-889275714, productPrefix().hashCode()), Statics.anyHash(x())), Statics.anyHash(residual())), Statics.anyHash(direction())), iter()), converged() ? 1231 : 1237), 5);
        }

        public boolean equals(Object obj) {
            boolean z;
            if (this != obj) {
                if ((obj instanceof State) && ((State) obj).breeze$optimize$linear$ConjugateGradient$State$$$outer() == this.$outer) {
                    State state = (State) obj;
                    z = iter() == state.iter() && converged() == state.converged() && BoxesRunTime.equals(x(), state.x()) && BoxesRunTime.equals(residual(), state.residual()) && BoxesRunTime.equals(direction(), state.direction()) && state.canEqual(this);
                } else {
                    z = false;
                }
                if (!z) {
                    return false;
                }
            }
            return true;
        }

        public String toString() {
            return ScalaRunTime$.MODULE$._toString(this);
        }

        public boolean canEqual(Object obj) {
            return obj instanceof State;
        }

        public int productArity() {
            return 5;
        }

        public String productPrefix() {
            return "State";
        }

        /* JADX WARN: Unreachable blocks removed: 7, instructions: 7 */
        public Object productElement(int i) {
            switch (i) {
                case 0:
                    return _1();
                case 1:
                    return _2();
                case 2:
                    return _3();
                case 3:
                    return BoxesRunTime.boxToInteger(_4());
                case 4:
                    return BoxesRunTime.boxToBoolean(_5());
                default:
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }
        }

        /* JADX WARN: Unreachable blocks removed: 7, instructions: 7 */
        public String productElementName(int i) {
            switch (i) {
                case 0:
                    return "x";
                case 1:
                    return "residual";
                case 2:
                    return "direction";
                case 3:
                    return "iter";
                case 4:
                    return "converged";
                default:
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }
        }

        public T x() {
            return (T) this.x;
        }

        public T residual() {
            return (T) this.residual;
        }

        public T direction() {
            return (T) this.direction;
        }

        public int iter() {
            return this.iter;
        }

        public boolean converged() {
            return this.converged;
        }

        /* JADX WARN: Unreachable blocks removed: 5, instructions: 5 */
        public double rtr() {
            while (true) {
                long j = LazyVals$.MODULE$.get(this, OFFSET$0);
                long STATE = LazyVals$.MODULE$.STATE(j, 0);
                if (STATE == 3) {
                    return this.rtr$lzy1;
                }
                if (STATE != 0) {
                    LazyVals$.MODULE$.wait4Notification(this, OFFSET$0, j, 0);
                } else if (LazyVals$.MODULE$.CAS(this, OFFSET$0, j, 1, 0)) {
                    try {
                        double unboxToDouble = BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.$outer.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(residual())).dot(residual(), this.$outer.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                        this.rtr$lzy1 = unboxToDouble;
                        LazyVals$.MODULE$.setFlag(this, OFFSET$0, 3, 0);
                        return unboxToDouble;
                    } catch (Throwable th) {
                        LazyVals$.MODULE$.setFlag(this, OFFSET$0, 0, 0);
                        throw th;
                    }
                }
            }
        }

        public ConjugateGradient<T, M>.State copy(T t, T t2, T t3, int i, boolean z) {
            return new State(this.$outer, t, t2, t3, i, z);
        }

        public T copy$default$1() {
            return (T) x();
        }

        public T copy$default$2() {
            return (T) residual();
        }

        public T copy$default$3() {
            return (T) direction();
        }

        public int copy$default$4() {
            return iter();
        }

        public boolean copy$default$5() {
            return converged();
        }

        public T _1() {
            return (T) x();
        }

        public T _2() {
            return (T) residual();
        }

        public T _3() {
            return (T) direction();
        }

        public int _4() {
            return iter();
        }

        public boolean _5() {
            return converged();
        }

        public final /* synthetic */ ConjugateGradient breeze$optimize$linear$ConjugateGradient$State$$$outer() {
            return this.$outer;
        }
    }

    public ConjugateGradient(double d, int i, double d2, double d3, MutableInnerProductVectorSpace<T, Object> mutableInnerProductVectorSpace, UFunc.UImpl2<OpMulMatrix$, M, T, T> uImpl2) {
        this.maxNormValue = d;
        this.maxIterations = i;
        this.normSquaredPenalty = d2;
        this.tolerance = d3;
        this.breeze$optimize$linear$ConjugateGradient$$space = mutableInnerProductVectorSpace;
        this.mult = uImpl2;
        breeze$util$SerializableLogging$$_the_logger_$eq(null);
    }

    @Override // breeze.util.SerializableLogging
    public LazyLogger breeze$util$SerializableLogging$$_the_logger() {
        return this.breeze$util$SerializableLogging$$_the_logger;
    }

    @Override // breeze.util.SerializableLogging
    public void breeze$util$SerializableLogging$$_the_logger_$eq(LazyLogger lazyLogger) {
        this.breeze$util$SerializableLogging$$_the_logger = lazyLogger;
    }

    @Override // breeze.util.SerializableLogging
    public /* bridge */ /* synthetic */ LazyLogger logger() {
        LazyLogger logger;
        logger = logger();
        return logger;
    }

    public T minimize(T t, M m) {
        return minimize(t, m, this.breeze$optimize$linear$ConjugateGradient$$space.zeroLike().apply(t));
    }

    public T minimize(T t, M m, T t2) {
        return (T) minimizeAndReturnResidual(t, m, t2)._1();
    }

    /* JADX WARN: Incorrect inner types in method signature: ()Lbreeze/optimize/linear/ConjugateGradient<TT;TM;>.State$; */
    public final ConjugateGradient$State$ State() {
        return this.State$lzy1;
    }

    public Tuple2<T, T> minimizeAndReturnResidual(T t, M m, T t2) {
        State state = (State) Implicits$.MODULE$.scEnrichIterator(iterations(t, m, t2)).last();
        return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(state.x()), state.residual());
    }

    public Iterator<ConjugateGradient<T, M>.State> iterations(T t, M m, T t2) {
        return Implicits$.MODULE$.scEnrichIterator(package$.MODULE$.Iterator().iterate(initialState(t, m, t2), state -> {
            Object residual = state.residual();
            Object direction = state.direction();
            double rtr = state.rtr();
            Object mo263apply = this.mult.mo263apply(m, direction);
            double unboxToDouble = BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).dot(direction, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
            double pow = scala.math.package$.MODULE$.pow(BoxesRunTime.unboxToDouble(norm$.MODULE$.apply(residual, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl())), 2.0d) / (BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).dot(mo263apply, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV())) + (this.normSquaredPenalty * unboxToDouble));
            Object $plus = ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(state.x())).$plus(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).$times(BoxesRunTime.boxToDouble(pow), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS_M()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV());
            double unboxToDouble2 = BoxesRunTime.unboxToDouble(norm$.MODULE$.apply($plus, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl()));
            if (unboxToDouble2 >= this.maxNormValue) {
                logger().info(() -> {
                    return r1.iterations$$anonfun$1$$anonfun$1(r2, r3);
                });
                double unboxToDouble3 = BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(state.x())).dot(direction, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                double unboxToDouble4 = BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(state.x())).dot(state.x(), this.breeze$optimize$linear$ConjugateGradient$$space.dotVV()));
                double d = this.maxNormValue * this.maxNormValue;
                double sqrt = scala.math.package$.MODULE$.sqrt((unboxToDouble3 * unboxToDouble3) + (unboxToDouble * (d - unboxToDouble4)));
                double d2 = unboxToDouble3 >= ((double) 0) ? (d - unboxToDouble4) / (unboxToDouble3 + sqrt) : (sqrt - unboxToDouble3) / unboxToDouble;
                if (Predef$.MODULE$.double2Double(d2).isNaN()) {
                    throw Scala3RunTime$.MODULE$.assertFailed(new StringBuilder(6).append(unboxToDouble3).append(" ").append(d).append(" ").append(unboxToDouble4).append("  ").append(unboxToDouble3).append(" ").append(sqrt).append(" ").append(unboxToDouble).toString());
                }
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble(d2), direction, state.x(), this.breeze$optimize$linear$ConjugateGradient$$space.scaleAddVV());
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble(-d2), ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(mo263apply)).$plus(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).$times$colon$times(BoxesRunTime.boxToDouble(this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV()), residual, this.breeze$optimize$linear$ConjugateGradient$$space.scaleAddVV());
                return State().apply(state.x(), residual, direction, state.iter() + 1, true);
            }
            ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(state.x())).$colon$eq($plus, this.breeze$optimize$linear$ConjugateGradient$$space.setIntoVV());
            ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(residual)).$minus$eq(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(mo263apply)).$plus(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).$times$colon$times(BoxesRunTime.boxToDouble(this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.addVV()))).$times$colon$times(BoxesRunTime.boxToDouble(pow), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.subIntoVV());
            ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).$colon$times$eq(BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(residual)).dot(residual, this.breeze$optimize$linear$ConjugateGradient$$space.dotVV())) / rtr), this.breeze$optimize$linear$ConjugateGradient$$space.mulIntoVS());
            ((NumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(direction)).$plus$eq(residual, this.breeze$optimize$linear$ConjugateGradient$$space.addIntoVV());
            double unboxToDouble5 = BoxesRunTime.unboxToDouble(norm$.MODULE$.apply(residual, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl()));
            boolean z = unboxToDouble5 <= this.tolerance || (state.iter() > this.maxIterations && this.maxIterations > 0);
            if (z) {
                if (state.iter() > this.maxIterations && this.maxIterations > 0) {
                    logger().info(() -> {
                        return r1.iterations$$anonfun$1$$anonfun$2(r2, r3);
                    });
                } else {
                    logger().info(() -> {
                        return r1.iterations$$anonfun$1$$anonfun$3(r2, r3);
                    });
                }
            } else {
                logger().info(() -> {
                    return r1.iterations$$anonfun$1$$anonfun$4(r2, r3);
                });
            }
            return State().apply(state.x(), residual, direction, state.iter() + 1, z);
        })).takeUpToWhere(state2 -> {
            return state2.converged();
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    private ConjugateGradient<T, M>.State initialState(T t, M m, T t2) {
        Object $minus = ((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(t)).$minus(this.mult.mo263apply(m, t2), this.breeze$optimize$linear$ConjugateGradient$$space.subVV()))).$minus(((ImmutableNumericOps) this.breeze$optimize$linear$ConjugateGradient$$space.hasOps().apply(t2)).$times$colon$times(BoxesRunTime.boxToDouble(this.normSquaredPenalty), this.breeze$optimize$linear$ConjugateGradient$$space.mulVS()), this.breeze$optimize$linear$ConjugateGradient$$space.subVV());
        return State().apply(t2, $minus, this.breeze$optimize$linear$ConjugateGradient$$space.copy().apply($minus), 0, BoxesRunTime.unboxToDouble(norm$.MODULE$.apply($minus, this.breeze$optimize$linear$ConjugateGradient$$space.normImpl())) <= this.tolerance);
    }

    private final String iterations$$anonfun$1$$anonfun$1(State state, double d) {
        return StringOps$.MODULE$.format$extension("%s boundary reached! norm(x): %.3f >= maxNormValue %s", ScalaRunTime$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(state.iter()), BoxesRunTime.boxToDouble(d), BoxesRunTime.boxToDouble(this.maxNormValue)}));
    }

    private final String iterations$$anonfun$1$$anonfun$2(State state, double d) {
        return StringOps$.MODULE$.format$extension("max iteration %s reached! norm(residual): %.3f > tolerance %s.", ScalaRunTime$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(state.iter()), BoxesRunTime.boxToDouble(d), BoxesRunTime.boxToDouble(this.tolerance)}));
    }

    private final String iterations$$anonfun$1$$anonfun$3(State state, double d) {
        return StringOps$.MODULE$.format$extension("%s converged! norm(residual): %.3f <= tolerance %s.", ScalaRunTime$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(state.iter()), BoxesRunTime.boxToDouble(d), BoxesRunTime.boxToDouble(this.tolerance)}));
    }

    private final String iterations$$anonfun$1$$anonfun$4(State state, double d) {
        return StringOps$.MODULE$.format$extension("%s: norm(residual): %.3f > tolerance %s.", ScalaRunTime$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(state.iter()), BoxesRunTime.boxToDouble(d), BoxesRunTime.boxToDouble(this.tolerance)}));
    }
}
