package ai.djl.examples.training.transferlearning;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.Cifar10;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.class */
public final class TrainResnetWithCifar10 {
    private static final Logger logger = LoggerFactory.getLogger(TrainResnetWithCifar10.class);

    private TrainResnetWithCifar10() {
    }

    public static void main(String[] strArr) throws ParseException, ModelException, IOException, TranslateException {
        runExample(strArr);
    }

    public static TrainingResult runExample(String[] strArr) throws IOException, ParseException, ModelException, TranslateException {
        Arguments parseArgs = Arguments.parseArgs(strArr);
        Model model = getModel(parseArgs);
        Throwable th = null;
        try {
            RandomAccessDataset dataset = getDataset(Dataset.Usage.TRAIN, parseArgs);
            RandomAccessDataset dataset2 = getDataset(Dataset.Usage.TEST, parseArgs);
            Trainer newTrainer = model.newTrainer(setupTrainingConfig(parseArgs));
            Throwable th2 = null;
            try {
                try {
                    newTrainer.setMetrics(new Metrics());
                    newTrainer.initialize(new Shape[]{new Shape(new long[]{1, 3, 32, 32})});
                    EasyTrain.fit(newTrainer, parseArgs.getEpoch(), dataset, dataset2);
                    TrainingResult trainingResult = newTrainer.getTrainingResult();
                    model.setProperty("Epoch", String.valueOf(trainingResult.getEpoch()));
                    model.setProperty("Accuracy", String.format("%.5f", trainingResult.getValidateEvaluation("Accuracy")));
                    model.setProperty("Loss", String.format("%.5f", trainingResult.getValidateLoss()));
                    Path path = Paths.get("build/model", new String[0]);
                    model.save(path, "resnetv1");
                    logger.info("Predict result: {}", testSaveParameters(model.getBlock(), path).topK(3));
                    if (newTrainer != null) {
                        if (0 != 0) {
                            try {
                                newTrainer.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newTrainer.close();
                        }
                    }
                    return trainingResult;
                } finally {
                }
            } catch (Throwable th4) {
                if (newTrainer != null) {
                    if (th2 != null) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (model != null) {
                if (0 != 0) {
                    try {
                        model.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    model.close();
                }
            }
        }
    }

    private static Model getModel(Arguments arguments) throws IOException, ModelNotFoundException, MalformedModelException {
        boolean isSymbolic = arguments.isSymbolic();
        boolean isPreTrained = arguments.isPreTrained();
        Map<String, String> criteria = arguments.getCriteria();
        Criteria.Builder optArtifactId = Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION).setTypes(Image.class, Classifications.class).optProgress(new ProgressBar()).optArtifactId("resnet");
        if (!isSymbolic) {
            if (!isPreTrained) {
                Model newInstance = Model.newInstance("resnetv1");
                newInstance.setBlock(ResNetV1.builder().setImageShape(new Shape(new long[]{3, 32, 32})).setNumLayers(50).setOutSize(10L).build());
                return newInstance;
            }
            optArtifactId.optGroupId("ai.djl.zoo");
            if (criteria == null) {
                optArtifactId.optFilter("layers", "50");
                optArtifactId.optFilter("flavor", "v1");
                optArtifactId.optFilter("dataset", "cifar10");
            } else {
                optArtifactId.optFilters(criteria);
            }
            return ModelZoo.loadModel(optArtifactId.build());
        }
        optArtifactId.optGroupId("ai.djl.mxnet");
        if (criteria == null) {
            optArtifactId.optFilter("layers", "50");
            optArtifactId.optFilter("flavor", "v1");
        } else {
            optArtifactId.optFilters(criteria);
        }
        ZooModel loadModel = ModelZoo.loadModel(optArtifactId.build());
        SequentialBlock sequentialBlock = new SequentialBlock();
        SymbolBlock block = loadModel.getBlock();
        block.removeLastBlock();
        sequentialBlock.add(block);
        sequentialBlock.add(Blocks.batchFlattenBlock());
        sequentialBlock.add(Linear.builder().setOutChannels(10L).build());
        loadModel.setBlock(sequentialBlock);
        if (!isPreTrained) {
            loadModel.getBlock().clear();
        }
        return loadModel;
    }

    private static Classifications testSaveParameters(Block block, Path path) throws IOException, ModelException, TranslateException {
        ImageClassificationTranslator build = ImageClassificationTranslator.builder().addTransform(new ToTensor()).addTransform(new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD)).optSynsetUrl(new URL("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt")).optApplySoftmax(true).build();
        Image fromUrl = ImageFactory.getInstance().fromUrl("src/test/resources/airplane1.png");
        ZooModel loadModel = ModelZoo.loadModel(Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls(path.toUri().toString()).optTranslator(build).optBlock(block).optModelName("resnetv1").build());
        Throwable th = null;
        try {
            Predictor newPredictor = loadModel.newPredictor();
            Throwable th2 = null;
            try {
                Classifications classifications = (Classifications) newPredictor.predict(fromUrl);
                if (newPredictor != null) {
                    if (0 != 0) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                return classifications;
            } catch (Throwable th4) {
                if (newPredictor != null) {
                    if (0 != 0) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (loadModel != null) {
                if (0 != 0) {
                    try {
                        loadModel.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    loadModel.close();
                }
            }
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).addEvaluator(new Accuracy()).optDevices(Device.getDevices(arguments.getMaxGpus())).addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Cifar10 build = Cifar10.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optLimit(arguments.getLimit()).optPipeline(new Pipeline(new Transform[]{new ToTensor(), new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD)})).build();
        build.prepare(new ProgressBar());
        return build;
    }
}
