package com.redis.om.spring.vectorize.face;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:com/redis/om/spring/vectorize/face/FaceDetectionTranslator.class */
public class FaceDetectionTranslator implements Translator<Image, DetectedObjects> {
    private double confThresh;
    private double nmsThresh;
    private int topK;
    private double[] variance;
    private int[][] scales;
    private int[] steps;
    private int width;
    private int height;

    public FaceDetectionTranslator(double d, double d2, double[] dArr, int i, int[][] iArr, int[] iArr2) {
        this.confThresh = d;
        this.nmsThresh = d2;
        this.variance = dArr;
        this.topK = i;
        this.scales = iArr;
        this.steps = iArr2;
    }

    public NDList processInput(TranslatorContext translatorContext, Image image) {
        this.width = image.getWidth();
        this.height = image.getHeight();
        NDArray flip = image.toNDArray(translatorContext.getNDManager(), Image.Flag.COLOR).transpose(new int[]{2, 0, 1}).flip(new int[]{0});
        if (!flip.getDataType().equals(DataType.FLOAT32)) {
            flip = flip.toType(DataType.FLOAT32, false);
        }
        return new NDList(new NDArray[]{flip.sub(translatorContext.getNDManager().create(new float[]{104.0f, 117.0f, 123.0f}, new Shape(new long[]{3, 1, 1})))});
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public DetectedObjects m61processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDManager nDManager = translatorContext.getNDManager();
        double d = this.variance[0];
        double d2 = this.variance[1];
        NDArray nDArray = ((NDArray) nDList.get(1)).get(":, 1:", new Object[0]);
        NDArray stack = NDArrays.stack(new NDList(new NDArray[]{nDArray.argMax(1).toType(DataType.FLOAT32, false), nDArray.max(new int[]{1})}));
        NDArray boxRecover = boxRecover(nDManager, this.width, this.height, this.scales, this.steps);
        NDArray nDArray2 = (NDArray) nDList.get(0);
        NDArray mul = nDArray2.get(":, 2:", new Object[0]).mul(Double.valueOf(d2)).exp().mul(boxRecover.get(":, 2:", new Object[0]));
        NDArray concat = NDArrays.concat(new NDList(new NDArray[]{nDArray2.get(":, :2", new Object[0]).mul(Double.valueOf(d)).mul(boxRecover.get(":, 2:", new Object[0])).add(boxRecover.get(":, :2", new Object[0])).sub(mul.mul(Float.valueOf(0.5f))), mul}), 1);
        NDArray decodeLandm = decodeLandm((NDArray) nDList.get(2), boxRecover, d);
        NDArray gt = stack.get(new long[]{1}).gt(Double.valueOf(this.confThresh));
        NDArray transpose = concat.transpose().booleanMask(gt, 1).transpose();
        NDArray transpose2 = decodeLandm.transpose().booleanMask(gt, 1).transpose();
        NDArray booleanMask = stack.booleanMask(gt, 1);
        long[] longArray = booleanMask.get(new long[]{1}).argSort().get(":" + this.topK, new Object[0]).toLongArray();
        NDArray transpose3 = booleanMask.transpose();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (int length = longArray.length - 1; length >= 0; length--) {
            long j = longArray[length];
            float[] floatArray = transpose3.get(new long[]{j}).toFloatArray();
            int i = (int) floatArray[0];
            double d3 = floatArray[1];
            double[] doubleArray = transpose.get(new long[]{j}).toDoubleArray();
            double[] doubleArray2 = transpose2.get(new long[]{j}).toDoubleArray();
            Rectangle rectangle = new Rectangle(doubleArray[0], doubleArray[1], doubleArray[2], doubleArray[3]);
            List list = (List) concurrentHashMap.getOrDefault(Integer.valueOf(i), new ArrayList());
            boolean z = true;
            Iterator it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((BoundingBox) it.next()).getIoU(rectangle) > this.nmsThresh) {
                    z = false;
                    break;
                }
            }
            if (z) {
                ArrayList arrayList4 = new ArrayList();
                for (int i2 = 0; i2 < 5; i2++) {
                    arrayList4.add(new Point(doubleArray2[i2 * 2] * this.width, doubleArray2[(i2 * 2) + 1] * this.height));
                }
                Landmark landmark = new Landmark(doubleArray[0], doubleArray[1], doubleArray[2], doubleArray[3], arrayList4);
                list.add(landmark);
                concurrentHashMap.put(Integer.valueOf(i), list);
                arrayList.add("Face");
                arrayList2.add(Double.valueOf(d3));
                arrayList3.add(landmark);
            }
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    private NDArray boxRecover(NDManager nDManager, int i, int i2, int[][] iArr, int[] iArr2) {
        int[][] iArr3 = new int[iArr2.length][2];
        for (int i3 = 0; i3 < iArr2.length; i3++) {
            int ceil = (int) Math.ceil(i / iArr2[i3]);
            int[] iArr4 = new int[2];
            iArr4[0] = (int) Math.ceil(i2 / iArr2[i3]);
            iArr4[1] = ceil;
            iArr3[i3] = iArr4;
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            int[] iArr5 = iArr[i4];
            for (int i5 = 0; i5 < iArr3[i4][0]; i5++) {
                for (int i6 = 0; i6 < iArr3[i4][1]; i6++) {
                    for (int i7 : iArr5) {
                        arrayList.add(new double[]{((i6 + 0.5d) * iArr2[i4]) / i, ((i5 + 0.5d) * iArr2[i4]) / i2, (i7 * 1.0d) / i, (i7 * 1.0d) / i2});
                    }
                }
            }
        }
        double[][] dArr = new double[arrayList.size()][((double[]) arrayList.get(0)).length];
        for (int i8 = 0; i8 < arrayList.size(); i8++) {
            dArr[i8] = (double[]) arrayList.get(i8);
        }
        return nDManager.create(dArr).clip(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    private NDArray decodeLandm(NDArray nDArray, NDArray nDArray2, double d) {
        return NDArrays.concat(new NDList(new NDArray[]{nDArray.get(":, :2", new Object[0]).mul(Double.valueOf(d)).mul(nDArray2.get(":, 2:", new Object[0])).add(nDArray2.get(":, :2", new Object[0])), nDArray.get(":, 2:4", new Object[0]).mul(Double.valueOf(d)).mul(nDArray2.get(":, 2:", new Object[0])).add(nDArray2.get(":, :2", new Object[0])), nDArray.get(":, 4:6", new Object[0]).mul(Double.valueOf(d)).mul(nDArray2.get(":, 2:", new Object[0])).add(nDArray2.get(":, :2", new Object[0])), nDArray.get(":, 6:8", new Object[0]).mul(Double.valueOf(d)).mul(nDArray2.get(":, 2:", new Object[0])).add(nDArray2.get(":, :2", new Object[0])), nDArray.get(":, 8:10", new Object[0]).mul(Double.valueOf(d)).mul(nDArray2.get(":, 2:", new Object[0])).add(nDArray2.get(":, :2", new Object[0]))}), 1);
    }
}
