/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.cv.poseestimation;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.modality.cv.Joints;
import ai.djl.modality.cv.SimplePoseTranslator;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Progress;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class SimplePoseModelLoader
extends BaseModelLoader<BufferedImage, Joints> {
    private static final Application APPLICATION = Application.CV.POSE_ESTIMATION;
    private static final String GROUP_ID = "ai.djl.mxnet";
    private static final String ARTIFACT_ID = "simple_pose";
    private static final String VERSION = "0.0.1";

    public SimplePoseModelLoader(Repository repository) {
        super(repository, MRL.model((Application)APPLICATION, (String)GROUP_ID, (String)ARTIFACT_ID), VERSION);
        ConcurrentHashMap<Class<Joints>, FactoryImpl> map = new ConcurrentHashMap<Class<Joints>, FactoryImpl>();
        map.put(Joints.class, new FactoryImpl());
        this.factories.put(BufferedImage.class, map);
    }

    public Application getApplication() {
        return APPLICATION;
    }

    public ZooModel<BufferedImage, Joints> loadModel(Map<String, String> filters, Device device, Progress progress) throws IOException, ModelNotFoundException, MalformedModelException {
        Criteria criteria = Criteria.builder().setTypes(BufferedImage.class, Joints.class).optFilters(filters).optDevice(device).optProgress(progress).build();
        return this.loadModel(criteria);
    }

    private static final class FactoryImpl
    implements TranslatorFactory<BufferedImage, Joints> {
        private FactoryImpl() {
        }

        public Translator<BufferedImage, Joints> newInstance(Map<String, Object> arguments) {
            int width = ((Double)arguments.getOrDefault("width", 192.0)).intValue();
            int height = ((Double)arguments.getOrDefault("height", 256.0)).intValue();
            double threshold = (Double)arguments.getOrDefault("threshold", 0.2);
            Pipeline pipeline = new Pipeline();
            pipeline.add((Transform)new Resize(width, height)).add((Transform)new ToTensor()).add((Transform)new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f}));
            return ((SimplePoseTranslator.Builder)SimplePoseTranslator.builder().setPipeline(pipeline)).optThreshold((float)threshold).build();
        }
    }
}

