package ai.djl.examples.inference.benchmark;

import ai.djl.ModelException;
import ai.djl.examples.inference.benchmark.util.AbstractBenchmark;
import ai.djl.examples.inference.benchmark.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/benchmark/MultithreadedBenchmark.class */
public class MultithreadedBenchmark extends AbstractBenchmark {
    private static final Logger logger = LoggerFactory.getLogger(MultithreadedBenchmark.class);

    /* loaded from: input_file:ai/djl/examples/inference/benchmark/MultithreadedBenchmark$PredictorCallable.class */
    private static class PredictorCallable implements Callable<Object> {
        private Predictor predictor;
        private Object inputData;
        private Metrics metrics;
        private String workerId;
        private boolean collectMemory;
        private AtomicInteger counter;
        private int total;
        private int steps;

        public PredictorCallable(ZooModel<?, ?> zooModel, Object obj, Metrics metrics, AtomicInteger atomicInteger, int i, boolean z) {
            this.predictor = zooModel.newPredictor();
            this.inputData = obj;
            this.metrics = metrics;
            this.counter = atomicInteger;
            this.workerId = String.format("%02d", Integer.valueOf(i));
            this.collectMemory = z;
            this.predictor.setMetrics(metrics);
            this.total = atomicInteger.get();
            if (this.total < 10) {
                this.steps = 1;
            } else {
                this.steps = (int) Math.pow(10.0d, ((int) Math.log10(this.total)) - 1);
            }
        }

        @Override // java.util.concurrent.Callable
        public Object call() throws TranslateException {
            Object obj = null;
            int i = 0;
            while (true) {
                int decrementAndGet = this.counter.decrementAndGet();
                if (decrementAndGet <= 0) {
                    MultithreadedBenchmark.logger.debug("Worker-{}: finished.", this.workerId);
                    return obj;
                }
                obj = this.predictor.predict(this.inputData);
                if (this.collectMemory) {
                    MemoryTrainingListener.collectMemoryInfo(this.metrics);
                }
                int i2 = this.total - decrementAndGet;
                i++;
                MultithreadedBenchmark.logger.trace("Worker-{}: {} iteration finished.", this.workerId, Integer.valueOf(i));
                if (i2 % this.steps == 0) {
                    MultithreadedBenchmark.logger.info("Completed {} requests", Integer.valueOf(i2));
                }
            }
        }
    }

    public static void main(String[] strArr) {
        if (new MultithreadedBenchmark().runBenchmark(strArr)) {
            System.exit(0);
        }
        System.exit(-1);
    }

    @Override // ai.djl.examples.inference.benchmark.util.AbstractBenchmark
    public Object predict(Arguments arguments, Metrics metrics, int i) throws IOException, ModelException, ClassNotFoundException {
        Object inputData = arguments.getInputData();
        ZooModel<?, ?> loadModel = loadModel(arguments, metrics);
        int threads = arguments.getThreads();
        AtomicInteger atomicInteger = new AtomicInteger(i);
        logger.info("Multithreaded inference with {} threads.", Integer.valueOf(threads));
        ArrayList arrayList = new ArrayList(threads);
        int i2 = 0;
        while (i2 < threads) {
            arrayList.add(new PredictorCallable(loadModel, inputData, metrics, atomicInteger, i2, i2 == 0));
            i2++;
        }
        Object obj = null;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(threads);
        int i3 = 0;
        try {
            try {
                metrics.addMetric("mt_start", Long.valueOf(System.currentTimeMillis()), "mills");
                Iterator it = newFixedThreadPool.invokeAll(arrayList).iterator();
                while (it.hasNext()) {
                    try {
                        obj = ((Future) it.next()).get();
                        i3++;
                    } catch (InterruptedException | ExecutionException e) {
                        logger.error("", e);
                    }
                }
                newFixedThreadPool.shutdown();
            } catch (Throwable th) {
                newFixedThreadPool.shutdown();
                throw th;
            }
        } catch (InterruptedException e2) {
            logger.error("", e2);
            newFixedThreadPool.shutdown();
        }
        if (i3 != threads) {
            logger.error("Only {}/{} threads finished.", Integer.valueOf(i3), Integer.valueOf(threads));
        }
        return obj;
    }
}
