package ai.djl.training.listener;

import ai.djl.Model;
import ai.djl.training.Trainer;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/training/listener/SaveModelTrainingListener.class */
public class SaveModelTrainingListener extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(SaveModelTrainingListener.class);
    private String outputDir;
    private String overrideModelName;
    private Consumer<Trainer> onSaveModel;
    private int checkpoint;
    private int epoch;

    public SaveModelTrainingListener(String str) {
        this(str, null, -1);
    }

    public SaveModelTrainingListener(String str, String str2) {
        this(str, str2, -1);
    }

    public SaveModelTrainingListener(String str, String str2, int i) {
        this.outputDir = str;
        this.checkpoint = i;
        if (str == null) {
            throw new IllegalArgumentException("Can not save checkpoint without specifying an output directory");
        }
        this.overrideModelName = str2;
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onEpoch(Trainer trainer) {
        this.epoch++;
        if (this.outputDir != null && this.checkpoint > 0 && this.epoch % this.checkpoint == 0) {
            saveModel(trainer);
        }
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
        if (this.checkpoint == -1 || this.epoch % this.checkpoint != 0) {
            saveModel(trainer);
        }
    }

    public String getOverrideModelName() {
        return this.overrideModelName;
    }

    public void setOverrideModelName(String str) {
        this.overrideModelName = str;
    }

    public int getCheckpoint() {
        return this.checkpoint;
    }

    public void setCheckpoint(int i) {
        this.checkpoint = i;
    }

    public void setSaveModelCallback(Consumer<Trainer> consumer) {
        this.onSaveModel = consumer;
    }

    protected void saveModel(Trainer trainer) {
        Model model = trainer.getModel();
        String name = model.getName();
        if (this.overrideModelName != null) {
            name = this.overrideModelName;
        }
        try {
            model.setProperty("Epoch", String.valueOf(this.epoch));
            if (this.onSaveModel != null) {
                this.onSaveModel.accept(trainer);
            }
            model.save(Paths.get(this.outputDir, new String[0]), name);
        } catch (IOException e) {
            logger.error("Failed to save checkpoint", e);
        }
    }
}
