package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/djl/modality/cv/translator/YoloPoseTranslator.class */
public class YoloPoseTranslator extends BaseImageTranslator<Joints[]> {
    private static final int MAX_DETECTION = 300;
    private float threshold;
    private float nmsThreshold;

    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloPoseTranslator$Builder.class */
    public static class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold = 0.25f;
        float nmsThreshold = 0.7f;

        Builder() {
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return self();
        }

        public Builder optNmsThreshold(float f) {
            this.nmsThreshold = f;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        protected void configPostProcess(Map<String, ?> map) {
            optThreshold(ArgumentsUtil.floatValue(map, "threshold", this.threshold));
            optNmsThreshold(ArgumentsUtil.floatValue(map, "nmsThreshold", this.nmsThreshold));
        }

        public YoloPoseTranslator build() {
            validate();
            return new YoloPoseTranslator(this);
        }
    }

    public YoloPoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    @Override // ai.djl.translate.PostProcessor
    public Joints[] processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        NDArray gt = singletonOrThrow.get(4).gt(Float.valueOf(this.threshold));
        NDArray transpose = singletonOrThrow.transpose();
        NDList split = xywh2xyxy(transpose.get("..., :4", new Object[0])).concat(transpose.get("..., 4:", new Object[0]), -1).get(gt).split(new long[]{4, 5}, 1);
        NDArray nDArray = split.get(0);
        int intExact = Math.toIntExact(nDArray.getShape().get(0));
        float[] floatArray = nDArray.toFloatArray();
        float[] floatArray2 = split.get(1).toFloatArray();
        float[] floatArray3 = split.get(2).toFloatArray();
        ArrayList arrayList = new ArrayList(intExact);
        ArrayList arrayList2 = new ArrayList(intExact);
        for (int i = 0; i < intExact; i++) {
            arrayList.add(new Rectangle(floatArray[i * 4], floatArray[(i * 4) + 1], floatArray[(i * 4) + 2] - r0, floatArray[(i * 4) + 3] - r0));
            arrayList2.add(Double.valueOf(floatArray2[i]));
        }
        List<Integer> nms = Rectangle.nms(arrayList, arrayList2, this.nmsThreshold);
        if (nms.size() > MAX_DETECTION) {
            nms = nms.subList(0, MAX_DETECTION);
        }
        Joints[] jointsArr = new Joints[nms.size()];
        for (int i2 = 0; i2 < jointsArr.length; i2++) {
            ArrayList arrayList3 = new ArrayList();
            jointsArr[i2] = new Joints(arrayList3);
            int intValue = nms.get(i2).intValue() * 51;
            for (int i3 = 0; i3 < 17; i3++) {
                arrayList3.add(new Joints.Joint(floatArray3[intValue + (i3 * 3)] / this.width, floatArray3[(intValue + (i3 * 3)) + 1] / this.height, floatArray3[intValue + (i3 * 3) + 2]));
            }
        }
        return jointsArr;
    }

    private NDArray xywh2xyxy(NDArray nDArray) {
        NDArray nDArray2 = nDArray.get("..., :2", new Object[0]);
        NDArray div = nDArray.get("..., 2:", new Object[0]).div((Number) 2);
        return nDArray2.sub(div).concat(nDArray2.add(div), -1);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
