package ai.djl.examples.inference.benchmark.util;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.inference.benchmark.MultithreadedBenchmark;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/benchmark/util/AbstractBenchmark.class */
public abstract class AbstractBenchmark<I, O> {
    private static final Logger logger = LoggerFactory.getLogger(AbstractBenchmark.class);
    private Class<I> input;
    private Class<O> output;
    private O lastResult;
    protected ProgressBar progressBar;
    protected int maxIterations;
    protected int iterationCount;

    public AbstractBenchmark(Class<I> cls, Class<O> cls2) {
        this.input = cls;
        this.output = cls2;
    }

    public O getPredictResult() {
        return this.lastResult;
    }

    /* JADX WARN: Finally extract failed */
    public final boolean runBenchmark(String[] strArr) {
        Options options = getOptions();
        try {
            Arguments parseArguments = parseArguments(new DefaultParser().parse(options, strArr, (Properties) null, false));
            logger.info(String.format("Load library %s in %.3f ms.", Engine.getInstance().getVersion(), Float.valueOf(((float) (System.nanoTime() - System.nanoTime())) / 1000000.0f)));
            this.maxIterations = parseArguments.getIteration();
            if (this instanceof MultithreadedBenchmark) {
                this.maxIterations = Math.max(this.maxIterations, parseArguments.getThreads() * 2);
            }
            Duration ofMinutes = Duration.ofMinutes(parseArguments.getDuration());
            if (runByIterations()) {
                logger.info("Running {} on: {}, iterations: {}.", new Object[]{getClass().getSimpleName(), Device.defaultDevice(), Integer.valueOf(this.maxIterations)});
                this.progressBar = new ProgressBar("Iteration", this.maxIterations);
            } else {
                logger.info("Running {} on: {}, duration: {} minutes.", new Object[]{getClass().getSimpleName(), Device.defaultDevice(), Long.valueOf(ofMinutes.toMinutes())});
                this.progressBar = new ProgressBar("Iteration", ofMinutes.getSeconds() * 1000);
            }
            Metrics metrics = new Metrics();
            long currentTimeMillis = System.currentTimeMillis();
            ArrayList arrayList = new ArrayList();
            try {
                try {
                    ZooModel<I, O> loadModel = loadModel(parseArguments, metrics, this.input, this.output);
                    Throwable th = null;
                    try {
                        initialize(loadModel, parseArguments, metrics);
                        while (keepPredicting(ofMinutes, currentTimeMillis)) {
                            this.iterationCount++;
                            arrayList.add(predict(loadModel, parseArguments, metrics));
                            updateProgress(this.progressBar, currentTimeMillis);
                        }
                        Iterator it = arrayList.iterator();
                        while (it.hasNext()) {
                            this.lastResult = (O) ((CompletableFuture) it.next()).get();
                        }
                        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                        if (loadModel != null) {
                            if (0 != 0) {
                                try {
                                    loadModel.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                loadModel.close();
                            }
                        }
                        clean();
                        recordResults(parseArguments, metrics, currentTimeMillis2);
                        return true;
                    } catch (Throwable th3) {
                        if (loadModel != null) {
                            if (0 != 0) {
                                try {
                                    loadModel.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                loadModel.close();
                            }
                        }
                        throw th3;
                    }
                } catch (Throwable th5) {
                    clean();
                    throw th5;
                }
            } catch (Exception e) {
                logger.error("Failed to run benchmark", e);
                throw e;
            }
        } catch (ParseException e2) {
            HelpFormatter helpFormatter = new HelpFormatter();
            helpFormatter.setLeftPadding(1);
            helpFormatter.setWidth(120);
            helpFormatter.printHelp(e2.getMessage(), options);
            return false;
        } catch (Throwable th6) {
            logger.error("Unexpected error", th6);
            return false;
        }
    }

    protected abstract void initialize(ZooModel<I, O> zooModel, Arguments arguments, Metrics metrics) throws IOException;

    protected abstract CompletableFuture<O> predict(ZooModel<I, O> zooModel, Arguments arguments, Metrics metrics) throws TranslateException;

    protected abstract void clean();

    /* JADX INFO: Access modifiers changed from: protected */
    public Options getOptions() {
        return Arguments.getOptions();
    }

    protected Arguments parseArguments(CommandLine commandLine) {
        return new Arguments(commandLine);
    }

    protected ZooModel<I, O> loadModel(Arguments arguments, Metrics metrics, Class<I> cls, Class<O> cls2) throws ModelException, IOException {
        long nanoTime = System.nanoTime();
        Criteria.Builder optProgress = Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION).setTypes(cls, cls2).optFilters(arguments.getCriteria()).optProgress(new ProgressBar());
        String modelName = arguments.getModelName();
        if (modelName == null) {
            modelName = "resnet";
        }
        optProgress.optModelLoaderName(modelName);
        ZooModel<I, O> loadModel = ModelZoo.loadModel(optProgress.build());
        long nanoTime2 = System.nanoTime() - nanoTime;
        logger.info("Model {} loaded in: {} ms.", loadModel.getName(), String.format("%.3f", Float.valueOf(((float) nanoTime2) / 1000000.0f)));
        metrics.addMetric("LoadModel", Long.valueOf(nanoTime2));
        return loadModel;
    }

    private boolean runByIterations() {
        return this.maxIterations != -1;
    }

    private boolean keepPredicting(Duration duration, long j) {
        return runByIterations() ? this.iterationCount < this.maxIterations : System.currentTimeMillis() - j < duration.getSeconds() * 1000;
    }

    private void updateProgress(ProgressBar progressBar, long j) {
        if (runByIterations()) {
            progressBar.update(this.iterationCount);
        } else {
            progressBar.update(System.currentTimeMillis() - j);
        }
    }

    private void recordResults(Arguments arguments, Metrics metrics, long j) {
        logger.info("Last inference result: {}", this.lastResult);
        logger.info(String.format("total time: %d ms, total runs: %d iterations", Long.valueOf(j), Integer.valueOf(this.iterationCount)));
        if (metrics.hasMetric("LoadModel")) {
            logger.info("Model loading time: {} ms.", String.format("%.3f", Float.valueOf(((float) ((Metric) metrics.getMetric("LoadModel").get(0)).getValue().longValue()) / 1000000.0f)));
        }
        if (!metrics.hasMetric("Inference") || this.maxIterations <= 1) {
            return;
        }
        float longValue = ((float) metrics.percentile("Inference", 50).getValue().longValue()) / 1000000.0f;
        float longValue2 = ((float) metrics.percentile("Inference", 90).getValue().longValue()) / 1000000.0f;
        float longValue3 = ((float) metrics.percentile("Inference", 99).getValue().longValue()) / 1000000.0f;
        float longValue4 = ((float) metrics.percentile("Preprocess", 50).getValue().longValue()) / 1000000.0f;
        float longValue5 = ((float) metrics.percentile("Preprocess", 90).getValue().longValue()) / 1000000.0f;
        float longValue6 = ((float) metrics.percentile("Preprocess", 99).getValue().longValue()) / 1000000.0f;
        float longValue7 = ((float) metrics.percentile("Postprocess", 50).getValue().longValue()) / 1000000.0f;
        float longValue8 = ((float) metrics.percentile("Postprocess", 90).getValue().longValue()) / 1000000.0f;
        float longValue9 = ((float) metrics.percentile("Postprocess", 99).getValue().longValue()) / 1000000.0f;
        logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue), Float.valueOf(longValue2), Float.valueOf(longValue3)));
        logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue4), Float.valueOf(longValue5), Float.valueOf(longValue6)));
        logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue7), Float.valueOf(longValue8), Float.valueOf(longValue9)));
        if (Boolean.getBoolean("collect-memory")) {
            float longValue10 = (float) metrics.percentile("Heap", 90).getValue().longValue();
            float longValue11 = (float) metrics.percentile("NonHeap", 90).getValue().longValue();
            float longValue12 = (float) metrics.percentile("cpu", 90).getValue().longValue();
            float longValue13 = (float) metrics.percentile("rss", 90).getValue().longValue();
            logger.info(String.format("heap P90: %.3f", Float.valueOf(longValue10)));
            logger.info(String.format("nonHeap P90: %.3f", Float.valueOf(longValue11)));
            logger.info(String.format("cpu P90: %.3f", Float.valueOf(longValue12)));
            logger.info(String.format("rss P90: %.3f", Float.valueOf(longValue13)));
        }
        MemoryTrainingListener.dumpMemoryInfo(metrics, arguments.getOutputDir());
    }
}
