package ai.djl.huggingface.translator;

import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput;
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;

/* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotClassificationTranslator.class */
public class ZeroShotClassificationTranslator implements NoBatchifyTranslator<ZeroShotClassificationInput, ZeroShotClassificationOutput> {
    private HuggingFaceTokenizer tokenizer;
    private boolean int32;
    private Predictor<NDList, NDList> predictor;

    /* loaded from: input_file:ai/djl/huggingface/translator/ZeroShotClassificationTranslator$Builder.class */
    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        public Builder optInt32(boolean z) {
            this.int32 = z;
            return this;
        }

        public void configure(Map<String, ?> map) {
            optInt32(ArgumentsUtil.booleanValue(map, "int32"));
        }

        public ZeroShotClassificationTranslator build() {
            return new ZeroShotClassificationTranslator(this.tokenizer, this.int32);
        }
    }

    ZeroShotClassificationTranslator(HuggingFaceTokenizer huggingFaceTokenizer, boolean z) {
        this.tokenizer = huggingFaceTokenizer;
        this.int32 = z;
    }

    public void prepare(TranslatorContext translatorContext) throws IOException, ModelException {
        this.predictor = translatorContext.getModel().newPredictor(new NoopTranslator((Batchifier) null));
        translatorContext.getPredictorManager().attachInternal(UUID.randomUUID().toString(), new AutoCloseable[]{this.predictor});
    }

    public NDList processInput(TranslatorContext translatorContext, ZeroShotClassificationInput zeroShotClassificationInput) {
        translatorContext.setAttachment("input", zeroShotClassificationInput);
        return new NDList();
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public ZeroShotClassificationOutput m187processOutput(TranslatorContext translatorContext, NDList nDList) throws TranslateException {
        ZeroShotClassificationInput zeroShotClassificationInput = (ZeroShotClassificationInput) translatorContext.getAttachment("input");
        String hypothesisTemplate = zeroShotClassificationInput.getHypothesisTemplate();
        String[] candidates = zeroShotClassificationInput.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        NDManager nDManager = translatorContext.getNDManager();
        NDList nDList2 = new NDList(candidates.length);
        for (String str : candidates) {
            nDList2.add((NDArray) ((NDList) this.predictor.predict(Batchifier.STACK.batchify(new NDList[]{this.tokenizer.encode(zeroShotClassificationInput.getText(), applyTemplate(hypothesisTemplate, str)).toNDList(nDManager, false, this.int32)}))).get(0));
        }
        NDArray concat = NDArrays.concat(nDList2);
        NDArray softmax = zeroShotClassificationInput.isMultiLabel() ? concat.get(":, -1", new Object[0]).softmax(-1) : concat.get(new NDIndex(":, {}", new Object[]{nDManager.create(new int[]{0, 2})})).softmax(1).get(":, -1", new Object[0]);
        long[] longArray = softmax.argSort(-1, false).toLongArray();
        float[] floatArray = softmax.toFloatArray();
        String[] strArr = new String[candidates.length];
        double[] dArr = new double[candidates.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = candidates[(int) longArray[i]];
            dArr[i] = floatArray[r0];
        }
        return new ZeroShotClassificationOutput(zeroShotClassificationInput.getText(), strArr, dArr);
    }

    private String applyTemplate(String str, String str2) {
        int indexOf = str.indexOf("{}");
        if (indexOf == -1) {
            return str + str2;
        }
        return str.substring(0, indexOf) + str2 + str.substring(indexOf + 2, str.length());
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
