/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.nlp.embeddings;

import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.AnnotatorModel;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.HasSimpleAnnotate;
import com.johnsnowlabs.nlp.annotators.common.Sentence;
import com.johnsnowlabs.nlp.annotators.common.SentenceSplit$;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence$;
import com.johnsnowlabs.nlp.embeddings.HasEmbeddingsProperties;
import com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings$;
import com.johnsnowlabs.storage.Database;
import com.johnsnowlabs.storage.HasStorageRef;
import com.johnsnowlabs.storage.HasStorageRef$;
import com.johnsnowlabs.storage.RocksDBConnection;
import java.io.Serializable;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005=f\u0001B\f\u0019\u0001\u0005B\u0001\u0002\u000e\u0001\u0003\u0006\u0004%\t%\u000e\u0005\t\u0007\u0002\u0011\t\u0011)A\u0005m!)A\t\u0001C\u0001\u000b\"9q\t\u0001b\u0001\n\u0003B\u0005BB(\u0001A\u0003%\u0011\nC\u0004Q\u0001\t\u0007I\u0011I)\t\rY\u0003\u0001\u0015!\u0003S\u0011\u001d9\u0006A1A\u0005BaCaa\u001a\u0001!\u0002\u0013I\u0006\"\u00025\u0001\t\u0003J\u0007bB7\u0001\u0005\u0004%\tA\u001c\u0005\u0007e\u0002\u0001\u000b\u0011B8\t\u000bM\u0004A\u0011\u0001;\t\u000b\u0011\u0003A\u0011A<\t\u000ba\u0004A\u0011B=\t\u000f\u0005\r\u0001\u0001\"\u0011\u0002\u0006!9\u00111\u0005\u0001\u0005R\u0005\u0015\u0002bBA-\u0001\u0011E\u00131L\u0004\b\u0003wB\u0002\u0012AA?\r\u00199\u0002\u0004#\u0001\u0002\u0000!1A\t\u0006C\u0001\u00033C\u0011\"a'\u0015\u0003\u0003%I!!(\u0003%M+g\u000e^3oG\u0016,UNY3eI&twm\u001d\u0006\u00033i\t!\"Z7cK\u0012$\u0017N\\4t\u0015\tYB$A\u0002oYBT!!\b\u0010\u0002\u0019)|\u0007N\\:o_^d\u0017MY:\u000b\u0003}\t1aY8n\u0007\u0001\u0019R\u0001\u0001\u0012)W9\u00022a\t\u0013'\u001b\u0005Q\u0012BA\u0013\u001b\u00059\teN\\8uCR|'/T8eK2\u0004\"a\n\u0001\u000e\u0003a\u00012aI\u0015'\u0013\tQ#DA\tICN\u001c\u0016.\u001c9mK\u0006sgn\u001c;bi\u0016\u0004\"a\n\u0017\n\u00055B\"a\u0006%bg\u0016k'-\u001a3eS:<7\u000f\u0015:pa\u0016\u0014H/[3t!\ty#'D\u00011\u0015\t\tD$A\u0004ti>\u0014\u0018mZ3\n\u0005M\u0002$!\u0004%bgN#xN]1hKJ+g-A\u0002vS\u0012,\u0012A\u000e\t\u0003o\u0001s!\u0001\u000f \u0011\u0005ebT\"\u0001\u001e\u000b\u0005m\u0002\u0013A\u0002\u001fs_>$hHC\u0001>\u0003\u0015\u00198-\u00197b\u0013\tyD(\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u0003\n\u0013aa\u0015;sS:<'BA =\u0003\u0011)\u0018\u000e\u001a\u0011\u0002\rqJg.\u001b;?)\t1c\tC\u00035\u0007\u0001\u0007a'A\npkR\u0004X\u000f^!o]>$\u0018\r^8s)f\u0004X-F\u0001J!\tQ5*D\u0001\u0001\u0013\taUJA\u0007B]:|G/\u0019;peRK\b/Z\u0005\u0003\u001dj\u0011a\u0003S1t\u001fV$\b/\u001e;B]:|G/\u0019;peRK\b/Z\u0001\u0015_V$\b/\u001e;B]:|G/\u0019;peRK\b/\u001a\u0011\u0002'%t\u0007/\u001e;B]:|G/\u0019;peRK\b/Z:\u0016\u0003I\u00032a\u0015+J\u001b\u0005a\u0014BA+=\u0005\u0015\t%O]1z\u0003QIg\u000e];u\u0003:tw\u000e^1u_J$\u0016\u0010]3tA\u0005IA-[7f]NLwN\\\u000b\u00023B\u0011!,Z\u0007\u00027*\u0011A,X\u0001\u0006a\u0006\u0014\u0018-\u001c\u0006\u0003=~\u000b!!\u001c7\u000b\u0005\u0001\f\u0017!B:qCJ\\'B\u00012d\u0003\u0019\t\u0007/Y2iK*\tA-A\u0002pe\u001eL!AZ.\u0003\u0011%sG\u000fU1sC6\f!\u0002Z5nK:\u001c\u0018n\u001c8!\u000319W\r\u001e#j[\u0016t7/[8o+\u0005Q\u0007CA*l\u0013\taGHA\u0002J]R\fq\u0002]8pY&twm\u0015;sCR,w-_\u000b\u0002_B\u0019!\f\u001d\u001c\n\u0005E\\&!\u0002)be\u0006l\u0017\u0001\u00059p_2LgnZ*ue\u0006$XmZ=!\u0003I\u0019X\r\u001e)p_2LgnZ*ue\u0006$XmZ=\u0015\u0005)+\b\"\u0002<\u000e\u0001\u00041\u0014\u0001C:ue\u0006$XmZ=\u0015\u0003\u0019\n1dY1mGVd\u0017\r^3TK:$XM\\2f\u000b6\u0014W\r\u001a3j]\u001e\u001cHC\u0001>\u007f!\r\u0019Fk\u001f\t\u0003'rL!! \u001f\u0003\u000b\u0019cw.\u0019;\t\r}|\u0001\u0019AA\u0001\u0003\u0019i\u0017\r\u001e:jqB\u00191\u000b\u0016>\u0002\u0011\u0005tgn\u001c;bi\u0016$B!a\u0002\u0002 A1\u0011\u0011BA\n\u00033qA!a\u0003\u0002\u00109\u0019\u0011(!\u0004\n\u0003uJ1!!\u0005=\u0003\u001d\u0001\u0018mY6bO\u0016LA!!\u0006\u0002\u0018\t\u00191+Z9\u000b\u0007\u0005EA\bE\u0002$\u00037I1!!\b\u001b\u0005)\teN\\8uCRLwN\u001c\u0005\b\u0003C\u0001\u0002\u0019AA\u0004\u0003-\tgN\\8uCRLwN\\:\u0002\u001d\t,gm\u001c:f\u0003:tw\u000e^1uKR!\u0011qEA&a\u0011\tI#!\u000f\u0011\r\u0005-\u0012\u0011GA\u001b\u001b\t\tiCC\u0002\u00020}\u000b1a]9m\u0013\u0011\t\u0019$!\f\u0003\u000f\u0011\u000bG/Y:fiB!\u0011qGA\u001d\u0019\u0001!1\"a\u000f\u0012\u0003\u0003\u0005\tQ!\u0001\u0002>\t\u0019q\f\n\u001a\u0012\t\u0005}\u0012Q\t\t\u0004'\u0006\u0005\u0013bAA\"y\t9aj\u001c;iS:<\u0007cA*\u0002H%\u0019\u0011\u0011\n\u001f\u0003\u0007\u0005s\u0017\u0010C\u0004\u0002NE\u0001\r!a\u0014\u0002\u000f\u0011\fG/Y:fiB\"\u0011\u0011KA+!\u0019\tY#!\r\u0002TA!\u0011qGA+\t1\t9&a\u0013\u0002\u0002\u0003\u0005)\u0011AA\u001f\u0005\ryF%M\u0001\u000eC\u001a$XM]!o]>$\u0018\r^3\u0015\t\u0005u\u0013\u0011\u0010\t\u0005\u0003?\n\u0019H\u0004\u0003\u0002b\u0005Ed\u0002BA2\u0003_rA!!\u001a\u0002n9!\u0011qMA6\u001d\rI\u0014\u0011N\u0005\u0002I&\u0011!mY\u0005\u0003A\u0006L1!a\f`\u0013\u0011\t\t\"!\f\n\t\u0005U\u0014q\u000f\u0002\n\t\u0006$\u0018M\u0012:b[\u0016TA!!\u0005\u0002.!9\u0011Q\n\nA\u0002\u0005u\u0013AE*f]R,gnY3F[\n,G\rZ5oON\u0004\"a\n\u000b\u0014\u000fQ\t\t)a\"\u0002\u0014B\u00191+a!\n\u0007\u0005\u0015EH\u0001\u0004B]f\u0014VM\u001a\t\u0006\u0003\u0013\u000byIJ\u0007\u0003\u0003\u0017S1!!$^\u0003\u0011)H/\u001b7\n\t\u0005E\u00151\u0012\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:SK\u0006$\u0017M\u00197f!\r\u0019\u0016QS\u0005\u0004\u0003/c$\u0001D*fe&\fG.\u001b>bE2,GCAA?\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005}\u0005\u0003BAQ\u0003Wk!!a)\u000b\t\u0005\u0015\u0016qU\u0001\u0005Y\u0006twM\u0003\u0002\u0002*\u0006!!.\u0019<b\u0013\u0011\ti+a)\u0003\r=\u0013'.Z2u\u0001")
public class SentenceEmbeddings
extends AnnotatorModel<SentenceEmbeddings>
implements HasSimpleAnnotate<SentenceEmbeddings>,
HasEmbeddingsProperties,
HasStorageRef {
    private final String uid;
    private final String outputAnnotatorType;
    private final String[] inputAnnotatorTypes;
    private final IntParam dimension;
    private final Param<String> poolingStrategy;
    private final Param<String> storageRef;

    public static MLReader<SentenceEmbeddings> read() {
        return SentenceEmbeddings$.MODULE$.read();
    }

    public static Object load(String string) {
        return SentenceEmbeddings$.MODULE$.load(string);
    }

    @Override
    public RocksDBConnection createDatabaseConnection(Database database) {
        return HasStorageRef.createDatabaseConnection$(this, database);
    }

    @Override
    public HasStorageRef setStorageRef(String value) {
        return HasStorageRef.setStorageRef$(this, value);
    }

    @Override
    public String getStorageRef() {
        return HasStorageRef.getStorageRef$(this);
    }

    @Override
    public void validateStorageRef(Dataset<?> dataset, String[] inputCols, String annotatorType) {
        HasStorageRef.validateStorageRef$(this, dataset, inputCols, annotatorType);
    }

    @Override
    public HasEmbeddingsProperties setDimension(int value) {
        return HasEmbeddingsProperties.setDimension$(this, value);
    }

    @Override
    public Column wrapEmbeddingsMetadata(Column col, int embeddingsDim, Option<String> embeddingsRef) {
        return HasEmbeddingsProperties.wrapEmbeddingsMetadata$(this, col, embeddingsDim, embeddingsRef);
    }

    @Override
    public Option<String> wrapEmbeddingsMetadata$default$3() {
        return HasEmbeddingsProperties.wrapEmbeddingsMetadata$default$3$(this);
    }

    @Override
    public Column wrapSentenceEmbeddingsMetadata(Column col, int embeddingsDim, Option<String> embeddingsRef) {
        return HasEmbeddingsProperties.wrapSentenceEmbeddingsMetadata$(this, col, embeddingsDim, embeddingsRef);
    }

    @Override
    public Option<String> wrapSentenceEmbeddingsMetadata$default$3() {
        return HasEmbeddingsProperties.wrapSentenceEmbeddingsMetadata$default$3$(this);
    }

    @Override
    public UserDefinedFunction dfAnnotate() {
        return HasSimpleAnnotate.dfAnnotate$(this);
    }

    @Override
    public Param<String> storageRef() {
        return this.storageRef;
    }

    @Override
    public void com$johnsnowlabs$storage$HasStorageRef$_setter_$storageRef_$eq(Param<String> x$1) {
        this.storageRef = x$1;
    }

    @Override
    public void com$johnsnowlabs$nlp$embeddings$HasEmbeddingsProperties$_setter_$dimension_$eq(IntParam x$1) {
    }

    public String uid() {
        return this.uid;
    }

    @Override
    public String outputAnnotatorType() {
        return this.outputAnnotatorType;
    }

    @Override
    public String[] inputAnnotatorTypes() {
        return this.inputAnnotatorTypes;
    }

    @Override
    public IntParam dimension() {
        return this.dimension;
    }

    @Override
    public int getDimension() {
        return BoxesRunTime.unboxToInt((Object)this.$((Param)this.dimension()));
    }

    public Param<String> poolingStrategy() {
        return this.poolingStrategy;
    }

    public SentenceEmbeddings setPoolingStrategy(String strategy) {
        SentenceEmbeddings sentenceEmbeddings;
        String string = strategy.toLowerCase();
        if ("average".equals(string)) {
            sentenceEmbeddings = (SentenceEmbeddings)this.set(this.poolingStrategy(), "AVERAGE");
        } else if ("sum".equals(string)) {
            sentenceEmbeddings = (SentenceEmbeddings)this.set(this.poolingStrategy(), "SUM");
        } else {
            throw new MatchError((Object)"poolingStrategy must be either AVERAGE or SUM");
        }
        return sentenceEmbeddings;
    }

    /*
     * WARNING - void declaration
     */
    private float[] calculateSentenceEmbeddings(float[][] matrix) {
        void var2_2;
        float[] res = (float[])Array$.MODULE$.ofDim(matrix[0].length, ClassTag$.MODULE$.Float());
        this.setDimension(matrix[0].length);
        new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(matrix[0])).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)j -> {
            block0: {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])matrix)).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                    res$1[j$1] = res[j] + matrix[i][j];
                });
                Object object = this.$(this.poolingStrategy());
                String string = "AVERAGE";
                if (object != null ? !object.equals(string) : string != null) break block0;
                res$1[j] = res[j] / (float)matrix.length;
            }
        });
        return var2_2;
    }

    @Override
    public Seq<Annotation> annotate(Seq<Annotation> annotations) {
        Seq<Sentence> sentences = SentenceSplit$.MODULE$.unpack(annotations);
        Seq<WordpieceEmbeddingsSentence> embeddingsSentences = WordpieceEmbeddingsSentence$.MODULE$.unpack(annotations);
        return (Seq)sentences.map((Function1 & Serializable & scala.Serializable)sentence -> {
            Seq embeddings2 = (Seq)embeddingsSentences.filter((Function1 & Serializable & scala.Serializable)embeddings -> BoxesRunTime.boxToBoolean((boolean)SentenceEmbeddings.$anonfun$annotate$2(sentence, embeddings)));
            float[] sentenceEmbeddings = (float[])((TraversableOnce)embeddings2.flatMap((Function1 & Serializable & scala.Serializable)tokenEmbedding -> new ArrayOps.ofFloat(SentenceEmbeddings.$anonfun$annotate$3(this, tokenEmbedding)), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float());
            return new Annotation(this.outputAnnotatorType(), sentence.start(), sentence.end(), sentence.content(), (Map<String, String>)((Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"sentence"), (Object)((Object)BoxesRunTime.boxToInteger((int)sentence.index())).toString()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"token"), (Object)sentence.content()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"pieceId"), (Object)"-1"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"isWordStart"), (Object)"true")}))), sentenceEmbeddings);
        }, Seq$.MODULE$.canBuildFrom());
    }

    @Override
    public Dataset<?> beforeAnnotate(Dataset<?> dataset) {
        String ref = HasStorageRef$.MODULE$.getStorageRefFromInput(dataset, (String[])this.$((Param)this.inputCols()), AnnotatorType$.MODULE$.WORD_EMBEDDINGS());
        Object object = this.get(this.storageRef()).isEmpty() ? this.setStorageRef(ref) : BoxedUnit.UNIT;
        return dataset;
    }

    @Override
    public Dataset<Row> afterAnnotate(Dataset<Row> dataset) {
        return dataset.withColumn(this.getOutputCol(), this.wrapSentenceEmbeddingsMetadata(dataset.col(this.getOutputCol()), BoxesRunTime.unboxToInt((Object)this.$((Param)this.dimension())), (Option<String>)new Some(this.$(this.storageRef()))));
    }

    public static final /* synthetic */ boolean $anonfun$annotate$2(Sentence sentence$1, WordpieceEmbeddingsSentence embeddings) {
        return embeddings.sentenceId() == sentence$1.index();
    }

    public static final /* synthetic */ float[] $anonfun$annotate$3(SentenceEmbeddings $this, WordpieceEmbeddingsSentence tokenEmbedding) {
        float[][] allEmbeddings = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])tokenEmbedding.tokens())).map((Function1 & Serializable & scala.Serializable)token -> token.embeddings(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
        return Predef$.MODULE$.floatArrayOps($this.calculateSentenceEmbeddings(allEmbeddings));
    }

    public SentenceEmbeddings(String uid) {
        this.uid = uid;
        HasSimpleAnnotate.$init$(this);
        HasEmbeddingsProperties.$init$(this);
        HasStorageRef.$init$(this);
        this.outputAnnotatorType = AnnotatorType$.MODULE$.SENTENCE_EMBEDDINGS();
        this.inputAnnotatorTypes = (String[])((Object[])new String[]{AnnotatorType$.MODULE$.DOCUMENT(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()});
        this.dimension = new IntParam((Identifiable)this, "dimension", "Number of embedding dimensions");
        this.poolingStrategy = new Param((Identifiable)this, "poolingStrategy", "Choose how you would like to aggregate Word Embeddings to Sentence Embeddings: AVERAGE or SUM");
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.inputCols().$minus$greater((Object)new String[]{AnnotatorType$.MODULE$.DOCUMENT(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()}), this.outputCol().$minus$greater((Object)"sentence_embeddings"), this.poolingStrategy().$minus$greater((Object)"AVERAGE"), this.dimension().$minus$greater((Object)BoxesRunTime.boxToInteger((int)100))}));
    }

    public SentenceEmbeddings() {
        this(Identifiable$.MODULE$.randomUID("SENTENCE_EMBEDDINGS"));
    }
}

