package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImagePreProcessor;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;

/* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotObjectDetectionTranslator.class */
public class ZeroShotObjectDetectionTranslator implements NoBatchifyTranslator<VisionLanguageInput, DetectedObjects> {
    private HuggingFaceTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private boolean int32;
    private float threshold;

    /* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotObjectDetectionTranslator$Builder.class */
    public static final class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;
        private float threshold = 0.2f;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: self, reason: merged with bridge method [inline-methods] */
        public Builder m193self() {
            return this;
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return this;
        }

        public Builder optInt32(boolean z) {
            this.int32 = z;
            return this;
        }

        public void configure(Map<String, ?> map) {
            configPreProcess(map);
            optInt32(ArgumentsUtil.booleanValue(map, "int32"));
            optThreshold(ArgumentsUtil.floatValue(map, "threshold", 0.2f));
        }

        public ZeroShotObjectDetectionTranslator build() throws IOException {
            return new ZeroShotObjectDetectionTranslator(this.tokenizer, new BaseImagePreProcessor(this), this.int32, this.threshold);
        }
    }

    ZeroShotObjectDetectionTranslator(HuggingFaceTokenizer huggingFaceTokenizer, BaseImageTranslator<?> baseImageTranslator, boolean z, float f) {
        this.tokenizer = huggingFaceTokenizer;
        this.imageProcessor = baseImageTranslator;
        this.int32 = z;
        this.threshold = f;
    }

    public NDList processInput(TranslatorContext translatorContext, VisionLanguageInput visionLanguageInput) throws TranslateException {
        NDManager nDManager = translatorContext.getNDManager();
        String[] candidates = visionLanguageInput.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        NDList nDList = Encoding.toNDList(this.tokenizer.batchEncode(candidates), nDManager, false, this.int32);
        nDList.add(((NDArray) this.imageProcessor.processInput(translatorContext, visionLanguageInput.getImage()).get(0)).expandDims(0));
        translatorContext.setAttachment("candidates", candidates);
        return nDList;
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public DetectedObjects m192processOutput(TranslatorContext translatorContext, NDList nDList) throws TranslateException {
        NDArray nDArray = nDList.get("logits");
        NDArray nDArray2 = nDList.get("pred_boxes");
        NDArray argMax = nDArray.argMax(-1);
        NDArray sigmoid = nDArray.max(new int[]{-1}).getNDArrayInternal().sigmoid();
        NDArray gt = sigmoid.gt(Float.valueOf(this.threshold));
        NDArray nDArray3 = sigmoid.get(gt);
        NDArray nDArray4 = argMax.get(gt);
        NDArray nDArray5 = nDArray2.get(gt);
        float[] floatArray = nDArray3.toFloatArray();
        long[] longArray = nDArray4.toLongArray();
        float[] floatArray2 = nDArray5.toFloatArray();
        String[] strArr = (String[]) translatorContext.getAttachment("candidates");
        ArrayList arrayList = new ArrayList(longArray.length);
        ArrayList arrayList2 = new ArrayList(longArray.length);
        ArrayList arrayList3 = new ArrayList(longArray.length);
        int intValue = ((Integer) translatorContext.getAttachment("width")).intValue();
        int intValue2 = ((Integer) translatorContext.getAttachment("height")).intValue();
        for (int i = 0; i < longArray.length; i++) {
            arrayList.add(strArr[(int) longArray[i]]);
            int i2 = i * 4;
            float f = floatArray2[i2];
            float f2 = floatArray2[i2 + 1];
            float f3 = floatArray2[i2 + 2];
            float f4 = floatArray2[i2 + 3];
            float f5 = f - (f3 / 2.0f);
            float f6 = f2 - (f4 / 2.0f);
            if (intValue > intValue2) {
                f6 = (f6 * intValue) / intValue2;
                f4 = (f4 * intValue) / intValue2;
            } else if (intValue < intValue2) {
                f5 = (f5 * intValue2) / intValue;
                f3 = (f3 * intValue2) / intValue;
            }
            arrayList3.add(new Rectangle(f5, f6, f3, f4));
            arrayList2.add(Double.valueOf(floatArray[i]));
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
