/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.nlp.qa;

import ai.djl.mxnet.zoo.nlp.qa.BertDataParser;
import ai.djl.mxnet.zoo.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import java.io.IOException;
import java.util.List;

public class BertQATranslator
implements Translator<QAInput, String> {
    private List<String> tokens;

    BertQATranslator() {
    }

    public Batchifier getBatchifier() {
        return null;
    }

    public NDList processInput(TranslatorContext ctx, QAInput input) throws IOException {
        BertDataParser parser = (BertDataParser)ctx.getModel().getArtifact("vocab.json", BertDataParser::parse);
        List<String> tokenQ = BertDataParser.tokenizer(input.getQuestion().toLowerCase());
        List<String> tokenA = BertDataParser.tokenizer(input.getParagraph().toLowerCase());
        int validLength = tokenQ.size() + tokenA.size();
        List<Float> tokenTypes = BertDataParser.getTokenTypes(tokenQ, tokenA, input.getSeqLength());
        this.tokens = BertDataParser.formTokens(tokenQ, tokenA, input.getSeqLength());
        List<Integer> indexes = parser.token2idx(this.tokens);
        float[] types = Utils.toFloatArray(tokenTypes);
        float[] indexesFloat = Utils.toFloatArray(indexes);
        int seqLength = input.getSeqLength();
        NDManager manager = ctx.getNDManager();
        NDArray data0 = manager.create(indexesFloat, new Shape(new long[]{1L, seqLength}));
        data0.setName("data0");
        NDArray data1 = manager.create(types, new Shape(new long[]{1L, seqLength}));
        data1.setName("data1");
        NDArray data2 = manager.create(new float[]{validLength});
        data2.setName("data2");
        return new NDList(new NDArray[]{data0, data1, data2});
    }

    public String processOutput(TranslatorContext ctx, NDList list) {
        NDArray array = list.singletonOrThrow();
        NDList output = array.split(2L, 2);
        NDArray startLogits = ((NDArray)output.get(0)).reshape(new Shape(new long[]{1L, -1L}));
        NDArray endLogits = ((NDArray)output.get(1)).reshape(new Shape(new long[]{1L, -1L}));
        NDArray startProb = startLogits.softmax(-1);
        NDArray endProb = endLogits.softmax(-1);
        int startIdx = (int)startProb.argMax(1).getLong(new long[0]);
        int endIdx = (int)endProb.argMax(1).getLong(new long[0]);
        return this.tokens.subList(startIdx, endIdx + 1).toString();
    }
}

