/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.nlp.embeddings;

import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.Annotation$;
import com.johnsnowlabs.nlp.AnnotatorModel;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.HasSimpleAnnotate;
import com.johnsnowlabs.nlp.annotators.common.TokenPieceEmbeddings;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence;
import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence$;
import com.johnsnowlabs.nlp.embeddings.ChunkEmbeddings$;
import java.io.Serializable;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.StructField;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Map;
import scala.collection.Map$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005ed\u0001\u0002\r\u001a\u0001\tB\u0001\u0002\f\u0001\u0003\u0006\u0004%\t%\f\u0005\tw\u0001\u0011\t\u0011)A\u0005]!)A\b\u0001C\u0001{!9q\b\u0001b\u0001\n\u0003\u0002\u0005BB$\u0001A\u0003%\u0011\tC\u0004I\u0001\t\u0007I\u0011I%\t\r9\u0003\u0001\u0015!\u0003K\u0011\u001dy\u0005A1A\u0005\u0002ACaa\u0018\u0001!\u0002\u0013\t\u0006b\u00021\u0001\u0005\u0004%\t!\u0019\u0005\u0007K\u0002\u0001\u000b\u0011\u00022\t\u000b\u0019\u0004A\u0011A4\t\u000b)\u0004A\u0011A6\t\u000bE\u0004A\u0011A\u0017\t\u000bI\u0004A\u0011A:\t\u000bq\u0002A\u0011\u0001;\t\u000bU\u0004A\u0011\u0002<\t\u000by\u0004A\u0011I@\t\u000f\u0005u\u0001\u0001\"\u0015\u0002 \u001d9\u0011QI\r\t\u0002\u0005\u001dcA\u0002\r\u001a\u0011\u0003\tI\u0005\u0003\u0004=+\u0011\u0005\u00111\r\u0005\n\u0003K*\u0012\u0011!C\u0005\u0003O\u0012qb\u00115v].,UNY3eI&twm\u001d\u0006\u00035m\t!\"Z7cK\u0012$\u0017N\\4t\u0015\taR$A\u0002oYBT!AH\u0010\u0002\u0019)|\u0007N\\:o_^d\u0017MY:\u000b\u0003\u0001\n1aY8n\u0007\u0001\u00192\u0001A\u0012*!\r!SeJ\u0007\u00027%\u0011ae\u0007\u0002\u000f\u0003:tw\u000e^1u_Jlu\u000eZ3m!\tA\u0003!D\u0001\u001a!\r!#fJ\u0005\u0003Wm\u0011\u0011\u0003S1t'&l\u0007\u000f\\3B]:|G/\u0019;f\u0003\r)\u0018\u000eZ\u000b\u0002]A\u0011q\u0006\u000f\b\u0003aY\u0002\"!\r\u001b\u000e\u0003IR!aM\u0011\u0002\rq\u0012xn\u001c;?\u0015\u0005)\u0014!B:dC2\f\u0017BA\u001c5\u0003\u0019\u0001&/\u001a3fM&\u0011\u0011H\u000f\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005]\"\u0014\u0001B;jI\u0002\na\u0001P5oSRtDCA\u0014?\u0011\u0015a3\u00011\u0001/\u0003MyW\u000f\u001e9vi\u0006sgn\u001c;bi>\u0014H+\u001f9f+\u0005\t\u0005C\u0001\"D\u001b\u0005\u0001\u0011B\u0001#F\u00055\teN\\8uCR|'\u000fV=qK&\u0011ai\u0007\u0002\u0017\u0011\u0006\u001cx*\u001e;qkR\feN\\8uCR|'\u000fV=qK\u0006!r.\u001e;qkR\feN\\8uCR|'\u000fV=qK\u0002\n1#\u001b8qkR\feN\\8uCR|'\u000fV=qKN,\u0012A\u0013\t\u0004\u00172\u000bU\"\u0001\u001b\n\u00055#$!B!se\u0006L\u0018\u0001F5oaV$\u0018I\u001c8pi\u0006$xN\u001d+za\u0016\u001c\b%A\bq_>d\u0017N\\4TiJ\fG/Z4z+\u0005\t\u0006c\u0001*^]5\t1K\u0003\u0002U+\u0006)\u0001/\u0019:b[*\u0011akV\u0001\u0003[2T!\u0001W-\u0002\u000bM\u0004\u0018M]6\u000b\u0005i[\u0016AB1qC\u000eDWMC\u0001]\u0003\ry'oZ\u0005\u0003=N\u0013Q\u0001U1sC6\f\u0001\u0003]8pY&twm\u0015;sCR,w-\u001f\u0011\u0002\u000fM\\\u0017\u000e](P-V\t!\r\u0005\u0002SG&\u0011Am\u0015\u0002\r\u0005>|G.Z1o!\u0006\u0014\u0018-\\\u0001\tg.L\u0007oT(WA\u0005\u00112/\u001a;Q_>d\u0017N\\4TiJ\fG/Z4z)\t\u0011\u0005\u000eC\u0003j\u0019\u0001\u0007a&\u0001\u0005tiJ\fG/Z4z\u0003)\u0019X\r^*lSB|uJ\u0016\u000b\u0003\u00052DQ!\\\u0007A\u00029\fQA^1mk\u0016\u0004\"aS8\n\u0005A$$a\u0002\"p_2,\u0017M\\\u0001\u0013O\u0016$\bk\\8mS:<7\u000b\u001e:bi\u0016<\u00170\u0001\u0006hKR\u001c6.\u001b9P\u001fZ+\u0012A\u001c\u000b\u0002O\u0005A2-\u00197dk2\fG/Z\"ik:\\W)\u001c2fI\u0012LgnZ:\u0015\u0005]\\\bcA&MqB\u00111*_\u0005\u0003uR\u0012QA\u00127pCRDQ\u0001`\tA\u0002u\fa!\\1ue&D\bcA&Mo\u0006A\u0011M\u001c8pi\u0006$X\r\u0006\u0003\u0002\u0002\u0005e\u0001CBA\u0002\u0003\u001b\t\u0019B\u0004\u0003\u0002\u0006\u0005%abA\u0019\u0002\b%\tQ'C\u0002\u0002\fQ\nq\u0001]1dW\u0006<W-\u0003\u0003\u0002\u0010\u0005E!aA*fc*\u0019\u00111\u0002\u001b\u0011\u0007\u0011\n)\"C\u0002\u0002\u0018m\u0011!\"\u00118o_R\fG/[8o\u0011\u001d\tYB\u0005a\u0001\u0003\u0003\t1\"\u00198o_R\fG/[8og\u0006i\u0011M\u001a;fe\u0006sgn\u001c;bi\u0016$B!!\t\u0002BA!\u00111EA\u001e\u001d\u0011\t)#a\u000e\u000f\t\u0005\u001d\u00121\u0007\b\u0005\u0003S\t\tD\u0004\u0003\u0002,\u0005=bbA\u0019\u0002.%\tA,\u0003\u0002[7&\u0011\u0001,W\u0005\u0004\u0003k9\u0016aA:rY&!\u00111BA\u001d\u0015\r\t)dV\u0005\u0005\u0003{\tyDA\u0005ECR\fgI]1nK*!\u00111BA\u001d\u0011\u001d\t\u0019e\u0005a\u0001\u0003C\tq\u0001Z1uCN,G/A\bDQVt7.R7cK\u0012$\u0017N\\4t!\tAScE\u0004\u0016\u0003\u0017\n\t&!\u0018\u0011\u0007-\u000bi%C\u0002\u0002PQ\u0012a!\u00118z%\u00164\u0007#BA*\u00033:SBAA+\u0015\r\t9&V\u0001\u0005kRLG.\u0003\u0003\u0002\\\u0005U#!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7OU3bI\u0006\u0014G.\u001a\t\u0004\u0017\u0006}\u0013bAA1i\ta1+\u001a:jC2L'0\u00192mKR\u0011\u0011qI\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002jA!\u00111NA;\u001b\t\tiG\u0003\u0003\u0002p\u0005E\u0014\u0001\u00027b]\u001eT!!a\u001d\u0002\t)\fg/Y\u0005\u0005\u0003o\niG\u0001\u0004PE*,7\r\u001e")
public class ChunkEmbeddings
extends AnnotatorModel<ChunkEmbeddings>
implements HasSimpleAnnotate<ChunkEmbeddings> {
    private final String uid;
    private final String outputAnnotatorType;
    private final String[] inputAnnotatorTypes;
    private final Param<String> poolingStrategy;
    private final BooleanParam skipOOV;

    public static MLReader<ChunkEmbeddings> read() {
        return ChunkEmbeddings$.MODULE$.read();
    }

    public static Object load(String string) {
        return ChunkEmbeddings$.MODULE$.load(string);
    }

    @Override
    public UserDefinedFunction dfAnnotate() {
        return HasSimpleAnnotate.dfAnnotate$(this);
    }

    public String uid() {
        return this.uid;
    }

    @Override
    public String outputAnnotatorType() {
        return this.outputAnnotatorType;
    }

    @Override
    public String[] inputAnnotatorTypes() {
        return this.inputAnnotatorTypes;
    }

    public Param<String> poolingStrategy() {
        return this.poolingStrategy;
    }

    public BooleanParam skipOOV() {
        return this.skipOOV;
    }

    public ChunkEmbeddings setPoolingStrategy(String strategy) {
        ChunkEmbeddings chunkEmbeddings;
        String string = strategy.toLowerCase();
        if ("average".equals(string)) {
            chunkEmbeddings = (ChunkEmbeddings)this.set(this.poolingStrategy(), "AVERAGE");
        } else if ("sum".equals(string)) {
            chunkEmbeddings = (ChunkEmbeddings)this.set(this.poolingStrategy(), "SUM");
        } else {
            throw new MatchError((Object)"poolingStrategy must be either AVERAGE or SUM");
        }
        return chunkEmbeddings;
    }

    public ChunkEmbeddings setSkipOOV(boolean value) {
        return (ChunkEmbeddings)this.set((Param)this.skipOOV(), BoxesRunTime.boxToBoolean((boolean)value));
    }

    public String getPoolingStrategy() {
        return (String)this.$(this.poolingStrategy());
    }

    public boolean getSkipOOV() {
        return BoxesRunTime.unboxToBoolean((Object)this.$((Param)this.skipOOV()));
    }

    /*
     * WARNING - void declaration
     */
    private float[] calculateChunkEmbeddings(float[][] matrix) {
        void var2_2;
        float[] res = (float[])Array$.MODULE$.ofDim(matrix[0].length, ClassTag$.MODULE$.Float());
        new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(matrix[0])).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)j -> {
            block0: {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])matrix)).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                    res$1[j$1] = res[j] + matrix[i][j];
                });
                Object object = this.$(this.poolingStrategy());
                String string = "AVERAGE";
                if (object != null ? !object.equals(string) : string != null) break block0;
                res$1[j] = res[j] / (float)matrix.length;
            }
        });
        return var2_2;
    }

    @Override
    public Seq<Annotation> annotate(Seq<Annotation> annotations) {
        Seq documentsWithChunks = (Seq)annotations.filter((Function1 & Serializable & scala.Serializable)token -> BoxesRunTime.boxToBoolean((boolean)ChunkEmbeddings.$anonfun$annotate$1(token)));
        Seq<WordpieceEmbeddingsSentence> embeddingsSentences = WordpieceEmbeddingsSentence$.MODULE$.unpack(annotations);
        return (Seq)documentsWithChunks.flatMap((Function1 & Serializable & scala.Serializable)chunk -> {
            Iterable iterable;
            int sentenceIdx = new StringOps(Predef$.MODULE$.augmentString((String)chunk.metadata().getOrElse((Object)"sentence", (Function0 & Serializable & scala.Serializable)() -> "0"))).toInt();
            int chunkIdx = new StringOps(Predef$.MODULE$.augmentString((String)chunk.metadata().getOrElse((Object)"chunk", (Function0 & Serializable & scala.Serializable)() -> "0"))).toInt();
            if (sentenceIdx < embeddingsSentences.length()) {
                float[][] finalEmbeddings;
                TokenPieceEmbeddings[] tokensWithEmbeddings = (TokenPieceEmbeddings[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])((WordpieceEmbeddingsSentence)embeddingsSentences.apply(sentenceIdx)).tokens())).filter((Function1 & Serializable & scala.Serializable)token -> BoxesRunTime.boxToBoolean((boolean)ChunkEmbeddings.$anonfun$annotate$5(chunk, token)));
                float[][] allEmbeddings = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])tokensWithEmbeddings)).flatMap((Function1 & Serializable & scala.Serializable)tokenEmbedding -> !tokenEmbedding.isOOV() || !BoxesRunTime.unboxToBoolean((Object)this.$((Param)this.skipOOV())) ? Option$.MODULE$.option2Iterable((Option)new Some((Object)tokenEmbedding.embeddings())) : Option$.MODULE$.option2Iterable((Option)None$.MODULE$), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
                float[][] fArray = finalEmbeddings = allEmbeddings.length > 0 ? allEmbeddings : (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])tokensWithEmbeddings)).map((Function1 & Serializable & scala.Serializable)x$1 -> x$1.embeddings(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
                iterable = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])finalEmbeddings)).isEmpty() ? Option$.MODULE$.option2Iterable((Option)None$.MODULE$) : Option$.MODULE$.option2Iterable((Option)new Some((Object)new Annotation(this.outputAnnotatorType(), chunk.begin(), chunk.end(), chunk.result(), (Map<String, String>)((Map)Map$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"sentence"), (Object)((Object)BoxesRunTime.boxToInteger((int)sentenceIdx)).toString()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"chunk"), (Object)((Object)BoxesRunTime.boxToInteger((int)chunkIdx)).toString()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"token"), (Object)chunk.result()), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"pieceId"), (Object)"-1"), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"isWordStart"), (Object)"true")}))), this.calculateChunkEmbeddings(finalEmbeddings))));
            } else {
                iterable = Option$.MODULE$.option2Iterable((Option)None$.MODULE$);
            }
            return iterable;
        }, Seq$.MODULE$.canBuildFrom());
    }

    @Override
    public Dataset<Row> afterAnnotate(Dataset<Row> dataset) {
        StructField embeddingsCol = Annotation$.MODULE$.getColumnByType(dataset, (String[])this.$((Param)this.inputCols()), AnnotatorType$.MODULE$.WORD_EMBEDDINGS());
        return dataset.withColumn(this.getOutputCol(), dataset.col(this.getOutputCol()).as(this.getOutputCol(), embeddingsCol.metadata()));
    }

    public static final /* synthetic */ boolean $anonfun$annotate$1(Annotation token) {
        String string = token.annotatorType();
        String string2 = AnnotatorType$.MODULE$.CHUNK();
        return !(string != null ? !string.equals(string2) : string2 != null);
    }

    public static final /* synthetic */ boolean $anonfun$annotate$5(Annotation chunk$1, TokenPieceEmbeddings token) {
        return token.begin() >= chunk$1.begin() && token.end() <= chunk$1.end();
    }

    public ChunkEmbeddings(String uid) {
        this.uid = uid;
        HasSimpleAnnotate.$init$(this);
        this.outputAnnotatorType = AnnotatorType$.MODULE$.WORD_EMBEDDINGS();
        this.inputAnnotatorTypes = (String[])((Object[])new String[]{AnnotatorType$.MODULE$.CHUNK(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()});
        this.poolingStrategy = new Param((Identifiable)this, "poolingStrategy", "Choose how you would like to aggregate Word Embeddings to Chunk Embeddings: AVERAGE or SUM");
        this.skipOOV = new BooleanParam((Identifiable)this, "skipOOV", "Whether to discard default vectors for OOV words from the aggregation / pooling");
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.inputCols().$minus$greater((Object)new String[]{AnnotatorType$.MODULE$.CHUNK(), AnnotatorType$.MODULE$.WORD_EMBEDDINGS()}), this.outputCol().$minus$greater((Object)"chunk_embeddings"), this.poolingStrategy().$minus$greater((Object)"AVERAGE"), this.skipOOV().$minus$greater((Object)BoxesRunTime.boxToBoolean((boolean)true))}));
    }

    public ChunkEmbeddings() {
        this(Identifiable$.MODULE$.randomUID("CHUNK_EMBEDDINGS"));
    }
}

