package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.translator.BaseImagePreProcessor;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

/* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotImageClassificationTranslator.class */
public class ZeroShotImageClassificationTranslator implements NoBatchifyTranslator<VisionLanguageInput, Classifications> {
    private HuggingFaceTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private boolean int32;

    /* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotImageClassificationTranslator$Builder.class */
    public static final class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: self, reason: merged with bridge method [inline-methods] */
        public Builder m190self() {
            return this;
        }

        public Builder optInt32(boolean z) {
            this.int32 = z;
            return this;
        }

        public void configure(Map<String, ?> map) {
            configPreProcess(map);
            optInt32(ArgumentsUtil.booleanValue(map, "int32"));
        }

        public ZeroShotImageClassificationTranslator build() throws IOException {
            return new ZeroShotImageClassificationTranslator(this.tokenizer, new BaseImagePreProcessor(this), this.int32);
        }
    }

    ZeroShotImageClassificationTranslator(HuggingFaceTokenizer huggingFaceTokenizer, BaseImageTranslator<?> baseImageTranslator, boolean z) {
        this.tokenizer = huggingFaceTokenizer;
        this.imageProcessor = baseImageTranslator;
        this.int32 = z;
    }

    public NDList processInput(TranslatorContext translatorContext, VisionLanguageInput visionLanguageInput) throws TranslateException {
        NDManager nDManager = translatorContext.getNDManager();
        String hypothesisTemplate = visionLanguageInput.getHypothesisTemplate();
        String[] candidates = visionLanguageInput.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        ArrayList arrayList = new ArrayList(candidates.length);
        for (String str : candidates) {
            arrayList.add(applyTemplate(hypothesisTemplate, str));
        }
        NDList nDList = Encoding.toNDList(this.tokenizer.batchEncode(arrayList), nDManager, false, this.int32);
        nDList.add(((NDArray) this.imageProcessor.processInput(translatorContext, visionLanguageInput.getImage()).get(0)).expandDims(0));
        translatorContext.setAttachment("candidates", candidates);
        return nDList;
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public Classifications m189processOutput(TranslatorContext translatorContext, NDList nDList) throws TranslateException {
        NDArray softmax = nDList.get("logits_per_image").squeeze().softmax(0);
        String[] strArr = (String[]) translatorContext.getAttachment("candidates");
        return new Classifications(Arrays.asList(strArr), softmax, strArr.length);
    }

    private String applyTemplate(String str, String str2) {
        int indexOf = str.indexOf("{}");
        if (indexOf == -1) {
            return str + str2;
        }
        return str.substring(0, indexOf) + str2 + str.substring(indexOf + 2, str.length());
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
