/*
 * Decompiled with CFR 0.152.
 */
package com.redis.om.spring.vectorize.face;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

public class FaceDetectionTranslator
implements Translator<Image, DetectedObjects> {
    private double confThresh;
    private double nmsThresh;
    private int topK;
    private double[] variance;
    private int[][] scales;
    private int[] steps;
    private int width;
    private int height;

    public FaceDetectionTranslator(double confThresh, double nmsThresh, double[] variance, int topK, int[][] scales, int[] steps) {
        this.confThresh = confThresh;
        this.nmsThresh = nmsThresh;
        this.variance = variance;
        this.topK = topK;
        this.scales = scales;
        this.steps = steps;
    }

    public NDList processInput(TranslatorContext ctx, Image input) {
        this.width = input.getWidth();
        this.height = input.getHeight();
        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
        if (!(array = array.transpose(new int[]{2, 0, 1}).flip(new int[]{0})).getDataType().equals((Object)DataType.FLOAT32)) {
            array = array.toType(DataType.FLOAT32, false);
        }
        NDArray mean = ctx.getNDManager().create(new float[]{104.0f, 117.0f, 123.0f}, new Shape(new long[]{3L, 1L, 1L}));
        array = array.sub(mean);
        return new NDList(new NDArray[]{array});
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDManager manager = ctx.getNDManager();
        double scaleXY = this.variance[0];
        double scaleWH = this.variance[1];
        NDArray prob = ((NDArray)list.get(1)).get(":, 1:", new Object[0]);
        prob = NDArrays.stack((NDList)new NDList(new NDArray[]{prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[]{1})}));
        NDArray boxRecover = this.boxRecover(manager, this.width, this.height, this.scales, this.steps);
        NDArray boundingBoxes = (NDArray)list.get(0);
        NDArray bbWH = boundingBoxes.get(":, 2:", new Object[0]).mul((Number)scaleWH).exp().mul(boxRecover.get(":, 2:", new Object[0]));
        NDArray bbXY = boundingBoxes.get(":, :2", new Object[0]).mul((Number)scaleXY).mul(boxRecover.get(":, 2:", new Object[0])).add(boxRecover.get(":, :2", new Object[0])).sub(bbWH.mul((Number)Float.valueOf(0.5f)));
        boundingBoxes = NDArrays.concat((NDList)new NDList(new NDArray[]{bbXY, bbWH}), (int)1);
        NDArray landms = (NDArray)list.get(2);
        landms = this.decodeLandm(landms, boxRecover, scaleXY);
        NDArray cutOff = prob.get(new long[]{1L}).gt((Number)this.confThresh);
        boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
        landms = landms.transpose().booleanMask(cutOff, 1).transpose();
        prob = prob.booleanMask(cutOff, 1);
        long[] order = prob.get(new long[]{1L}).argSort().get(":" + this.topK, new Object[0]).toLongArray();
        prob = prob.transpose();
        ArrayList<String> retNames = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<Landmark> retBB = new ArrayList<Landmark>();
        ConcurrentHashMap<Integer, List> recorder = new ConcurrentHashMap<Integer, List>();
        for (int i = order.length - 1; i >= 0; --i) {
            long currMaxLoc = order[i];
            float[] classProb = prob.get(new long[]{currMaxLoc}).toFloatArray();
            int classId = (int)classProb[0];
            double probability = classProb[1];
            double[] boxArr = boundingBoxes.get(new long[]{currMaxLoc}).toDoubleArray();
            double[] landmsArr = landms.get(new long[]{currMaxLoc}).toDoubleArray();
            Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
            List boxes = recorder.getOrDefault(classId, new ArrayList());
            boolean belowIoU = true;
            for (BoundingBox box : boxes) {
                if (!(box.getIoU((BoundingBox)rect) > this.nmsThresh)) continue;
                belowIoU = false;
                break;
            }
            if (!belowIoU) continue;
            ArrayList<Point> keyPoints = new ArrayList<Point>();
            for (int j = 0; j < 5; ++j) {
                double x = landmsArr[j * 2];
                double y = landmsArr[j * 2 + 1];
                keyPoints.add(new Point(x * (double)this.width, y * (double)this.height));
            }
            Landmark landmark = new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints);
            boxes.add(landmark);
            recorder.put(classId, boxes);
            String className = "Face";
            retNames.add(className);
            retProbs.add(probability);
            retBB.add(landmark);
        }
        return new DetectedObjects(retNames, retProbs, retBB);
    }

    private NDArray boxRecover(NDManager manager, int width, int height, int[][] scales, int[] steps) {
        int[][] aspectRatio = new int[steps.length][2];
        for (int i = 0; i < steps.length; ++i) {
            int wRatio = (int)Math.ceil((float)width / (float)steps[i]);
            int hRatio = (int)Math.ceil((float)height / (float)steps[i]);
            aspectRatio[i] = new int[]{hRatio, wRatio};
        }
        ArrayList<double[]> defaultBoxes = new ArrayList<double[]>();
        for (int idx = 0; idx < steps.length; ++idx) {
            int[] scale = scales[idx];
            for (int h = 0; h < aspectRatio[idx][0]; ++h) {
                for (int w = 0; w < aspectRatio[idx][1]; ++w) {
                    for (int i : scale) {
                        double skx = (double)i * 1.0 / (double)width;
                        double sky = (double)i * 1.0 / (double)height;
                        double cx = ((double)w + 0.5) * (double)steps[idx] / (double)width;
                        double cy = ((double)h + 0.5) * (double)steps[idx] / (double)height;
                        defaultBoxes.add(new double[]{cx, cy, skx, sky});
                    }
                }
            }
        }
        double[][] boxes = new double[defaultBoxes.size()][((double[])defaultBoxes.get(0)).length];
        for (int i = 0; i < defaultBoxes.size(); ++i) {
            boxes[i] = (double[])defaultBoxes.get(i);
        }
        return manager.create(boxes).clip((Number)0.0, (Number)1.0);
    }

    private NDArray decodeLandm(NDArray pre, NDArray priors, double scaleXY) {
        NDArray point1 = pre.get(":, :2", new Object[0]).mul((Number)scaleXY).mul(priors.get(":, 2:", new Object[0])).add(priors.get(":, :2", new Object[0]));
        NDArray point2 = pre.get(":, 2:4", new Object[0]).mul((Number)scaleXY).mul(priors.get(":, 2:", new Object[0])).add(priors.get(":, :2", new Object[0]));
        NDArray point3 = pre.get(":, 4:6", new Object[0]).mul((Number)scaleXY).mul(priors.get(":, 2:", new Object[0])).add(priors.get(":, :2", new Object[0]));
        NDArray point4 = pre.get(":, 6:8", new Object[0]).mul((Number)scaleXY).mul(priors.get(":, 2:", new Object[0])).add(priors.get(":, :2", new Object[0]));
        NDArray point5 = pre.get(":, 8:10", new Object[0]).mul((Number)scaleXY).mul(priors.get(":, 2:", new Object[0])).add(priors.get(":, :2", new Object[0]));
        return NDArrays.concat((NDList)new NDList(new NDArray[]{point1, point2, point3, point4, point5}), (int)1);
    }
}

