package opennlp.dl.vectors;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.File;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import opennlp.dl.AbstractDL;
import opennlp.dl.Tokens;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;

/* loaded from: input_file:opennlp/dl/vectors/SentenceVectorsDL.class */
public class SentenceVectorsDL extends AbstractDL {
    public SentenceVectorsDL(File file, File file2) throws OrtException, IOException {
        this.env = OrtEnvironment.getEnvironment();
        this.session = this.env.createSession(file.getPath(), new OrtSession.SessionOptions());
        this.vocab = loadVocab(new File(file2.getPath()));
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
    }

    public float[] getVectors(String str) throws OrtException {
        Tokens tokens = tokenize(str, this.tokenizer, this.vocab);
        HashMap hashMap = new HashMap();
        hashMap.put(AbstractDL.INPUT_IDS, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.ids()), new long[]{1, tokens.ids().length}));
        hashMap.put(AbstractDL.ATTENTION_MASK, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.mask()), new long[]{1, tokens.mask().length}));
        hashMap.put(AbstractDL.TOKEN_TYPE_IDS, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.types()), new long[]{1, tokens.types().length}));
        return ((float[][][]) this.session.run(hashMap).get(0).getValue())[0][0];
    }

    private Tokens tokenize(String str, Tokenizer tokenizer, Map<String, Integer> map) {
        String[] strArr = tokenizer.tokenize(str);
        int[] iArr = new int[strArr.length];
        long[] jArr = new long[iArr.length];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = map.get(strArr[i]).intValue();
        }
        long[] array = Arrays.stream(iArr).mapToLong(i2 -> {
            return i2;
        }).toArray();
        long[] jArr2 = new long[iArr.length];
        Arrays.fill(jArr2, 1L);
        return new Tokens(strArr, array, jArr, jArr2);
    }
}
