package ai.djl.training.tracker;

import ai.djl.TrainingDivergedException;

/* loaded from: input_file:ai/djl/training/tracker/WarmUpTracker.class */
public final class WarmUpTracker implements Tracker {
    Tracker mainTracker;
    int warmUpSteps;
    float warmUpBeginValue;
    float warmUpFinalValue;
    Mode warmUpMode;

    /* loaded from: input_file:ai/djl/training/tracker/WarmUpTracker$Builder.class */
    public static final class Builder {
        Tracker mainTracker;
        int warmUpSteps;
        float warmUpBeginValue;
        Mode warmUpMode;

        private Builder() {
            this.warmUpMode = Mode.LINEAR;
        }

        public Builder setMainTracker(Tracker tracker) {
            this.mainTracker = tracker;
            return this;
        }

        public Builder optWarmUpSteps(int i) {
            this.warmUpSteps = i;
            return this;
        }

        public Builder optWarmUpBeginValue(float f) {
            this.warmUpBeginValue = f;
            return this;
        }

        public Builder optWarmUpMode(Mode mode) {
            this.warmUpMode = mode;
            return this;
        }

        public WarmUpTracker build() {
            return new WarmUpTracker(this);
        }
    }

    /* loaded from: input_file:ai/djl/training/tracker/WarmUpTracker$Mode.class */
    public enum Mode {
        LINEAR,
        CONSTANT
    }

    WarmUpTracker(Builder builder) {
        this.mainTracker = builder.mainTracker;
        this.warmUpSteps = builder.warmUpSteps;
        this.warmUpBeginValue = builder.warmUpBeginValue;
        this.warmUpMode = builder.warmUpMode;
        this.warmUpFinalValue = this.mainTracker.getNewValue(0);
    }

    public static Builder builder() {
        return new Builder();
    }

    float getWarmUpValue(int i) {
        float f = this.warmUpBeginValue;
        if (this.warmUpMode == Mode.LINEAR) {
            f = this.warmUpBeginValue + (((this.warmUpFinalValue - this.warmUpBeginValue) * i) / this.warmUpSteps);
        }
        checkValue(f);
        return f;
    }

    @Override // ai.djl.training.tracker.Tracker
    public float getNewValue(int i) {
        return i < this.warmUpSteps ? getWarmUpValue(i) : this.mainTracker.getNewValue(i - this.warmUpSteps);
    }

    void checkValue(float f) {
        if (Float.isNaN(f)) {
            throw new TrainingDivergedException("Value is Nan.");
        }
    }
}
