package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.PikachuDetection;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.DetectedObjects;
import ai.djl.modality.cv.ImageVisualization;
import ai.djl.modality.cv.MultiBoxDetection;
import ai.djl.modality.cv.SingleShotDetectionTranslator;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.BoundingBoxError;
import ai.djl.training.evaluator.SingleShotDetectionAccuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import javax.imageio.ImageIO;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:ai/djl/examples/training/TrainPikachu.class */
public final class TrainPikachu {
    private TrainPikachu() {
    }

    public static void main(String[] strArr) throws IOException, ParseException {
        runExample(strArr);
    }

    public static ExampleTrainingResult runExample(String[] strArr) throws IOException, ParseException {
        Arguments parseArgs = Arguments.parseArgs(strArr);
        Model newInstance = Model.newInstance();
        Throwable th = null;
        try {
            newInstance.setBlock(getSsdTrainBlock());
            RandomAccessDataset dataset = getDataset(Dataset.Usage.TRAIN, parseArgs);
            RandomAccessDataset dataset2 = getDataset(Dataset.Usage.TEST, parseArgs);
            DefaultTrainingConfig defaultTrainingConfig = setupTrainingConfig(parseArgs);
            defaultTrainingConfig.addTrainingListeners(TrainingListener.Defaults.logging(TrainPikachu.class.getSimpleName(), parseArgs.getBatchSize(), (int) dataset.getNumIterations(), (int) dataset2.getNumIterations(), parseArgs.getOutputDir()));
            Trainer newTrainer = newInstance.newTrainer(defaultTrainingConfig);
            Throwable th2 = null;
            try {
                newTrainer.setMetrics(new Metrics());
                newTrainer.initialize(new Shape[]{new Shape(new long[]{parseArgs.getBatchSize(), 3, 256, 256})});
                TrainingUtils.fit(newTrainer, parseArgs.getEpoch(), dataset, dataset2, parseArgs.getOutputDir(), "ssd");
                ExampleTrainingResult exampleTrainingResult = new ExampleTrainingResult(newTrainer);
                if (newTrainer != null) {
                    if (0 != 0) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                newInstance.save(Paths.get(parseArgs.getOutputDir(), new String[0]), "ssd");
                if (newInstance != null) {
                    if (0 != 0) {
                        try {
                            newInstance.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        newInstance.close();
                    }
                }
                return exampleTrainingResult;
            } catch (Throwable th5) {
                if (newTrainer != null) {
                    if (0 != 0) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th6) {
                            th2.addSuppressed(th6);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (newInstance != null) {
                if (0 != 0) {
                    try {
                        newInstance.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    newInstance.close();
                }
            }
            throw th7;
        }
    }

    public static int predict(String str, String str2) throws IOException, MalformedModelException, TranslateException {
        Model newInstance = Model.newInstance();
        Throwable th = null;
        try {
            newInstance.setBlock(getSsdTrainBlock());
            newInstance.load(Paths.get(str, new String[0]), "ssd");
            newInstance.setBlock(getSsdPredictBlock(newInstance.getBlock()));
            Path path = Paths.get(str2, new String[0]);
            Pipeline pipeline = new Pipeline(new Transform[]{new ToTensor()});
            ArrayList arrayList = new ArrayList();
            arrayList.add("pikachu");
            Predictor newPredictor = newInstance.newPredictor(SingleShotDetectionTranslator.builder().setPipeline(pipeline).setClasses(arrayList).optThreshold(0.6f).build());
            Throwable th2 = null;
            try {
                try {
                    BufferedImage fromFile = BufferedImageUtils.fromFile(path);
                    DetectedObjects detectedObjects = (DetectedObjects) newPredictor.predict(fromFile);
                    ImageVisualization.drawBoundingBoxes(fromFile, detectedObjects);
                    ImageIO.write(fromFile, "png", Paths.get(str, new String[0]).resolve("pikachu_output.png").toFile());
                    int numberOfObjects = detectedObjects.getNumberOfObjects();
                    if (newPredictor != null) {
                        if (0 != 0) {
                            try {
                                newPredictor.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newPredictor.close();
                        }
                    }
                    return numberOfObjects;
                } finally {
                }
            } catch (Throwable th4) {
                if (newPredictor != null) {
                    if (th2 != null) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (newInstance != null) {
                if (0 != 0) {
                    try {
                        newInstance.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    newInstance.close();
                }
            }
        }
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        PikachuDetection build = PikachuDetection.builder().optUsage(usage).optMaxIteration(arguments.getMaxIterations()).optPipeline(new Pipeline(new Transform[]{new ToTensor()})).setSampling(arguments.getBatchSize(), true).build();
        build.prepare(new ProgressBar());
        return build;
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig(new SingleShotDetectionLoss()).setBatchSize(arguments.getBatchSize()).addEvaluator(new SingleShotDetectionAccuracy("classAccuracy")).addEvaluator(new BoundingBoxError("boundingBoxError")).optDevices(Device.getDevices(arguments.getMaxGpus()));
    }

    public static Block getSsdTrainBlock() {
        SequentialBlock sequentialBlock = new SequentialBlock();
        for (int i : new int[]{16, 32, 64}) {
            sequentialBlock.add(SingleShotDetection.getDownSamplingBlock(i));
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < 5; i2++) {
            arrayList2.add(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(0.5f)));
        }
        arrayList.add(Arrays.asList(Float.valueOf(0.2f), Float.valueOf(0.272f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.37f), Float.valueOf(0.447f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.54f), Float.valueOf(0.619f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.71f), Float.valueOf(0.79f)));
        arrayList.add(Arrays.asList(Float.valueOf(0.88f), Float.valueOf(0.961f)));
        return SingleShotDetection.builder().setNumClasses(1).setNumFeatures(3).optGlobalPool(true).setRatios(arrayList2).setSizes(arrayList).setBaseNetwork(sequentialBlock).build();
    }

    public static Block getSsdPredictBlock(Block block) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(block);
        sequentialBlock.add(new LambdaBlock(nDList -> {
            NDArray nDArray = (NDArray) nDList.get(0);
            return MultiBoxDetection.builder().build().detection(new NDList(new NDArray[]{((NDArray) nDList.get(1)).softmax(-1).transpose(new int[]{0, 2, 1}), (NDArray) nDList.get(2), nDArray})).singletonOrThrow().split(new long[]{1, 2}, 2);
        }));
        return sequentialBlock;
    }
}
