package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.hyperparameter.optimizer.HpORandom;
import ai.djl.training.hyperparameter.param.HpInt;
import ai.djl.training.hyperparameter.param.HpSet;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Pair;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/training/TrainWithHpo.class */
public final class TrainWithHpo {
    private static final Logger logger = LoggerFactory.getLogger(TrainWithHpo.class);

    private TrainWithHpo() {
    }

    public static void main(String[] strArr) throws IOException, ParseException {
        runExample(strArr);
    }

    public static ExampleTrainingResult runExample(String[] strArr) throws IOException, ParseException {
        Arguments parseArgs = Arguments.parseArgs(strArr);
        RandomAccessDataset dataset = getDataset(Dataset.Usage.TRAIN, parseArgs);
        RandomAccessDataset dataset2 = getDataset(Dataset.Usage.TEST, parseArgs);
        HpORandom hpORandom = new HpORandom(new HpSet("hp", Arrays.asList(new HpInt("hiddenLayersSize", 10, 100), new HpInt("hiddenLayersCount", 2, 10))));
        for (int i = 0; i < 50; i++) {
            HpSet nextConfig = hpORandom.nextConfig();
            Pair<Model, ExampleTrainingResult> train = train(parseArgs, nextConfig, dataset, dataset2);
            ((Model) train.getKey()).close();
            ExampleTrainingResult exampleTrainingResult = (ExampleTrainingResult) train.getValue();
            hpORandom.update(nextConfig, exampleTrainingResult.getLoss());
            logger.info("--------- hp test {}/{} - Loss {} - {}", new Object[]{Integer.valueOf(i), 50, Float.valueOf(exampleTrainingResult.getLoss()), nextConfig});
        }
        HpSet hpSet = (HpSet) hpORandom.getBest().getKey();
        Pair<Model, ExampleTrainingResult> train2 = train(parseArgs, hpSet, dataset, dataset2);
        ExampleTrainingResult exampleTrainingResult2 = (ExampleTrainingResult) train2.getValue();
        Model model = (Model) train2.getKey();
        Throwable th = null;
        try {
            try {
                logger.info("--------- FINAL_HP - Loss {} - {}", Float.valueOf(exampleTrainingResult2.getLoss()), hpSet);
                model.save(Paths.get(parseArgs.getOutputDir(), new String[0]), "mlp");
                if (model != null) {
                    if (0 != 0) {
                        try {
                            model.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        model.close();
                    }
                }
                return exampleTrainingResult2;
            } finally {
            }
        } catch (Throwable th3) {
            if (model != null) {
                if (th != null) {
                    try {
                        model.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    model.close();
                }
            }
            throw th3;
        }
    }

    private static Pair<Model, ExampleTrainingResult> train(Arguments arguments, HpSet hpSet, RandomAccessDataset randomAccessDataset, RandomAccessDataset randomAccessDataset2) throws IOException {
        int[] iArr = new int[((Integer) hpSet.getHParam("hiddenLayersCount").random()).intValue()];
        Arrays.fill(iArr, ((Integer) hpSet.getHParam("hiddenLayersSize").random()).intValue());
        Mlp mlp = new Mlp(784, 10, iArr);
        Model newInstance = Model.newInstance();
        newInstance.setBlock(mlp);
        DefaultTrainingConfig defaultTrainingConfig = setupTrainingConfig(arguments);
        defaultTrainingConfig.addTrainingListeners(TrainingListener.Defaults.logging(TrainWithHpo.class.getSimpleName(), arguments.getBatchSize(), (int) randomAccessDataset.getNumIterations(), (int) randomAccessDataset2.getNumIterations(), arguments.getOutputDir()));
        Trainer newTrainer = newInstance.newTrainer(defaultTrainingConfig);
        Throwable th = null;
        try {
            try {
                newTrainer.setMetrics(new Metrics());
                newTrainer.initialize(new Shape[]{new Shape(new long[]{1, 784})});
                TrainingUtils.fit(newTrainer, arguments.getEpoch(), randomAccessDataset, randomAccessDataset2, arguments.getOutputDir(), "mlp");
                ExampleTrainingResult exampleTrainingResult = new ExampleTrainingResult(newTrainer);
                if (newTrainer != null) {
                    if (0 != 0) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                return new Pair<>(newInstance, exampleTrainingResult);
            } finally {
            }
        } catch (Throwable th3) {
            if (newTrainer != null) {
                if (th != null) {
                    try {
                        newTrainer.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    newTrainer.close();
                }
            }
            throw th3;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).addEvaluator(new Accuracy()).setBatchSize(arguments.getBatchSize()).optDevices(Device.getDevices(arguments.getMaxGpus()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Mnist build = Mnist.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optMaxIteration(arguments.getMaxIterations()).build();
        build.prepare(new ProgressBar());
        return build;
    }
}
