package breeze.optimize;

import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.linalg.operators.HasOps$;
import breeze.math.MutableFiniteCoordinateField;
import breeze.numerics.package$signum$;
import breeze.numerics.package$signum$signumDoubleImpl$;
import breeze.numerics.package$sqrt$;
import breeze.numerics.package$sqrt$sqrtDoubleImpl$;
import breeze.optimize.FirstOrderMinimizer;
import breeze.stats.distributions.RandBasis;
import java.io.Serializable;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Product;
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.Iterator;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.ScalaRunTime$;

/* compiled from: AdaptiveGradientDescent.scala */
/* loaded from: input_file:breeze/optimize/AdaptiveGradientDescent.class */
public final class AdaptiveGradientDescent {

    /* compiled from: AdaptiveGradientDescent.scala */
    /* loaded from: input_file:breeze/optimize/AdaptiveGradientDescent$L1Regularization.class */
    public static class L1Regularization<T> extends StochasticGradientDescent<T> {
        private final double lambda;
        private final double delta;
        private final MutableFiniteCoordinateField<T, ?, Object> space;
        public final AdaptiveGradientDescent$L1Regularization$History$ History$lzy2;

        /* compiled from: AdaptiveGradientDescent.scala */
        /* loaded from: input_file:breeze/optimize/AdaptiveGradientDescent$L1Regularization$History.class */
        public class History implements Product, Serializable {
            private final Object sumOfSquaredGradients;
            private final /* synthetic */ L1Regularization $outer;

            public History(L1Regularization l1Regularization, T t) {
                this.sumOfSquaredGradients = t;
                if (l1Regularization == null) {
                    throw new NullPointerException();
                }
                this.$outer = l1Regularization;
            }

            public /* bridge */ /* synthetic */ Iterator productIterator() {
                return Product.productIterator$(this);
            }

            public /* bridge */ /* synthetic */ Iterator productElementNames() {
                return Product.productElementNames$(this);
            }

            public int hashCode() {
                return ScalaRunTime$.MODULE$._hashCode(this);
            }

            public boolean equals(Object obj) {
                boolean z;
                if (this != obj) {
                    if ((obj instanceof History) && ((History) obj).breeze$optimize$AdaptiveGradientDescent$L1Regularization$History$$$outer() == this.$outer) {
                        History history = (History) obj;
                        z = BoxesRunTime.equals(sumOfSquaredGradients(), history.sumOfSquaredGradients()) && history.canEqual(this);
                    } else {
                        z = false;
                    }
                    if (!z) {
                        return false;
                    }
                }
                return true;
            }

            public String toString() {
                return ScalaRunTime$.MODULE$._toString(this);
            }

            public boolean canEqual(Object obj) {
                return obj instanceof History;
            }

            public int productArity() {
                return 1;
            }

            public String productPrefix() {
                return "History";
            }

            public Object productElement(int i) {
                if (0 == i) {
                    return _1();
                }
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }

            public String productElementName(int i) {
                if (0 == i) {
                    return "sumOfSquaredGradients";
                }
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }

            public T sumOfSquaredGradients() {
                return (T) this.sumOfSquaredGradients;
            }

            public L1Regularization<T>.History copy(T t) {
                return new History(this.$outer, t);
            }

            public T copy$default$1() {
                return (T) sumOfSquaredGradients();
            }

            public T _1() {
                return (T) sumOfSquaredGradients();
            }

            public final /* synthetic */ L1Regularization breeze$optimize$AdaptiveGradientDescent$L1Regularization$History$$$outer() {
                return this.$outer;
            }
        }

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        public L1Regularization(double d, double d2, double d3, int i, MutableFiniteCoordinateField<T, ?, Object> mutableFiniteCoordinateField, RandBasis randBasis) {
            super(d3, i, StochasticGradientDescent$.MODULE$.$lessinit$greater$default$3(), StochasticGradientDescent$.MODULE$.$lessinit$greater$default$4(), mutableFiniteCoordinateField);
            this.lambda = d;
            this.delta = d2;
            this.space = mutableFiniteCoordinateField;
            this.History$lzy2 = new AdaptiveGradientDescent$L1Regularization$History$(this);
        }

        public double lambda() {
            return this.lambda;
        }

        /* JADX WARN: Incorrect inner types in method signature: ()Lbreeze/optimize/AdaptiveGradientDescent$L1Regularization<TT;>.History$; */
        public final AdaptiveGradientDescent$L1Regularization$History$ History() {
            return this.History$lzy2;
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public L1Regularization<T>.History initialHistory(StochasticDiffFunction<T> stochasticDiffFunction, T t) {
            return History().apply(this.space.zeroLike().apply(t));
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public L1Regularization<T>.History updateHistory(T t, T t2, double d, StochasticDiffFunction<T> stochasticDiffFunction, FirstOrderMinimizer.State<T, Object, L1Regularization<T>.History> state) {
            L1Regularization<T>.History history = state.history();
            Object $times$colon$times = ((ImmutableNumericOps) this.space.hasOps().apply(state.grad())).$times$colon$times(state.grad(), this.space.mulVV());
            if (state.iter() > 200.0d) {
                ((NumericOps) this.space.hasOps().apply($times$colon$times)).$times$eq(BoxesRunTime.boxToDouble(1 / 200.0d), this.space.mulIntoVS());
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble((200.0d - 1) / 200.0d), history.sumOfSquaredGradients(), $times$colon$times, this.space.scaleAddVV());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                ((NumericOps) this.space.hasOps().apply($times$colon$times)).$plus$eq(history.sumOfSquaredGradients(), this.space.addIntoVV());
            }
            return new History(this, $times$colon$times);
        }

        @Override // breeze.optimize.StochasticGradientDescent, breeze.optimize.FirstOrderMinimizer
        public T takeStep(FirstOrderMinimizer.State<T, Object, L1Regularization<T>.History> state, T t, double d) {
            Object apply = package$sqrt$.MODULE$.apply(((ImmutableNumericOps) this.space.hasOps().apply(((ImmutableNumericOps) this.space.hasOps().apply(state.history().sumOfSquaredGradients())).$plus$colon$plus(((ImmutableNumericOps) this.space.hasOps().apply(state.grad())).$times$colon$times(state.grad(), this.space.mulVV()), this.space.addVV()))).$plus$colon$plus(BoxesRunTime.boxToDouble(this.delta), this.space.addVS()), HasOps$.MODULE$.fromLowOrderCanMapActiveValues(this.space.scalarOf(), package$sqrt$sqrtDoubleImpl$.MODULE$, this.space.mapValues()));
            Object $plus = ((NumericOps) this.space.hasOps().apply(state.x())).$plus(((ImmutableNumericOps) this.space.hasOps().apply(((ImmutableNumericOps) this.space.hasOps().apply(t)).$times$colon$times(BoxesRunTime.boxToDouble(d), this.space.mulVS()))).$div$colon$div(apply, this.space.divVV()), this.space.addVV());
            double lambda = lambda() * d;
            return (T) this.space.zipMapValues().map($plus, apply, (d2, d3) -> {
                Tuple2 apply2 = Tuple2$.MODULE$.apply(BoxesRunTime.boxToDouble(d2), BoxesRunTime.boxToDouble(d3));
                if (apply2 == null) {
                    throw new MatchError(apply2);
                }
                double unboxToDouble = BoxesRunTime.unboxToDouble(apply2._1());
                double unboxToDouble2 = BoxesRunTime.unboxToDouble(apply2._2());
                if (RichDouble$.MODULE$.abs$extension(Predef$.MODULE$.doubleWrapper(unboxToDouble)) < lambda / unboxToDouble2) {
                    return 0.0d;
                }
                return unboxToDouble - ((scala.math.package$.MODULE$.signum(unboxToDouble) * lambda) / unboxToDouble2);
            });
        }

        @Override // breeze.optimize.StochasticGradientDescent, breeze.optimize.FirstOrderMinimizer
        public double determineStepSize(FirstOrderMinimizer.State<T, Object, L1Regularization<T>.History> state, StochasticDiffFunction<T> stochasticDiffFunction, T t) {
            return defaultStepSize();
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public Tuple2<Object, T> adjust(T t, T t2, double d) {
            double unboxToDouble = d + (BoxesRunTime.unboxToDouble(norm$.MODULE$.apply(t, BoxesRunTime.boxToDouble(1.0d), this.space.normImpl2())) * lambda());
            Object $plus = ((NumericOps) this.space.hasOps().apply(t2)).$plus(((ImmutableNumericOps) this.space.hasOps().apply(package$signum$.MODULE$.apply(t, HasOps$.MODULE$.fromLowOrderCanMapActiveValues(this.space.scalarOf(), package$signum$signumDoubleImpl$.MODULE$, this.space.mapValues())))).$times$colon$times(BoxesRunTime.boxToDouble(lambda()), this.space.mulVS()), this.space.addVV());
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension((Double) Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToDouble(unboxToDouble)), $plus);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // breeze.optimize.FirstOrderMinimizer
        public /* bridge */ /* synthetic */ Object initialHistory(StochasticDiffFunction stochasticDiffFunction, Object obj) {
            return initialHistory((StochasticDiffFunction<StochasticDiffFunction>) stochasticDiffFunction, (StochasticDiffFunction) obj);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // breeze.optimize.FirstOrderMinimizer
        public /* bridge */ /* synthetic */ Object updateHistory(Object obj, Object obj2, double d, StochasticDiffFunction stochasticDiffFunction, FirstOrderMinimizer.State state) {
            return updateHistory(obj, obj2, d, (StochasticDiffFunction<Object>) stochasticDiffFunction, (FirstOrderMinimizer.State<Object, Object, L1Regularization<Object>.History>) state);
        }
    }

    /* compiled from: AdaptiveGradientDescent.scala */
    /* loaded from: input_file:breeze/optimize/AdaptiveGradientDescent$L2Regularization.class */
    public static class L2Regularization<T> extends StochasticGradientDescent<T> {
        private final double regularizationConstant;
        private final double delta;
        public final AdaptiveGradientDescent$L2Regularization$History$ History$lzy1;

        /* compiled from: AdaptiveGradientDescent.scala */
        /* loaded from: input_file:breeze/optimize/AdaptiveGradientDescent$L2Regularization$History.class */
        public class History implements Product, Serializable {
            private final Object sumOfSquaredGradients;
            private final /* synthetic */ L2Regularization $outer;

            public History(L2Regularization l2Regularization, T t) {
                this.sumOfSquaredGradients = t;
                if (l2Regularization == null) {
                    throw new NullPointerException();
                }
                this.$outer = l2Regularization;
            }

            public /* bridge */ /* synthetic */ Iterator productIterator() {
                return Product.productIterator$(this);
            }

            public /* bridge */ /* synthetic */ Iterator productElementNames() {
                return Product.productElementNames$(this);
            }

            public int hashCode() {
                return ScalaRunTime$.MODULE$._hashCode(this);
            }

            public boolean equals(Object obj) {
                boolean z;
                if (this != obj) {
                    if ((obj instanceof History) && ((History) obj).breeze$optimize$AdaptiveGradientDescent$L2Regularization$History$$$outer() == this.$outer) {
                        History history = (History) obj;
                        z = BoxesRunTime.equals(sumOfSquaredGradients(), history.sumOfSquaredGradients()) && history.canEqual(this);
                    } else {
                        z = false;
                    }
                    if (!z) {
                        return false;
                    }
                }
                return true;
            }

            public String toString() {
                return ScalaRunTime$.MODULE$._toString(this);
            }

            public boolean canEqual(Object obj) {
                return obj instanceof History;
            }

            public int productArity() {
                return 1;
            }

            public String productPrefix() {
                return "History";
            }

            public Object productElement(int i) {
                if (0 == i) {
                    return _1();
                }
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }

            public String productElementName(int i) {
                if (0 == i) {
                    return "sumOfSquaredGradients";
                }
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }

            public T sumOfSquaredGradients() {
                return (T) this.sumOfSquaredGradients;
            }

            public L2Regularization<T>.History copy(T t) {
                return new History(this.$outer, t);
            }

            public T copy$default$1() {
                return (T) sumOfSquaredGradients();
            }

            public T _1() {
                return (T) sumOfSquaredGradients();
            }

            public final /* synthetic */ L2Regularization breeze$optimize$AdaptiveGradientDescent$L2Regularization$History$$$outer() {
                return this.$outer;
            }
        }

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        public L2Regularization(double d, double d2, int i, double d3, int i2, MutableFiniteCoordinateField<T, ?, Object> mutableFiniteCoordinateField, RandBasis randBasis) {
            super(d2, i, d3, i2, mutableFiniteCoordinateField);
            this.regularizationConstant = d;
            this.History$lzy1 = new AdaptiveGradientDescent$L2Regularization$History$(this);
            this.delta = 1.0E-4d;
        }

        public double regularizationConstant() {
            return this.regularizationConstant;
        }

        private MutableFiniteCoordinateField<T, ?, Object> vspace$accessor() {
            return (MutableFiniteCoordinateField) super.vspace();
        }

        public double delta() {
            return this.delta;
        }

        /* JADX WARN: Incorrect inner types in method signature: ()Lbreeze/optimize/AdaptiveGradientDescent$L2Regularization<TT;>.History$; */
        public final AdaptiveGradientDescent$L2Regularization$History$ History() {
            return this.History$lzy1;
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public L2Regularization<T>.History initialHistory(StochasticDiffFunction<T> stochasticDiffFunction, T t) {
            return History().apply(vspace$accessor().zeroLike().apply(t));
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public L2Regularization<T>.History updateHistory(T t, T t2, double d, StochasticDiffFunction<T> stochasticDiffFunction, FirstOrderMinimizer.State<T, Object, L2Regularization<T>.History> state) {
            L2Regularization<T>.History history = state.history();
            Object $times$colon$times = ((ImmutableNumericOps) vspace$accessor().hasOps().apply(state.grad())).$times$colon$times(state.grad(), vspace$accessor().mulVV());
            if (state.iter() > 1000.0d) {
                ((NumericOps) vspace$accessor().hasOps().apply($times$colon$times)).$times$eq(BoxesRunTime.boxToDouble(1 / 1000.0d), vspace$accessor().mulIntoVS());
                breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble((1000.0d - 1) / 1000.0d), history.sumOfSquaredGradients(), $times$colon$times, vspace$accessor().scaleAddVV());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                ((NumericOps) vspace$accessor().hasOps().apply($times$colon$times)).$plus$eq(history.sumOfSquaredGradients(), vspace$accessor().addIntoVV());
            }
            return new History(this, $times$colon$times);
        }

        @Override // breeze.optimize.StochasticGradientDescent, breeze.optimize.FirstOrderMinimizer
        public T takeStep(FirstOrderMinimizer.State<T, Object, L2Regularization<T>.History> state, T t, double d) {
            Object apply = package$sqrt$.MODULE$.apply(((ImmutableNumericOps) vspace$accessor().hasOps().apply(state.history().sumOfSquaredGradients())).$plus$colon$plus(((ImmutableNumericOps) vspace$accessor().hasOps().apply(state.grad())).$times$colon$times(state.grad(), vspace$accessor().mulVV()), vspace$accessor().addVV()), HasOps$.MODULE$.fromLowOrderCanMapActiveValues(vspace$accessor().scalarOf(), package$sqrt$sqrtDoubleImpl$.MODULE$, vspace$accessor().mapValues()));
            T t2 = (T) ((ImmutableNumericOps) vspace$accessor().hasOps().apply(state.x())).$times$colon$times(apply, vspace$accessor().mulVV());
            breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble(d), t, t2, vspace$accessor().scaleAddVV());
            ((NumericOps) vspace$accessor().hasOps().apply(apply)).$plus$eq(BoxesRunTime.boxToDouble(delta() + (regularizationConstant() * d)), vspace$accessor().addIntoVS());
            ((NumericOps) vspace$accessor().hasOps().apply(t2)).$colon$div$eq(apply, vspace$accessor().divIntoVV());
            return t2;
        }

        @Override // breeze.optimize.StochasticGradientDescent, breeze.optimize.FirstOrderMinimizer
        public double determineStepSize(FirstOrderMinimizer.State<T, Object, L2Regularization<T>.History> state, StochasticDiffFunction<T> stochasticDiffFunction, T t) {
            return defaultStepSize();
        }

        @Override // breeze.optimize.FirstOrderMinimizer
        public Tuple2<Object, T> adjust(T t, T t2, double d) {
            double unboxToDouble = d + ((BoxesRunTime.unboxToDouble(((ImmutableNumericOps) vspace$accessor().hasOps().apply(t)).dot(t, vspace$accessor().dotVV())) * regularizationConstant()) / 2.0d);
            Object $plus = ((NumericOps) vspace$accessor().hasOps().apply(t2)).$plus(((ImmutableNumericOps) vspace$accessor().hasOps().apply(t)).$times(BoxesRunTime.boxToDouble(regularizationConstant()), vspace$accessor().mulVS_M()), vspace$accessor().addVV());
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension((Double) Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToDouble(unboxToDouble)), $plus);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // breeze.optimize.FirstOrderMinimizer
        public /* bridge */ /* synthetic */ Object initialHistory(StochasticDiffFunction stochasticDiffFunction, Object obj) {
            return initialHistory((StochasticDiffFunction<StochasticDiffFunction>) stochasticDiffFunction, (StochasticDiffFunction) obj);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // breeze.optimize.FirstOrderMinimizer
        public /* bridge */ /* synthetic */ Object updateHistory(Object obj, Object obj2, double d, StochasticDiffFunction stochasticDiffFunction, FirstOrderMinimizer.State state) {
            return updateHistory(obj, obj2, d, (StochasticDiffFunction<Object>) stochasticDiffFunction, (FirstOrderMinimizer.State<Object, Object, L2Regularization<Object>.History>) state);
        }
    }
}
