package breeze.optimize;

import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.math.MutableInnerProductModule;
import breeze.optimize.FirstOrderMinimizer;
import breeze.stats.distributions.Rand$;
import java.io.Serializable;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Product;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.Seq;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;

/* compiled from: StochasticAveragedGradient.scala */
/* loaded from: input_file:breeze/optimize/StochasticAveragedGradient.class */
public class StochasticAveragedGradient<T> extends FirstOrderMinimizer<T, BatchDiffFunction<T>> {
    private final double initialStepSize;
    private final int tuneStepFrequency;
    private final double l2Regularization;
    private final MutableInnerProductModule<T, Object> vs;
    public final StochasticAveragedGradient$History$ History$lzy1;

    /* compiled from: StochasticAveragedGradient.scala */
    /* loaded from: input_file:breeze/optimize/StochasticAveragedGradient$History.class */
    public class History implements Product, Serializable {
        private final double stepSize;
        private final IndexedSeq range;
        private final Object currentSum;
        private final IndexedSeq previousGradients;
        private final int nextPos;
        private final /* synthetic */ StochasticAveragedGradient $outer;

        public History(StochasticAveragedGradient stochasticAveragedGradient, double d, IndexedSeq<Object> indexedSeq, T t, IndexedSeq<T> indexedSeq2, int i) {
            this.stepSize = d;
            this.range = indexedSeq;
            this.currentSum = t;
            this.previousGradients = indexedSeq2;
            this.nextPos = i;
            if (stochasticAveragedGradient == null) {
                throw new NullPointerException();
            }
            this.$outer = stochasticAveragedGradient;
        }

        public /* bridge */ /* synthetic */ Iterator productIterator() {
            return Product.productIterator$(this);
        }

        public /* bridge */ /* synthetic */ Iterator productElementNames() {
            return Product.productElementNames$(this);
        }

        public int hashCode() {
            return Statics.finalizeHash(Statics.mix(Statics.mix(Statics.mix(Statics.mix(Statics.mix(Statics.mix(-889275714, productPrefix().hashCode()), Statics.doubleHash(stepSize())), Statics.anyHash(range())), Statics.anyHash(currentSum())), Statics.anyHash(previousGradients())), nextPos()), 5);
        }

        public boolean equals(Object obj) {
            boolean z;
            if (this != obj) {
                if ((obj instanceof History) && ((History) obj).breeze$optimize$StochasticAveragedGradient$History$$$outer() == this.$outer) {
                    History history = (History) obj;
                    if (stepSize() == history.stepSize() && nextPos() == history.nextPos()) {
                        IndexedSeq<Object> range = range();
                        IndexedSeq<Object> range2 = history.range();
                        if (range != null ? range.equals(range2) : range2 == null) {
                            if (BoxesRunTime.equals(currentSum(), history.currentSum())) {
                                IndexedSeq<T> previousGradients = previousGradients();
                                IndexedSeq<T> previousGradients2 = history.previousGradients();
                                if (previousGradients != null ? previousGradients.equals(previousGradients2) : previousGradients2 == null) {
                                    if (history.canEqual(this)) {
                                        z = true;
                                    }
                                }
                            }
                        }
                    }
                    z = false;
                } else {
                    z = false;
                }
                if (!z) {
                    return false;
                }
            }
            return true;
        }

        public String toString() {
            return ScalaRunTime$.MODULE$._toString(this);
        }

        public boolean canEqual(Object obj) {
            return obj instanceof History;
        }

        public int productArity() {
            return 5;
        }

        public String productPrefix() {
            return "History";
        }

        /* JADX WARN: Unreachable blocks removed: 7, instructions: 7 */
        public Object productElement(int i) {
            switch (i) {
                case 0:
                    return BoxesRunTime.boxToDouble(_1());
                case 1:
                    return _2();
                case 2:
                    return _3();
                case 3:
                    return _4();
                case 4:
                    return BoxesRunTime.boxToInteger(_5());
                default:
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }
        }

        /* JADX WARN: Unreachable blocks removed: 7, instructions: 7 */
        public String productElementName(int i) {
            switch (i) {
                case 0:
                    return "stepSize";
                case 1:
                    return "range";
                case 2:
                    return "currentSum";
                case 3:
                    return "previousGradients";
                case 4:
                    return "nextPos";
                default:
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
            }
        }

        public double stepSize() {
            return this.stepSize;
        }

        public IndexedSeq<Object> range() {
            return this.range;
        }

        public T currentSum() {
            return (T) this.currentSum;
        }

        public IndexedSeq<T> previousGradients() {
            return this.previousGradients;
        }

        public int nextPos() {
            return this.nextPos;
        }

        public StochasticAveragedGradient<T>.History copy(double d, IndexedSeq<Object> indexedSeq, T t, IndexedSeq<T> indexedSeq2, int i) {
            return new History(this.$outer, d, indexedSeq, t, indexedSeq2, i);
        }

        public double copy$default$1() {
            return stepSize();
        }

        public IndexedSeq<Object> copy$default$2() {
            return range();
        }

        public T copy$default$3() {
            return (T) currentSum();
        }

        public IndexedSeq<T> copy$default$4() {
            return previousGradients();
        }

        public int copy$default$5() {
            return nextPos();
        }

        public double _1() {
            return stepSize();
        }

        public IndexedSeq<Object> _2() {
            return range();
        }

        public T _3() {
            return (T) currentSum();
        }

        public IndexedSeq<T> _4() {
            return previousGradients();
        }

        public int _5() {
            return nextPos();
        }

        public final /* synthetic */ StochasticAveragedGradient breeze$optimize$StochasticAveragedGradient$History$$$outer() {
            return this.$outer;
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public StochasticAveragedGradient(int i, double d, int i2, double d2, MutableInnerProductModule<T, Object> mutableInnerProductModule) {
        super(i, FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$2(), FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$3(), FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$4(), mutableInnerProductModule);
        this.initialStepSize = d;
        this.tuneStepFrequency = i2;
        this.l2Regularization = d2;
        this.vs = mutableInnerProductModule;
        this.History$lzy1 = new StochasticAveragedGradient$History$(this);
    }

    /* JADX WARN: Incorrect inner types in method signature: ()Lbreeze/optimize/StochasticAveragedGradient<TT;>.History$; */
    public final StochasticAveragedGradient$History$ History() {
        return this.History$lzy1;
    }

    public StochasticAveragedGradient<T>.History initialHistory(BatchDiffFunction<T> batchDiffFunction, T t) {
        T apply = this.vs.zeroLike().apply(t);
        return History().apply(this.initialStepSize, batchDiffFunction.fullRange(), this.vs.zeroLike().apply(t), (IndexedSeq) scala.package$.MODULE$.IndexedSeq().fill(batchDiffFunction.fullRange().length(), () -> {
            return initialHistory$$anonfun$1(r6);
        }), 0);
    }

    @Override // breeze.optimize.FirstOrderMinimizer
    public T chooseDescentDirection(FirstOrderMinimizer.State<T, Object, StochasticAveragedGradient<T>.History> state, BatchDiffFunction<T> batchDiffFunction) {
        return (T) ((ImmutableNumericOps) this.vs.hasOps().apply(state.history().currentSum())).$times(BoxesRunTime.boxToDouble((-1.0d) / batchDiffFunction.fullRange().size()), this.vs.mulVS_M());
    }

    public double determineStepSize(FirstOrderMinimizer.State<T, Object, StochasticAveragedGradient<T>.History> state, BatchDiffFunction<T> batchDiffFunction, T t) {
        return state.history().stepSize();
    }

    public Tuple2<Object, T> calculateObjective(BatchDiffFunction<T> batchDiffFunction, T t, StochasticAveragedGradient<T>.History history) {
        return batchDiffFunction.calculate(t, (IndexedSeq) scala.package$.MODULE$.IndexedSeq().apply(ScalaRunTime$.MODULE$.wrapIntArray(new int[]{history.nextPos()})));
    }

    @Override // breeze.optimize.FirstOrderMinimizer
    public Tuple2<Object, T> adjust(T t, T t2, double d) {
        double unboxToDouble = d + ((BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.vs.hasOps().apply(t)).dot(t, this.vs.dotVV())) * this.l2Regularization) / 2.0d);
        Object $plus = ((NumericOps) this.vs.hasOps().apply(t2)).$plus(((ImmutableNumericOps) this.vs.hasOps().apply(t)).$times(BoxesRunTime.boxToDouble(this.l2Regularization), this.vs.mulVS_M()), this.vs.addVV());
        return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension((Double) Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToDouble(unboxToDouble)), $plus);
    }

    @Override // breeze.optimize.FirstOrderMinimizer
    public T takeStep(FirstOrderMinimizer.State<T, Object, StochasticAveragedGradient<T>.History> state, T t, double d) {
        T t2 = (T) ((ImmutableNumericOps) this.vs.hasOps().apply(state.x())).$times(BoxesRunTime.boxToDouble(1 - (d * this.l2Regularization)), this.vs.mulVS_M());
        breeze.linalg.package$.MODULE$.axpy(BoxesRunTime.boxToDouble(d), t, t2, this.vs.scaleAddVV());
        return t2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StochasticAveragedGradient<T>.History updateHistory(T t, T t2, double d, BatchDiffFunction<T> batchDiffFunction, FirstOrderMinimizer.State<T, Object, StochasticAveragedGradient<T>.History> state) {
        double stepSize;
        Object $minus = ((ImmutableNumericOps) this.vs.hasOps().apply(state.history().currentSum())).$minus(state.history().previousGradients().apply(state.history().nextPos()), this.vs.subVV());
        if (this.tuneStepFrequency <= 0 || state.iter() % this.tuneStepFrequency != 0) {
            stepSize = state.history().stepSize();
        } else {
            Object $minus2 = ((ImmutableNumericOps) this.vs.hasOps().apply(t)).$minus(state.x(), this.vs.subVV());
            stepSize = (batchDiffFunction.valueAt(t, (IndexedSeq) scala.package$.MODULE$.IndexedSeq().apply(ScalaRunTime$.MODULE$.wrapIntArray(new int[]{state.history().nextPos()}))) + ((this.l2Regularization / ((double) 2)) * BoxesRunTime.unboxToDouble(norm$.MODULE$.apply(t, this.vs.normImpl())))) - state.adjustedValue() > BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.vs.hasOps().apply(state.adjustedGradient())).dot($minus2, this.vs.dotVV())) + (BoxesRunTime.unboxToDouble(((ImmutableNumericOps) this.vs.hasOps().apply($minus2)).dot($minus2, this.vs.dotVV())) / (((double) 2) * state.history().stepSize())) ? state.history().stepSize() / 2 : state.history().stepSize() * 1.5d;
        }
        double d2 = stepSize;
        ((NumericOps) this.vs.hasOps().apply($minus)).$plus$eq(t2, this.vs.addIntoVV());
        return History().apply(d2, state.history().range(), $minus, (IndexedSeq) state.history().previousGradients().updated(state.history().nextPos(), t2), state.iter() < state.history().previousGradients().length() - 1 ? state.iter() + 1 : BoxesRunTime.unboxToInt(Rand$.MODULE$.choose((Seq) state.history().range()).mo1178draw()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // breeze.optimize.FirstOrderMinimizer
    public /* bridge */ /* synthetic */ Object initialHistory(StochasticDiffFunction stochasticDiffFunction, Object obj) {
        return initialHistory((BatchDiffFunction<BatchDiffFunction<T>>) stochasticDiffFunction, (BatchDiffFunction<T>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // breeze.optimize.FirstOrderMinimizer
    public /* bridge */ /* synthetic */ double determineStepSize(FirstOrderMinimizer.State state, StochasticDiffFunction stochasticDiffFunction, Object obj) {
        return determineStepSize((FirstOrderMinimizer.State<BatchDiffFunction<T>, Object, StochasticAveragedGradient<BatchDiffFunction<T>>.History>) state, (BatchDiffFunction<BatchDiffFunction<T>>) stochasticDiffFunction, (BatchDiffFunction<T>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // breeze.optimize.FirstOrderMinimizer
    public /* bridge */ /* synthetic */ Tuple2 calculateObjective(StochasticDiffFunction stochasticDiffFunction, Object obj, Object obj2) {
        return calculateObjective((BatchDiffFunction<BatchDiffFunction<T>>) stochasticDiffFunction, (BatchDiffFunction<T>) obj, (StochasticAveragedGradient<BatchDiffFunction<T>>.History) obj2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // breeze.optimize.FirstOrderMinimizer
    public /* bridge */ /* synthetic */ Object updateHistory(Object obj, Object obj2, double d, StochasticDiffFunction stochasticDiffFunction, FirstOrderMinimizer.State state) {
        return updateHistory(obj, obj2, d, (BatchDiffFunction<Object>) stochasticDiffFunction, (FirstOrderMinimizer.State<Object, Object, StochasticAveragedGradient<Object>.History>) state);
    }

    private static final Object initialHistory$$anonfun$1(Object obj) {
        return obj;
    }
}
