package com.johnsnowlabs.ml.ai;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings$;
import com.johnsnowlabs.ml.onnx.OnnxWrapper;
import com.johnsnowlabs.ml.tensorflow.TensorResources;
import com.johnsnowlabs.ml.tensorflow.TensorResources$;
import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper;
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants$AttentionMask$;
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants$InputIds$;
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants$LastHiddenState$;
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants$PoolerOutput$;
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureManager$;
import com.johnsnowlabs.ml.util.ModelArch$;
import com.johnsnowlabs.ml.util.ONNX$;
import com.johnsnowlabs.ml.util.TensorFlow$;
import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.annotators.common.IndexedToken;
import com.johnsnowlabs.nlp.annotators.common.Sentence;
import com.johnsnowlabs.nlp.annotators.common.TokenPiece;
import com.johnsnowlabs.nlp.annotators.common.TokenPieceEmbeddings;
import com.johnsnowlabs.nlp.annotators.common.TokenPieceEmbeddings$;
import com.johnsnowlabs.nlp.annotators.common.TokenizedSentence;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence;
import com.johnsnowlabs.nlp.annotators.common.WordpieceTokenizedSentence;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: RoBerta.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-g!\u0002\u000f\u001e\u0001\u0005*\u0003\u0002C\u0018\u0001\u0005\u000b\u0007I\u0011A\u0019\t\u0011m\u0002!\u0011!Q\u0001\nIB\u0001\u0002\u0010\u0001\u0003\u0006\u0004%\t!\u0010\u0005\t\u000b\u0002\u0011\t\u0011)A\u0005}!Aa\t\u0001B\u0001B\u0003%q\t\u0003\u0005K\u0001\t\u0005\t\u0015!\u0003H\u0011!Y\u0005A!A!\u0002\u00139\u0005\u0002\u0003'\u0001\u0005\u0003\u0005\u000b\u0011B'\t\u0011Q\u0003!\u0011!Q\u0001\nUC\u0001\u0002\u001a\u0001\u0003\u0002\u0003\u0006I!\u0019\u0005\u0006K\u0002!\tA\u001a\u0005\bc\u0002\u0011\r\u0011\"\u0001s\u0011\u0019\u0019\b\u0001)A\u0005-\"9A\u000f\u0001b\u0001\n\u0003)\bB\u0002<\u0001A\u0003%\u0011\rC\u0003x\u0001\u0011%\u0001\u0010C\u0003}\u0001\u0011\u0005Q\u0010C\u0004\u0002\"\u0001!\t!a\t\t\u000f\u0005\u001d\u0002\u0001\"\u0001\u0002*!9\u00111\u000e\u0001\u0005\u0002\u00055tACAF;\u0005\u0005\t\u0012A\u0011\u0002\u000e\u001aIA$HA\u0001\u0012\u0003\t\u0013q\u0012\u0005\u0007KZ!\t!!%\t\u0013\u0005Me#%A\u0005\u0002\u0005U\u0005\"CAV-E\u0005I\u0011AAW\u0011%\t\tLFI\u0001\n\u0003\t\u0019\fC\u0005\u00028Z\t\t\u0011\"\u0003\u0002:\n9!k\u001c\"feR\f'B\u0001\u0010 \u0003\t\t\u0017N\u0003\u0002!C\u0005\u0011Q\u000e\u001c\u0006\u0003E\r\nAB[8i]Ntwn\u001e7bENT\u0011\u0001J\u0001\u0004G>l7c\u0001\u0001'YA\u0011qEK\u0007\u0002Q)\t\u0011&A\u0003tG\u0006d\u0017-\u0003\u0002,Q\t1\u0011I\\=SK\u001a\u0004\"aJ\u0017\n\u00059B#\u0001D*fe&\fG.\u001b>bE2,\u0017!\u0005;f]N|'O\u001a7po^\u0013\u0018\r\u001d9fe\u000e\u0001Q#\u0001\u001a\u0011\u0007\u001d\u001aT'\u0003\u00025Q\t1q\n\u001d;j_:\u0004\"AN\u001d\u000e\u0003]R!\u0001O\u0010\u0002\u0015Q,gn]8sM2|w/\u0003\u0002;o\t\tB+\u001a8t_J4Gn\\<Xe\u0006\u0004\b/\u001a:\u0002%Q,gn]8sM2|wo\u0016:baB,'\u000fI\u0001\f_:t\u0007p\u0016:baB,'/F\u0001?!\r93g\u0010\t\u0003\u0001\u000ek\u0011!\u0011\u0006\u0003\u0005~\tAa\u001c8oq&\u0011A)\u0011\u0002\f\u001f:t\u0007p\u0016:baB,'/\u0001\u0007p]:DxK]1qa\u0016\u0014\b%\u0001\u000btK:$XM\\2f'R\f'\u000f\u001e+pW\u0016t\u0017\n\u001a\t\u0003O!K!!\u0013\u0015\u0003\u0007%sG/\u0001\ntK:$XM\\2f\u000b:$Gk\\6f]&#\u0017A\u00039bIR{7.\u001a8JI\u0006\u00012m\u001c8gS\u001e\u0004&o\u001c;p\u0005f$Xm\u001d\t\u0004OMr\u0005cA\u0014P#&\u0011\u0001\u000b\u000b\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003OIK!a\u0015\u0015\u0003\t\tKH/Z\u0001\u000bg&<g.\u0019;ve\u0016\u001c\bcA\u00144-B!qKX1b\u001d\tAF\f\u0005\u0002ZQ5\t!L\u0003\u0002\\a\u00051AH]8pizJ!!\u0018\u0015\u0002\rA\u0013X\rZ3g\u0013\ty\u0006MA\u0002NCBT!!\u0018\u0015\u0011\u0005]\u0013\u0017BA2a\u0005\u0019\u0019FO]5oO\u0006IQn\u001c3fY\u0006\u00138\r[\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0013\u001dL'n\u001b7n]>\u0004\bC\u00015\u0001\u001b\u0005i\u0002\"B\u0018\f\u0001\u0004\u0011\u0004\"\u0002\u001f\f\u0001\u0004q\u0004\"\u0002$\f\u0001\u00049\u0005\"\u0002&\f\u0001\u00049\u0005\"B&\f\u0001\u00049\u0005b\u0002'\f!\u0003\u0005\r!\u0014\u0005\b).\u0001\n\u00111\u0001V\u0011\u001d!7\u0002%AA\u0002\u0005\fAc\u0018;g%>\u0014UM\u001d;b'&<g.\u0019;ve\u0016\u001cX#\u0001,\u0002+}#hMU8CKJ$\u0018mU5h]\u0006$XO]3tA\u0005qA-\u001a;fGR,G-\u00128hS:,W#A1\u0002\u001f\u0011,G/Z2uK\u0012,enZ5oK\u0002\nQb]3tg&|gnV1s[V\u0004H#A=\u0011\u0005\u001dR\u0018BA>)\u0005\u0011)f.\u001b;\u0002\u0007Q\fw\rF\u0002\u007f\u00033\u0001Ra`A\u0005\u0003\u001fqA!!\u0001\u0002\u00069\u0019\u0011,a\u0001\n\u0003%J1!a\u0002)\u0003\u001d\u0001\u0018mY6bO\u0016LA!a\u0003\u0002\u000e\t\u00191+Z9\u000b\u0007\u0005\u001d\u0001\u0006\u0005\u0003(\u001f\u0006E\u0001\u0003B\u0014P\u0003'\u00012aJA\u000b\u0013\r\t9\u0002\u000b\u0002\u0006\r2|\u0017\r\u001e\u0005\b\u00037\t\u0002\u0019AA\u000f\u0003\u0015\u0011\u0017\r^2i!\u0015y\u0018\u0011BA\u0010!\r9sjR\u0001\fi\u0006<7+Z9vK:\u001cW\r\u0006\u0003\u0002\u0010\u0005\u0015\u0002bBA\u000e%\u0001\u0007\u0011QD\u0001\baJ,G-[2u)1\tY#!\u0011\u0002N\u0005e\u0013QLA1!\u0015y\u0018\u0011BA\u0017!\u0011\ty#!\u0010\u000e\u0005\u0005E\"\u0002BA\u001a\u0003k\taaY8n[>t'\u0002BA\u001c\u0003s\t!\"\u00198o_R\fGo\u001c:t\u0015\r\tY$I\u0001\u0004]2\u0004\u0018\u0002BA \u0003c\u00111dV8sIBLWmY3F[\n,G\rZ5oON\u001cVM\u001c;f]\u000e,\u0007bBA\"'\u0001\u0007\u0011QI\u0001\ng\u0016tG/\u001a8dKN\u0004Ra`A\u0005\u0003\u000f\u0002B!a\f\u0002J%!\u00111JA\u0019\u0005i9vN\u001d3qS\u0016\u001cW\rV8lK:L'0\u001a3TK:$XM\\2f\u0011\u001d\tye\u0005a\u0001\u0003#\nac\u001c:jO&t\u0017\r\u001c+pW\u0016t7+\u001a8uK:\u001cWm\u001d\t\u0006\u007f\u0006%\u00111\u000b\t\u0005\u0003_\t)&\u0003\u0003\u0002X\u0005E\"!\u0005+pW\u0016t\u0017N_3e'\u0016tG/\u001a8dK\"1\u00111L\nA\u0002\u001d\u000b\u0011BY1uG\"\u001c\u0016N_3\t\r\u0005}3\u00031\u0001H\u0003Ei\u0017\r_*f]R,gnY3MK:<G\u000f\u001b\u0005\b\u0003G\u001a\u0002\u0019AA3\u00035\u0019\u0017m]3TK:\u001c\u0018\u000e^5wKB\u0019q%a\u001a\n\u0007\u0005%\u0004FA\u0004C_>dW-\u00198\u0002\u001fA\u0014X\rZ5diN+\u0017/^3oG\u0016$\"\"a\u001c\u0002z\u0005u\u0014qQAE!\u0015y\u0018\u0011BA9!\u0011\t\u0019(!\u001e\u000e\u0005\u0005e\u0012\u0002BA<\u0003s\u0011!\"\u00118o_R\fG/[8o\u0011\u001d\tY\b\u0006a\u0001\u0003\u000b\na\u0001^8lK:\u001c\bbBA\")\u0001\u0007\u0011q\u0010\t\u0006\u007f\u0006%\u0011\u0011\u0011\t\u0005\u0003_\t\u0019)\u0003\u0003\u0002\u0006\u0006E\"\u0001C*f]R,gnY3\t\r\u0005mC\u00031\u0001H\u0011\u0019\ty\u0006\u0006a\u0001\u000f\u00069!k\u001c\"feR\f\u0007C\u00015\u0017'\r1b\u0005\f\u000b\u0003\u0003\u001b\u000b1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u00122TCAALU\ri\u0015\u0011T\u0016\u0003\u00037\u0003B!!(\u0002(6\u0011\u0011q\u0014\u0006\u0005\u0003C\u000b\u0019+A\u0005v]\u000eDWmY6fI*\u0019\u0011Q\u0015\u0015\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002*\u0006}%!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\u0006YB\u0005\\3tg&t\u0017\u000e\u001e\u0013he\u0016\fG/\u001a:%I\u00164\u0017-\u001e7uI]*\"!a,+\u0007U\u000bI*A\u000e%Y\u0016\u001c8/\u001b8ji\u0012:'/Z1uKJ$C-\u001a4bk2$H\u0005O\u000b\u0003\u0003kS3!YAM\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005m\u0006\u0003BA_\u0003\u000fl!!a0\u000b\t\u0005\u0005\u00171Y\u0001\u0005Y\u0006twM\u0003\u0002\u0002F\u0006!!.\u0019<b\u0013\u0011\tI-a0\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/johnsnowlabs/ml/ai/RoBerta.class */
public class RoBerta implements Serializable {
    private final Option<TensorflowWrapper> tensorflowWrapper;
    private final Option<OnnxWrapper> onnxWrapper;
    private final int sentenceStartTokenId;
    private final int sentenceEndTokenId;
    private final int padTokenId;
    private final Option<byte[]> configProtoBytes;
    private final Option<Map<String, String>> signatures;
    private final String modelArch;
    private final Map<String, String> _tfRoBertaSignatures;
    private final String detectedEngine;

    public Option<TensorflowWrapper> tensorflowWrapper() {
        return this.tensorflowWrapper;
    }

    public Option<OnnxWrapper> onnxWrapper() {
        return this.onnxWrapper;
    }

    public Map<String, String> _tfRoBertaSignatures() {
        return this._tfRoBertaSignatures;
    }

    public String detectedEngine() {
        return this.detectedEngine;
    }

    private void sessionWarmup() {
        int[] iArr = {0, 7939, 18, 3279, 658, 5, 19374, 13, 5, 78, 42752, 4, 2};
        String str = this.modelArch;
        String wordEmbeddings = ModelArch$.MODULE$.wordEmbeddings();
        if (str != null ? str.equals(wordEmbeddings) : wordEmbeddings == null) {
            tag((Seq) new $colon.colon(iArr, Nil$.MODULE$));
            return;
        }
        String str2 = this.modelArch;
        String sentenceEmbeddings = ModelArch$.MODULE$.sentenceEmbeddings();
        if (str2 == null) {
            if (sentenceEmbeddings != null) {
                return;
            }
        } else if (!str2.equals(sentenceEmbeddings)) {
            return;
        }
        tagSequence((Seq) new $colon.colon(iArr, Nil$.MODULE$));
    }

    public Seq<float[][]> tag(Seq<int[]> seq) {
        float[] fArr;
        int unboxToInt = BoxesRunTime.unboxToInt(((TraversableOnce) seq.map(iArr -> {
            return BoxesRunTime.boxToInteger($anonfun$tag$1(iArr));
        }, Seq$.MODULE$.canBuildFrom())).max(Ordering$Int$.MODULE$));
        int length = seq.length();
        String detectedEngine = detectedEngine();
        String name = ONNX$.MODULE$.name();
        if (name != null ? !name.equals(detectedEngine) : detectedEngine != null) {
            TensorResources tensorResources = new TensorResources();
            Tuple2<Tensor, Tensor> prepareBatchTensors = PrepareEmbeddings$.MODULE$.prepareBatchTensors(tensorResources, seq, unboxToInt, length, this.padTokenId);
            if (prepareBatchTensors == null) {
                throw new MatchError(prepareBatchTensors);
            }
            Tuple2 tuple2 = new Tuple2((Tensor) prepareBatchTensors._1(), (Tensor) prepareBatchTensors._2());
            Tensor tensor = (Tensor) tuple2._1();
            Tensor tensor2 = (Tensor) tuple2._2();
            TensorflowWrapper tensorflowWrapper = (TensorflowWrapper) tensorflowWrapper().get();
            Session.Runner runner = tensorflowWrapper.getTFSessionWithSignature(this.configProtoBytes, false, tensorflowWrapper.getTFSessionWithSignature$default$3(), this.signatures).runner();
            runner.feed((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$InputIds$.MODULE$.key(), () -> {
                return "missing_input_id_key";
            }), tensor).feed((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$AttentionMask$.MODULE$.key(), () -> {
                return "missing_input_mask_key";
            }), tensor2).fetch((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$LastHiddenState$.MODULE$.key(), () -> {
                return "missing_sequence_output_key";
            }));
            Buffer<Tensor> buffer = (Buffer) JavaConverters$.MODULE$.asScalaBufferConverter(runner.run()).asScala();
            float[] extractFloats = TensorResources$.MODULE$.extractFloats((Tensor) buffer.head(), TensorResources$.MODULE$.extractFloats$default$2());
            tensor.close();
            tensor2.close();
            tensorResources.clearSession(buffer);
            tensorResources.clearTensors();
            fArr = extractFloats;
        } else {
            OnnxWrapper onnxWrapper = (OnnxWrapper) onnxWrapper().get();
            Tuple2<OrtSession, OrtEnvironment> session = onnxWrapper.getSession(onnxWrapper.getSession$default$1());
            if (session == null) {
                throw new MatchError(session);
            }
            Tuple2 tuple22 = new Tuple2((OrtSession) session._1(), (OrtEnvironment) session._2());
            OrtSession ortSession = (OrtSession) tuple22._1();
            OrtEnvironment ortEnvironment = (OrtEnvironment) tuple22._2();
            OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, ((TraversableOnce) seq.map(iArr2 -> {
                return (long[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr2)).map(i -> {
                    return i;
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Long()));
            }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Long.TYPE))));
            OnnxTensor createTensor2 = OnnxTensor.createTensor(ortEnvironment, ((TraversableOnce) seq.map(iArr3 -> {
                return (long[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr3)).map(i -> {
                    return ((long) i) == 0 ? 0L : 1L;
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Long()));
            }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Long.TYPE))));
            OrtSession.Result run = ortSession.run((java.util.Map) JavaConverters$.MODULE$.mapAsJavaMapConverter(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("input_ids"), createTensor), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("attention_mask"), createTensor2)}))).asJava());
            try {
                float[] array = ((OnnxTensor) run.get("last_hidden_state").get()).getFloatBuffer().array();
                createTensor.close();
                createTensor2.close();
                fArr = array;
            } finally {
                if (run != null) {
                    run.close();
                }
            }
        }
        return PrepareEmbeddings$.MODULE$.prepareBatchWordEmbeddings(seq, fArr, unboxToInt, length);
    }

    public float[][] tagSequence(Seq<int[]> seq) {
        int unboxToInt = BoxesRunTime.unboxToInt(((TraversableOnce) seq.map(iArr -> {
            return BoxesRunTime.boxToInteger($anonfun$tagSequence$1(iArr));
        }, Seq$.MODULE$.canBuildFrom())).max(Ordering$Int$.MODULE$));
        int length = seq.length();
        TensorResources tensorResources = new TensorResources();
        Tuple2<Tensor, Tensor> prepareBatchTensors = PrepareEmbeddings$.MODULE$.prepareBatchTensors(tensorResources, seq, unboxToInt, length, this.padTokenId);
        if (prepareBatchTensors == null) {
            throw new MatchError(prepareBatchTensors);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) prepareBatchTensors._1(), (Tensor) prepareBatchTensors._2());
        Tensor tensor = (Tensor) tuple2._1();
        Tensor tensor2 = (Tensor) tuple2._2();
        TensorflowWrapper tensorflowWrapper = (TensorflowWrapper) tensorflowWrapper().get();
        Session.Runner runner = tensorflowWrapper.getTFSessionWithSignature(this.configProtoBytes, false, tensorflowWrapper.getTFSessionWithSignature$default$3(), this.signatures).runner();
        runner.feed((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$InputIds$.MODULE$.key(), () -> {
            return "missing_input_id_key";
        }), tensor).feed((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$AttentionMask$.MODULE$.key(), () -> {
            return "missing_input_mask_key";
        }), tensor2).fetch((String) _tfRoBertaSignatures().getOrElse(ModelSignatureConstants$PoolerOutput$.MODULE$.key(), () -> {
            return "missing_pooled_output_key";
        }));
        Buffer<Tensor> buffer = (Buffer) JavaConverters$.MODULE$.asScalaBufferConverter(runner.run()).asScala();
        float[] extractFloats = TensorResources$.MODULE$.extractFloats((Tensor) buffer.head(), TensorResources$.MODULE$.extractFloats$default$2());
        tensor.close();
        tensor2.close();
        tensorResources.clearSession(buffer);
        tensorResources.clearTensors();
        return (float[][]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(extractFloats)).grouped(extractFloats.length / length).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
    }

    public Seq<WordpieceEmbeddingsSentence> predict(Seq<WordpieceTokenizedSentence> seq, Seq<TokenizedSentence> seq2, int i, int i2, boolean z) {
        return ((IterableLike) seq.zipWithIndex(Seq$.MODULE$.canBuildFrom())).grouped(i).flatMap(seq3 -> {
            return (Seq) ((TraversableLike) seq3.zip(this.tag(PrepareEmbeddings$.MODULE$.prepareBatchInputsWithPadding(seq3, i2, this.sentenceStartTokenId, this.sentenceEndTokenId, this.padTokenId)), Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Tuple2 tuple2 = (Tuple2) tuple2._1();
                float[][] fArr = (float[][]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((float[][]) tuple2._2())).slice(1, ((WordpieceTokenizedSentence) tuple2._1()).tokens().length + 1);
                TokenizedSentence tokenizedSentence = (TokenizedSentence) seq2.apply(tuple2._2$mcI$sp());
                return new WordpieceEmbeddingsSentence((TokenPieceEmbeddings[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((WordpieceTokenizedSentence) tuple2._1()).tokens())).zip(Predef$.MODULE$.wrapRefArray(fArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).flatMap(tuple22 -> {
                    if (tuple22 == null) {
                        throw new MatchError(tuple22);
                    }
                    TokenPiece tokenPiece = (TokenPiece) tuple22._1();
                    float[] fArr2 = (float[]) tuple22._2();
                    TokenPieceEmbeddings apply = TokenPieceEmbeddings$.MODULE$.apply(tokenPiece, fArr2);
                    return Option$.MODULE$.option2Iterable(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tokenizedSentence.indexedTokens())).find(indexedToken -> {
                        return BoxesRunTime.boxToBoolean($anonfun$predict$4(apply, indexedToken));
                    }).map(indexedToken2 -> {
                        return TokenPieceEmbeddings$.MODULE$.apply(new TokenPiece(apply.wordpiece(), z ? indexedToken2.token() : indexedToken2.token().toLowerCase(), apply.pieceId(), apply.isWordStart(), indexedToken2.begin(), indexedToken2.end()), fArr2);
                    }));
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(TokenPieceEmbeddings.class))), tokenizedSentence.sentenceIndex());
            }, Seq$.MODULE$.canBuildFrom());
        }).toSeq();
    }

    public Seq<Annotation> predictSequence(Seq<WordpieceTokenizedSentence> seq, Seq<Sentence> seq2, int i, int i2) {
        return ((IterableLike) ((IterableLike) seq.zip(seq2, Seq$.MODULE$.canBuildFrom())).zipWithIndex(Seq$.MODULE$.canBuildFrom())).grouped(i).flatMap(seq3 -> {
            Seq<Tuple2<WordpieceTokenizedSentence, Object>> seq3 = (Seq) seq3.map(tuple2 -> {
                return new Tuple2(((Tuple2) tuple2._1())._1(), BoxesRunTime.boxToInteger(tuple2._2$mcI$sp()));
            }, Seq$.MODULE$.canBuildFrom());
            return (Seq) ((TraversableLike) ((Seq) seq3.map(tuple22 -> {
                return (Sentence) ((Tuple2) tuple22._1())._2();
            }, Seq$.MODULE$.canBuildFrom())).zip(Predef$.MODULE$.wrapRefArray(this.tagSequence(PrepareEmbeddings$.MODULE$.prepareBatchInputsWithPadding(seq3, i2, this.sentenceStartTokenId, this.sentenceEndTokenId, this.padTokenId))), Seq$.MODULE$.canBuildFrom())).map(tuple23 -> {
                if (tuple23 == null) {
                    throw new MatchError(tuple23);
                }
                Sentence sentence = (Sentence) tuple23._1();
                return new Annotation(AnnotatorType$.MODULE$.SENTENCE_EMBEDDINGS(), sentence.start(), sentence.end(), sentence.content(), Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("sentence"), Integer.toString(sentence.index())), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("token"), sentence.content()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("pieceId"), "-1"), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("isWordStart"), "true")})), (float[]) tuple23._2());
            }, Seq$.MODULE$.canBuildFrom());
        }).toSeq();
    }

    public static final /* synthetic */ int $anonfun$tag$1(int[] iArr) {
        return iArr.length;
    }

    public static final /* synthetic */ int $anonfun$tagSequence$1(int[] iArr) {
        return iArr.length;
    }

    public static final /* synthetic */ boolean $anonfun$predict$4(TokenPieceEmbeddings tokenPieceEmbeddings, IndexedToken indexedToken) {
        return indexedToken.begin() == tokenPieceEmbeddings.begin();
    }

    public RoBerta(Option<TensorflowWrapper> option, Option<OnnxWrapper> option2, int i, int i2, int i3, Option<byte[]> option3, Option<Map<String, String>> option4, String str) {
        this.tensorflowWrapper = option;
        this.onnxWrapper = option2;
        this.sentenceStartTokenId = i;
        this.sentenceEndTokenId = i2;
        this.padTokenId = i3;
        this.configProtoBytes = option3;
        this.signatures = option4;
        this.modelArch = str;
        this._tfRoBertaSignatures = (Map) option4.getOrElse(() -> {
            return ModelSignatureManager$.MODULE$.apply(ModelSignatureManager$.MODULE$.apply$default$1(), ModelSignatureManager$.MODULE$.apply$default$2(), ModelSignatureManager$.MODULE$.apply$default$3(), ModelSignatureManager$.MODULE$.apply$default$4(), ModelSignatureManager$.MODULE$.apply$default$5(), ModelSignatureManager$.MODULE$.apply$default$6());
        });
        this.detectedEngine = option.isDefined() ? TensorFlow$.MODULE$.name() : option2.isDefined() ? ONNX$.MODULE$.name() : TensorFlow$.MODULE$.name();
        sessionWarmup();
    }
}
