package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/huggingface/translator/QuestionAnsweringTranslator.class */
public class QuestionAnsweringTranslator implements Translator<QAInput, String> {
    private HuggingFaceTokenizer tokenizer;
    private boolean includeTokenTypes;
    private boolean int32;
    private Batchifier batchifier;
    private boolean detail;

    /* loaded from: input_file:ai/djl/huggingface/translator/QuestionAnsweringTranslator$Builder.class */
    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean includeTokenTypes;
        private boolean int32;
        private Batchifier batchifier = Batchifier.STACK;
        private boolean detail;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        public Builder optIncludeTokenTypes(boolean z) {
            this.includeTokenTypes = z;
            return this;
        }

        public Builder optInt32(boolean z) {
            this.int32 = z;
            return this;
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this;
        }

        public Builder optDetail(boolean z) {
            this.detail = z;
            return this;
        }

        public void configure(Map<String, ?> map) {
            optIncludeTokenTypes(ArgumentsUtil.booleanValue(map, "includeTokenTypes"));
            optInt32(ArgumentsUtil.booleanValue(map, "int32"));
            String stringValue = ArgumentsUtil.stringValue(map, "batchifier", "stack");
            optDetail(ArgumentsUtil.booleanValue(map, "detail"));
            optBatchifier(Batchifier.fromString(stringValue));
        }

        public QuestionAnsweringTranslator build() throws IOException {
            return new QuestionAnsweringTranslator(this.tokenizer, this.includeTokenTypes, this.int32, this.batchifier, this.detail);
        }
    }

    QuestionAnsweringTranslator(HuggingFaceTokenizer huggingFaceTokenizer, boolean z, boolean z2, Batchifier batchifier, boolean z3) {
        this.tokenizer = huggingFaceTokenizer;
        this.includeTokenTypes = z;
        this.int32 = z2;
        this.batchifier = batchifier;
        this.detail = z3;
    }

    public Batchifier getBatchifier() {
        return this.batchifier;
    }

    public NDList processInput(TranslatorContext translatorContext, QAInput qAInput) {
        Encoding encode = this.tokenizer.encode(qAInput.getQuestion(), qAInput.getParagraph());
        translatorContext.setAttachment("encoding", encode);
        return encode.toNDList(translatorContext.getNDManager(), this.includeTokenTypes, this.int32);
    }

    public NDList batchProcessInput(TranslatorContext translatorContext, List<QAInput> list) {
        NDManager nDManager = translatorContext.getNDManager();
        PairList<String, String> pairList = new PairList<>(list.size());
        for (QAInput qAInput : list) {
            pairList.add(qAInput.getQuestion(), qAInput.getParagraph());
        }
        Encoding[] batchEncode = this.tokenizer.batchEncode(pairList);
        translatorContext.setAttachment("encodings", batchEncode);
        NDList[] nDListArr = new NDList[batchEncode.length];
        for (int i = 0; i < batchEncode.length; i++) {
            nDListArr[i] = batchEncode[i].toNDList(nDManager, this.includeTokenTypes, this.int32);
        }
        return this.batchifier.batchify(nDListArr);
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public String m176processOutput(TranslatorContext translatorContext, NDList nDList) {
        return decode(nDList, (Encoding) translatorContext.getAttachment("encoding"));
    }

    public List<String> batchProcessOutput(TranslatorContext translatorContext, NDList nDList) {
        NDList[] unbatchify = this.batchifier.unbatchify(nDList);
        Encoding[] encodingArr = (Encoding[]) translatorContext.getAttachment("encodings");
        ArrayList arrayList = new ArrayList(unbatchify.length);
        for (int i = 0; i < encodingArr.length; i++) {
            arrayList.add(decode(unbatchify[i], encodingArr[i]));
        }
        return arrayList;
    }

    private String decode(NDList nDList, Encoding encoding) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        if ("PyTorch".equals(nDArray.getManager().getEngine().getEngineName())) {
            nDArray = nDArray.duplicate();
            nDArray2 = nDArray2.duplicate();
        }
        if (this.detail) {
            long[] sequenceIds = encoding.getSequenceIds();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < sequenceIds.length; i++) {
                if (sequenceIds[i] == 0) {
                    arrayList.add(Integer.valueOf(i));
                }
            }
            NDIndex nDIndex = new NDIndex("{}", new Object[]{nDList.getManager().create(arrayList.stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray())});
            nDArray.set(nDIndex, Float.valueOf(-100000.0f));
            nDArray2.set(nDIndex, Float.valueOf(-100000.0f));
            NDArray exp = nDArray.sub(nDArray.max()).exp();
            nDArray = exp.div(exp.sum());
            NDArray exp2 = nDArray2.sub(nDArray2.max()).exp();
            nDArray2 = exp2.div(exp2.sum());
        }
        nDArray.set(new NDIndex(new long[]{0}), -100000);
        nDArray2.set(new NDIndex(new long[]{0}), -100000);
        int i2 = (int) nDArray.argMax().getLong(new long[0]);
        int i3 = (int) nDArray2.argMax().getLong(new long[0]);
        if (i2 > i3) {
            i2 = i3;
            i3 = i2;
            NDArray nDArray3 = nDArray;
            nDArray = nDArray2;
            nDArray2 = nDArray3;
        }
        int i4 = (i3 - i2) + 1;
        long[] jArr = new long[i4];
        System.arraycopy(encoding.getIds(), i2, jArr, 0, i4);
        String trim = this.tokenizer.decode(jArr).trim();
        if (!this.detail) {
            return trim;
        }
        float f = nDArray.getFloat(new long[]{i2}) * nDArray2.getFloat(new long[]{i3});
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put("score", Float.valueOf(f));
        concurrentHashMap.put("start", Integer.valueOf(i2));
        concurrentHashMap.put("end", Integer.valueOf(i3));
        concurrentHashMap.put("answer", trim);
        return JsonUtils.toJson(concurrentHashMap);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
