/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.tensorflow;

import com.johnsnowlabs.ml.tensorflow.ClassifierDatasetEncoder;
import com.johnsnowlabs.ml.tensorflow.Logging;
import com.johnsnowlabs.ml.tensorflow.TensorResources;
import com.johnsnowlabs.ml.tensorflow.TensorResources$;
import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper;
import com.johnsnowlabs.nlp.Annotation;
import com.johnsnowlabs.nlp.Annotation$;
import com.johnsnowlabs.nlp.AnnotatorType$;
import com.johnsnowlabs.nlp.util.io.OutputHelper$;
import java.io.Serializable;
import java.util.List;
import org.apache.spark.ml.util.Identifiable$;
import org.slf4j.Logger;
import org.tensorflow.Tensor;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.GenTraversableOnce;
import scala.collection.IterableLike;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.MapLike;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.FloatRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;
import scala.util.Random$;

@ScalaSignature(bytes="\u0006\u0001\t\u0005c\u0001\u0002\u001c8\u0001\u0001C\u0001\u0002\u000f\u0001\u0003\u0006\u0004%\tA\u0014\u0005\t%\u0002\u0011\t\u0011)A\u0005\u001f\"A1\u000b\u0001BC\u0002\u0013\u0005A\u000b\u0003\u0005Y\u0001\t\u0005\t\u0015!\u0003V\u0011!I\u0006A!b\u0001\n\u0003R\u0006\u0002\u00036\u0001\u0005\u0003\u0005\u000b\u0011B.\t\u000b-\u0004A\u0011\u00017\t\u000fE\u0004!\u0019!C\u0005e\"11\u0010\u0001Q\u0001\nMDq\u0001 \u0001C\u0002\u0013%!\u000f\u0003\u0004~\u0001\u0001\u0006Ia\u001d\u0005\b}\u0002\u0011\r\u0011\"\u0003s\u0011\u0019y\b\u0001)A\u0005g\"A\u0011\u0011\u0001\u0001C\u0002\u0013%!\u000fC\u0004\u0002\u0004\u0001\u0001\u000b\u0011B:\t\u0013\u0005\u0015\u0001A1A\u0005\n\u0005\u001d\u0001\u0002CA\b\u0001\u0001\u0006I!!\u0003\t\u0013\u0005E\u0001A1A\u0005\n\u0005M\u0001\u0002CA\u0015\u0001\u0001\u0006I!!\u0006\t\u0013\u0005-\u0002A1A\u0005\n\u0005M\u0001\u0002CA\u0017\u0001\u0001\u0006I!!\u0006\t\u0013\u0005=\u0002A1A\u0005\n\u0005M\u0001\u0002CA\u0019\u0001\u0001\u0006I!!\u0006\t\u0013\u0005M\u0002A1A\u0005\n\u0005M\u0001\u0002CA\u001b\u0001\u0001\u0006I!!\u0006\t\u0013\u0005]\u0002A1A\u0005\n\u0005M\u0001\u0002CA\u001d\u0001\u0001\u0006I!!\u0006\t\u0013\u0005m\u0002A1A\u0005\n\u0005M\u0001\u0002CA\u001f\u0001\u0001\u0006I!!\u0006\t\u0013\u0005}\u0002A1A\u0005\n\u0005M\u0001\u0002CA!\u0001\u0001\u0006I!!\u0006\t\u0013\u0005\r\u0003A1A\u0005\n\u0005M\u0001\u0002CA#\u0001\u0001\u0006I!!\u0006\t\u0011\u0005\u001d\u0003A1A\u0005\nIDq!!\u0013\u0001A\u0003%1\u000fC\u0004\u0002L\u0001!\t!!\u0014\t\u000f\u0005\r\u0004\u0001\"\u0001\u0002f!I\u0011\u0011\u0018\u0001\u0012\u0002\u0013\u0005\u00111\u0018\u0005\n\u0003#\u0004\u0011\u0013!C\u0001\u0003'D\u0011\"a6\u0001#\u0003%\t!a5\t\u0013\u0005e\u0007!%A\u0005\u0002\u0005M\u0007\"CAn\u0001E\u0005I\u0011AAo\u0011%\t\t\u000fAI\u0001\n\u0003\tY\fC\u0005\u0002d\u0002\t\n\u0011\"\u0001\u0002f\"I\u0011\u0011\u001e\u0001\u0012\u0002\u0013\u0005\u0011Q\u001d\u0005\n\u0003W\u0004\u0011\u0013!C\u0001\u0003[Dq!!=\u0001\t\u0003\t\u0019\u0010C\u0005\u0003\"\u0001\t\n\u0011\"\u0001\u0002<\"I!1\u0005\u0001\u0012\u0002\u0013\u0005\u0011Q\u001c\u0005\b\u0005K\u0001A\u0011\u0001B\u0014\u0011%\u0011y\u0003AI\u0001\n\u0003\ti\u000eC\u0004\u00032\u0001!\tAa\r\t\u0013\t}\u0002!%A\u0005\u0002\u0005M'!\u0007+f]N|'O\u001a7po6+H\u000e^5DY\u0006\u001c8/\u001b4jKJT!\u0001O\u001d\u0002\u0015Q,gn]8sM2|wO\u0003\u0002;w\u0005\u0011Q\u000e\u001c\u0006\u0003yu\nAB[8i]Ntwn\u001e7bENT\u0011AP\u0001\u0004G>l7\u0001A\n\u0005\u0001\u0005;%\n\u0005\u0002C\u000b6\t1IC\u0001E\u0003\u0015\u00198-\u00197b\u0013\t15I\u0001\u0004B]f\u0014VM\u001a\t\u0003\u0005\"K!!S\"\u0003\u0019M+'/[1mSj\f'\r\\3\u0011\u0005-cU\"A\u001c\n\u00055;$a\u0002'pO\u001eLgnZ\u000b\u0002\u001fB\u00111\nU\u0005\u0003#^\u0012\u0011\u0003V3og>\u0014h\r\\8x/J\f\u0007\u000f]3s\u0003-!XM\\:pe\u001adwn\u001e\u0011\u0002\u000f\u0015t7m\u001c3feV\tQ\u000b\u0005\u0002L-&\u0011qk\u000e\u0002\u0019\u00072\f7o]5gS\u0016\u0014H)\u0019;bg\u0016$XI\\2pI\u0016\u0014\u0018\u0001C3oG>$WM\u001d\u0011\u0002\u0019Y,'OY8tK2+g/\u001a7\u0016\u0003m\u0003\"\u0001\u00184\u000f\u0005u#W\"\u00010\u000b\u0005}\u0003\u0017a\u00018fe*\u0011\u0011MY\u0001\u000bC:tw\u000e^1u_J\u001c(BA2<\u0003\rqG\u000e]\u0005\u0003Kz\u000bqAV3sE>\u001cX-\u0003\u0002hQ\n)a+\u00197vK&\u0011\u0011n\u0011\u0002\f\u000b:,X.\u001a:bi&|g.A\u0007wKJ\u0014wn]3MKZ,G\u000eI\u0001\u0007y%t\u0017\u000e\u001e \u0015\t5tw\u000e\u001d\t\u0003\u0017\u0002AQ\u0001O\u0004A\u0002=CQaU\u0004A\u0002UCQ!W\u0004A\u0002m\u000b\u0001\"\u001b8qkR\\U-_\u000b\u0002gB\u0011A/_\u0007\u0002k*\u0011ao^\u0001\u0005Y\u0006twMC\u0001y\u0003\u0011Q\u0017M^1\n\u0005i,(AB*ue&tw-A\u0005j]B,HoS3zA\u0005AA.\u00192fY.+\u00170A\u0005mC\n,GnS3zA\u0005\t2/Z9vK:\u001cW\rT3oORD7*Z=\u0002%M,\u0017/^3oG\u0016dUM\\4uQ.+\u0017\u0010I\u0001\u0010Y\u0016\f'O\\5oOJ\u000bG/Z&fs\u0006\u0001B.Z1s]&twMU1uK.+\u0017\u0010I\u0001\u000b]Vl7\t\\1tg\u0016\u001cXCAA\u0005!\r\u0011\u00151B\u0005\u0004\u0003\u001b\u0019%aA%oi\u0006Ya.^7DY\u0006\u001c8/Z:!\u00035\u0001(/\u001a3jGRLwN\\&fsV\u0011\u0011Q\u0003\t\u0005\u0003/\t)C\u0004\u0003\u0002\u001a\u0005\u0005\u0002cAA\u000e\u00076\u0011\u0011Q\u0004\u0006\u0004\u0003?y\u0014A\u0002\u001fs_>$h(C\u0002\u0002$\r\u000ba\u0001\u0015:fI\u00164\u0017b\u0001>\u0002()\u0019\u00111E\"\u0002\u001dA\u0014X\rZ5di&|gnS3zA\u0005aq\u000e\u001d;j[&TXM]&fs\u0006iq\u000e\u001d;j[&TXM]&fs\u0002\nq\u0001\\8tg.+\u00170\u0001\u0005m_N\u001c8*Z=!\u0003-\t7mY;sC\u000eL8*Z=\u0002\u0019\u0005\u001c7-\u001e:bGf\\U-\u001f\u0011\u0002\u00135,GO]5dg\u001a\u000b\u0014AC7fiJL7m\u001d$2A\u0005iQ.\u001a;sS\u000e\u001c\u0018iY2LKf\fa\"\\3ue&\u001c7/Q2d\u0017\u0016L\b%\u0001\bnKR\u0014\u0018nY:M_N\u001c8*Z=\u0002\u001f5,GO]5dg2{7o]&fs\u0002\nQ\"\\3ue&\u001c7\u000f\u0016)S\u0017\u0016L\u0018AD7fiJL7m\u001d+Q%.+\u0017\u0010I\u0001\bS:LGoS3z\u0003!Ig.\u001b;LKf\u0004\u0013\u0001\u0006:fg\"\f\u0007/Z%oaV$h)Z1ukJ,7\u000f\u0006\u0003\u0002P\u0005}\u0003#\u0002\"\u0002R\u0005U\u0013bAA*\u0007\n)\u0011I\u001d:bsB)!)!\u0015\u0002XA)!)!\u0015\u0002ZA\u0019!)a\u0017\n\u0007\u0005u3IA\u0003GY>\fG\u000fC\u0004\u0002b\u0011\u0002\r!a\u0014\u0002\u000b\t\fGo\u00195\u0002\u000bQ\u0014\u0018-\u001b8\u00159\u0005\u001d\u0014QNA9\u0003s\ni(!!\u0002\u0006\u0006%\u0015QRAP\u0003G\u000bi+!-\u00026B\u0019!)!\u001b\n\u0007\u0005-4I\u0001\u0003V]&$\bbBA8K\u0001\u0007\u0011qJ\u0001\u0007S:\u0004X\u000f^:\t\u000f\u0005MT\u00051\u0001\u0002v\u00051A.\u00192fYN\u0004RAQA)\u0003o\u0002RAQA)\u0003+Aq!a\u001f&\u0001\u0004\tI!\u0001\u0005dY\u0006\u001c8OT;n\u0011%\ty(\nI\u0001\u0002\u0004\tI&\u0001\u0002me\"I\u00111Q\u0013\u0011\u0002\u0003\u0007\u0011\u0011B\u0001\nE\u0006$8\r[*ju\u0016D\u0011\"a\"&!\u0003\u0005\r!!\u0003\u0002\u0015M$\u0018M\u001d;Fa>\u001c\u0007\u000eC\u0005\u0002\f\u0016\u0002\n\u00111\u0001\u0002\n\u0005AQM\u001c3Fa>\u001c\u0007\u000eC\u0005\u0002\u0010\u0016\u0002\n\u00111\u0001\u0002\u0012\u0006\u00012m\u001c8gS\u001e\u0004&o\u001c;p\u0005f$Xm\u001d\t\u0006\u0005\u0006M\u0015qS\u0005\u0004\u0003+\u001b%AB(qi&|g\u000eE\u0003C\u0003#\nI\nE\u0002C\u00037K1!!(D\u0005\u0011\u0011\u0015\u0010^3\t\u0013\u0005\u0005V\u0005%AA\u0002\u0005e\u0013a\u0004<bY&$\u0017\r^5p]N\u0003H.\u001b;\t\u0013\u0005\u0015V\u0005%AA\u0002\u0005\u001d\u0016\u0001D:ik\u001a4G.Z#q_\u000eD\u0007c\u0001\"\u0002*&\u0019\u00111V\"\u0003\u000f\t{w\u000e\\3b]\"I\u0011qV\u0013\u0011\u0002\u0003\u0007\u0011qU\u0001\u0011K:\f'\r\\3PkR\u0004X\u000f\u001e'pONDq!a-&\u0001\u0004\t)\"\u0001\bpkR\u0004X\u000f\u001e'pON\u0004\u0016\r\u001e5\t\u0013\u0005]V\u0005%AA\u0002\u0005U\u0011\u0001B;vS\u0012\fq\u0002\u001e:bS:$C-\u001a4bk2$H\u0005N\u000b\u0003\u0003{SC!!\u0017\u0002@.\u0012\u0011\u0011\u0019\t\u0005\u0003\u0007\fi-\u0004\u0002\u0002F*!\u0011qYAe\u0003%)hn\u00195fG.,GMC\u0002\u0002L\u000e\u000b!\"\u00198o_R\fG/[8o\u0013\u0011\ty-!2\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cW-A\bue\u0006Lg\u000e\n3fM\u0006,H\u000e\u001e\u00136+\t\t)N\u000b\u0003\u0002\n\u0005}\u0016a\u0004;sC&tG\u0005Z3gCVdG\u000f\n\u001c\u0002\u001fQ\u0014\u0018-\u001b8%I\u00164\u0017-\u001e7uI]\nq\u0002\u001e:bS:$C-\u001a4bk2$H\u0005O\u000b\u0003\u0003?TC!!%\u0002@\u0006yAO]1j]\u0012\"WMZ1vYR$\u0013(\u0001\tue\u0006Lg\u000e\n3fM\u0006,H\u000e\u001e\u00132aU\u0011\u0011q\u001d\u0016\u0005\u0003O\u000by,\u0001\tue\u0006Lg\u000e\n3fM\u0006,H\u000e\u001e\u00132c\u0005\u0001BO]1j]\u0012\"WMZ1vYR$\u0013gM\u000b\u0003\u0003_TC!!\u0006\u0002@\u00069\u0001O]3eS\u000e$H\u0003CA{\u0005\u001f\u0011YBa\b\u0011\r\u0005](\u0011\u0001B\u0004\u001d\u0011\tI0!@\u000f\t\u0005m\u00111`\u0005\u0002\t&\u0019\u0011q`\"\u0002\u000fA\f7m[1hK&!!1\u0001B\u0003\u0005\r\u0019V-\u001d\u0006\u0004\u0003\u007f\u001c\u0005\u0003\u0002B\u0005\u0005\u0017i\u0011AY\u0005\u0004\u0005\u001b\u0011'AC!o]>$\u0018\r^5p]\"9!\u0011C\u0018A\u0002\tM\u0011\u0001\u00023pGN\u0004b!a>\u0003\u0002\tU\u0001c\u0002\"\u0003\u0018\u0005%\u0011Q_\u0005\u0004\u00053\u0019%A\u0002+va2,'\u0007C\u0005\u0003\u001e=\u0002\n\u00111\u0001\u0002Z\u0005IA\u000f\u001b:fg\"|G\u000e\u001a\u0005\n\u0003\u001f{\u0003\u0013!a\u0001\u0003#\u000b\u0011\u0003\u001d:fI&\u001cG\u000f\n3fM\u0006,H\u000e\u001e\u00133\u0003E\u0001(/\u001a3jGR$C-\u001a4bk2$HeM\u0001\u0010S:$XM\u001d8bYB\u0013X\rZ5diRA\u0011q\u000bB\u0015\u0005W\u0011i\u0003C\u0004\u0002pI\u0002\r!a\u0014\t\u000f\u0005M$\u00071\u0001\u0002V!I\u0011q\u0012\u001a\u0011\u0002\u0003\u0007\u0011\u0011S\u0001\u001aS:$XM\u001d8bYB\u0013X\rZ5di\u0012\"WMZ1vYR$3'A\u0004nK\u0006\u001cXO]3\u0015\r\u0005]#Q\u0007B\u001f\u0011\u001d\u00119\u0004\u000ea\u0001\u0005s\tq\u0001\\1cK2,G\rE\u0003C\u0003#\u0012Y\u0004E\u0004C\u0005/\t)&a\u0016\t\u0013\u0005\rE\u0007%AA\u0002\u0005%\u0011!E7fCN,(/\u001a\u0013eK\u001a\fW\u000f\u001c;%e\u0001")
public class TensorflowMultiClassifier
implements scala.Serializable,
Logging {
    private final TensorflowWrapper tensorflow;
    private final ClassifierDatasetEncoder encoder;
    private final Enumeration.Value verboseLevel;
    private final String inputKey;
    private final String labelKey;
    private final String sequenceLengthKey;
    private final String learningRateKey;
    private final int numClasses;
    private final String predictionKey;
    private final String optimizerKey;
    private final String lossKey;
    private final String accuracyKey;
    private final String metricsF1;
    private final String metricsAccKey;
    private final String metricsLossKey;
    private final String metricsTPRKey;
    private final String initKey;
    private final Logger logger;

    @Override
    public String getLogName() {
        return Logging.getLogName$(this);
    }

    @Override
    public void log(Function0<String> value, Enumeration.Value minLevel) {
        Logging.log$(this, value, minLevel);
    }

    @Override
    public void outputLog(Function0<String> value, String uuid, boolean shouldLog, String outputLogsPath) {
        Logging.outputLog$(this, value, uuid, shouldLog, outputLogsPath);
    }

    @Override
    public Logger logger() {
        return this.logger;
    }

    @Override
    public void com$johnsnowlabs$ml$tensorflow$Logging$_setter_$logger_$eq(Logger x$1) {
        this.logger = x$1;
    }

    public TensorflowWrapper tensorflow() {
        return this.tensorflow;
    }

    public ClassifierDatasetEncoder encoder() {
        return this.encoder;
    }

    @Override
    public Enumeration.Value verboseLevel() {
        return this.verboseLevel;
    }

    private String inputKey() {
        return this.inputKey;
    }

    private String labelKey() {
        return this.labelKey;
    }

    private String sequenceLengthKey() {
        return this.sequenceLengthKey;
    }

    private String learningRateKey() {
        return this.learningRateKey;
    }

    private int numClasses() {
        return this.numClasses;
    }

    private String predictionKey() {
        return this.predictionKey;
    }

    private String optimizerKey() {
        return this.optimizerKey;
    }

    private String lossKey() {
        return this.lossKey;
    }

    private String accuracyKey() {
        return this.accuracyKey;
    }

    private String metricsF1() {
        return this.metricsF1;
    }

    private String metricsAccKey() {
        return this.metricsAccKey;
    }

    private String metricsLossKey() {
        return this.metricsLossKey;
    }

    private String metricsTPRKey() {
        return this.metricsTPRKey;
    }

    private String initKey() {
        return this.initKey;
    }

    public float[][][] reshapeInputFeatures(float[][][] batch) {
        int[] sequencesLength = (int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToInteger((int)TensorflowMultiClassifier.$anonfun$reshapeInputFeatures$1(x)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        int maxSentenceLength = BoxesRunTime.unboxToInt((Object)new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(sequencesLength)).max((Ordering)Ordering.Int$.MODULE$));
        int dimension = ((float[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch[0])).head()).length;
        return (float[][][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)sentence -> {
            float[][] fArray;
            if (((float[][])sentence).length >= maxSentenceLength) {
                fArray = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])sentence)).take(maxSentenceLength);
            } else {
                int diff = maxSentenceLength - ((float[][])sentence).length;
                fArray = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])sentence)).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])Array$.MODULE$.fill(diff, (Function0 & Serializable & scala.Serializable)() -> (float[])Array$.MODULE$.fill(dimension, (Function0)(JFunction0.mcF.sp & Serializable & scala.Serializable)() -> 0.0f, ClassTag$.MODULE$.Float()), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
            }
            return fArray;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)))));
    }

    public void train(float[][][] inputs, String[][] labels, int classNum, float lr, int batchSize, int startEpoch, int endEpoch, Option<byte[]> configProtoBytes, float validationSplit, boolean shuffleEpoch, boolean enableOutputLogs, String outputLogsPath, String uuid) {
        block4: {
            Tuple2 tuple2;
            Tuple2 tuple22;
            Object object = startEpoch == 0 ? this.tensorflow().createSession(configProtoBytes).runner().addTarget(this.initKey()).run() : BoxedUnit.UNIT;
            float[][] encodedLabels = this.encoder().encodeTagsMultiLabel(labels);
            Seq zippedInputsLabels = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputs)).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])encodedLabels), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).toSeq();
            Seq trainingDataset = (Seq)Random$.MODULE$.shuffle((TraversableOnce)zippedInputsLabels, Seq$.MODULE$.canBuildFrom());
            int sample = (int)((float)trainingDataset.length() * validationSplit);
            if (validationSplit > 0.0f) {
                Tuple2 tuple23 = trainingDataset.splitAt(sample);
                if (tuple23 == null) {
                    throw new MatchError((Object)tuple23);
                }
                Seq trainingSample = (Seq)tuple23._1();
                Seq trainingSet = (Seq)tuple23._2();
                Tuple2 tuple24 = new Tuple2((Object)trainingSample, (Object)trainingSet);
                Tuple2 tuple25 = tuple24;
                Seq trainingSample2 = (Seq)tuple25._1();
                Seq trainingSet2 = (Seq)tuple25._2();
                tuple22 = new Tuple2(trainingSet2.toArray(ClassTag$.MODULE$.apply(Tuple2.class)), trainingSample2.toArray(ClassTag$.MODULE$.apply(Tuple2.class)));
            } else {
                Seq emptyValid = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{new Tuple2(Array$.MODULE$.empty(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))), Array$.MODULE$.empty(ClassTag$.MODULE$.Float()))}));
                tuple22 = tuple2 = new Tuple2(trainingDataset.toArray(ClassTag$.MODULE$.apply(Tuple2.class)), emptyValid.toArray(ClassTag$.MODULE$.apply(Tuple2.class)));
            }
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            Tuple2[] trainDatasetSeq = (Tuple2[])tuple2._1();
            Tuple2[] validateDatasetSample = (Tuple2[])tuple2._2();
            Tuple2 tuple26 = new Tuple2((Object)trainDatasetSeq, (Object)validateDatasetSample);
            Tuple2 tuple27 = tuple26;
            Tuple2[] trainDatasetSeq2 = (Tuple2[])tuple27._1();
            Tuple2[] validateDatasetSample2 = (Tuple2[])tuple27._2();
            Predef$.MODULE$.println((Object)new StringBuilder(94).append("Training started - epochs: ").append(endEpoch).append(" - learning_rate: ").append(lr).append(" - batch_size: ").append(batchSize).append(" - training_examples: ").append(trainDatasetSeq2.length).append(" - classes: ").append(classNum).toString());
            this.outputLog((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(94).append("Training started - epochs: ").append(endEpoch).append(" - learning_rate: ").append(lr).append(" - batch_size: ").append(batchSize).append(" - training_examples: ").append(trainDatasetSeq2.length).append(" - classes: ").append(classNum).toString(), uuid, enableOutputLogs, outputLogsPath);
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(startEpoch), endEpoch).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)epoch -> {
                long time = System.nanoTime();
                IntRef batches = IntRef.create((int)0);
                FloatRef loss = FloatRef.create((float)0.0f);
                FloatRef acc = FloatRef.create((float)0.0f);
                double learningRate = (double)lr / (1.0 + 0.2 * (double)epoch);
                Tuple2[] shuffledBatch = shuffleEpoch ? (Tuple2[])Random$.MODULE$.shuffle((TraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])trainDatasetSeq2)).toSeq(), Seq$.MODULE$.canBuildFrom()).toArray(ClassTag$.MODULE$.apply(Tuple2.class)) : trainDatasetSeq2;
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])shuffledBatch)).grouped(batchSize).foreach((Function1 & Serializable & scala.Serializable)batch -> {
                    TensorflowMultiClassifier.$anonfun$train$3(this, learningRate, configProtoBytes, loss, acc, batches, batch);
                    return BoxedUnit.UNIT;
                });
                acc.elem /= (float)(trainDatasetSeq2.length / batchSize);
                loss.elem /= (float)(trainDatasetSeq2.length / batchSize);
                if ((double)validationSplit > 0.0) {
                    float[] validationAccuracy = this.measure(validateDatasetSample2, this.measure$default$2());
                    double endTime = (double)(System.nanoTime() - time) / 1.0E9;
                    Integer arg$macro$1 = BoxesRunTime.boxToInteger((int)(epoch + 1));
                    Integer arg$macro$2 = BoxesRunTime.boxToInteger((int)endEpoch);
                    double arg$macro$3 = endTime;
                    Float arg$macro$4 = BoxesRunTime.boxToFloat((float)loss.elem);
                    Float arg$macro$5 = BoxesRunTime.boxToFloat((float)acc.elem);
                    Float arg$macro$6 = BoxesRunTime.boxToFloat((float)validationAccuracy[0]);
                    Float arg$macro$7 = BoxesRunTime.boxToFloat((float)validationAccuracy[1]);
                    Float arg$macro$8 = BoxesRunTime.boxToFloat((float)validationAccuracy[2]);
                    Float arg$macro$9 = BoxesRunTime.boxToFloat((float)validationAccuracy[3]);
                    Integer arg$macro$10 = BoxesRunTime.boxToInteger((int)batches.elem);
                    Predef$.MODULE$.println((Object)new StringOps("Epoch %s/%s - %.2fs - loss: %s - acc: %s - val_loss: %s - val_acc: %s - val_f1: %s - val_tpr: %s - batches: %s").format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{arg$macro$1, arg$macro$2, BoxesRunTime.boxToDouble((double)arg$macro$3), arg$macro$4, arg$macro$5, arg$macro$6, arg$macro$7, arg$macro$8, arg$macro$9, arg$macro$10})));
                    this.outputLog((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> {
                        Integer arg$macro$11 = BoxesRunTime.boxToInteger((int)epoch);
                        Integer arg$macro$12 = BoxesRunTime.boxToInteger((int)endEpoch);
                        double arg$macro$13 = endTime;
                        Float arg$macro$14 = BoxesRunTime.boxToFloat((float)loss$1.elem);
                        Float arg$macro$15 = BoxesRunTime.boxToFloat((float)acc$1.elem);
                        Float arg$macro$16 = BoxesRunTime.boxToFloat((float)validationAccuracy[0]);
                        Float arg$macro$17 = BoxesRunTime.boxToFloat((float)validationAccuracy[1]);
                        Float arg$macro$18 = BoxesRunTime.boxToFloat((float)validationAccuracy[2]);
                        Float arg$macro$19 = BoxesRunTime.boxToFloat((float)validationAccuracy[3]);
                        Integer arg$macro$20 = BoxesRunTime.boxToInteger((int)batches$1.elem);
                        return new StringOps("Epoch %s/%s - %.2fs - loss: %s - acc: %s - val_loss: %s - val_acc: %s - val_f1: %s - val_tpr: %s - batches: %s").format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{arg$macro$11, arg$macro$12, BoxesRunTime.boxToDouble((double)arg$macro$13), arg$macro$14, arg$macro$15, arg$macro$16, arg$macro$17, arg$macro$18, arg$macro$19, arg$macro$20}));
                    }, uuid, enableOutputLogs, outputLogsPath);
                } else {
                    double endTime = (double)(System.nanoTime() - time) / 1.0E9;
                    Integer arg$macro$21 = BoxesRunTime.boxToInteger((int)(epoch + 1));
                    Integer arg$macro$22 = BoxesRunTime.boxToInteger((int)endEpoch);
                    double arg$macro$23 = endTime;
                    Float arg$macro$24 = BoxesRunTime.boxToFloat((float)loss.elem);
                    Integer arg$macro$25 = BoxesRunTime.boxToInteger((int)batches.elem);
                    Predef$.MODULE$.println((Object)new StringOps("Epoch %s/%s - %.2fs - loss: %s - batches: %s").format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{arg$macro$21, arg$macro$22, BoxesRunTime.boxToDouble((double)arg$macro$23), arg$macro$24, arg$macro$25})));
                    this.outputLog((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> {
                        Integer arg$macro$26 = BoxesRunTime.boxToInteger((int)epoch);
                        Integer arg$macro$27 = BoxesRunTime.boxToInteger((int)endEpoch);
                        double arg$macro$28 = endTime;
                        Float arg$macro$29 = BoxesRunTime.boxToFloat((float)loss$1.elem);
                        Integer arg$macro$30 = BoxesRunTime.boxToInteger((int)batches$1.elem);
                        return new StringOps("Epoch %s/%s - %.2fs - loss: %s - batches: %s").format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{arg$macro$26, arg$macro$27, BoxesRunTime.boxToDouble((double)arg$macro$28), arg$macro$29, arg$macro$30}));
                    }, uuid, enableOutputLogs, outputLogsPath);
                }
            });
            if (!enableOutputLogs) break block4;
            OutputHelper$.MODULE$.exportLogFileToS3();
        }
    }

    public float train$default$4() {
        return 0.005f;
    }

    public int train$default$5() {
        return 64;
    }

    public int train$default$6() {
        return 0;
    }

    public int train$default$7() {
        return 10;
    }

    public Option<byte[]> train$default$8() {
        return None$.MODULE$;
    }

    public float train$default$9() {
        return 0.0f;
    }

    public boolean train$default$10() {
        return false;
    }

    public boolean train$default$11() {
        return false;
    }

    public String train$default$13() {
        return Identifiable$.MODULE$.randomUID("multiclassifierdl");
    }

    public Seq<Annotation> predict(Seq<Tuple2<Object, Seq<Annotation>>> docs, float threshold, Option<byte[]> configProtoBytes) {
        TensorResources tensors = new TensorResources();
        float[][][] inputs = this.encoder().extractSentenceEmbeddingsMultiLabelPredict(docs);
        int[] sequenceLengthArrays = (int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputs)).map((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToInteger((int)TensorflowMultiClassifier.$anonfun$predict$1(x)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        float[][][] inputsReshaped = this.reshapeInputFeatures(inputs);
        List calculated = this.tensorflow().getSession(configProtoBytes).runner().feed(this.inputKey(), tensors.createTensor(inputsReshaped)).feed(this.sequenceLengthKey(), tensors.createTensor(sequenceLengthArrays)).fetch(this.predictionKey()).run();
        float[][] tagsId = (float[][])new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(0), TensorResources$.MODULE$.extractFloats$default$2()))).grouped(this.numClasses()).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
        Tuple2<String, Object>[][] tagsName = this.encoder().decodeOutputData(tagsId);
        tensors.clearTensors();
        return (Seq)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])tagsName)).flatMap((Function1 & Serializable & scala.Serializable)score -> new ArrayOps.ofRef(TensorflowMultiClassifier.$anonfun$predict$2(threshold, docs, score)), Array$.MODULE$.fallbackCanBuildFrom(Predef.DummyImplicit$.MODULE$.dummyImplicit()));
    }

    public float predict$default$2() {
        return 0.5f;
    }

    public Option<byte[]> predict$default$3() {
        return None$.MODULE$;
    }

    public float[] internalPredict(float[][][] inputs, float[][] labels, Option<byte[]> configProtoBytes) {
        TensorResources tensors = new TensorResources();
        int[] sequenceLengthArrays = (int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputs)).map((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToInteger((int)TensorflowMultiClassifier.$anonfun$internalPredict$1(x)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        float[][][] inputsReshaped = this.reshapeInputFeatures(inputs);
        List calculated = this.tensorflow().getSession(configProtoBytes).runner().feed(this.inputKey(), tensors.createTensor(inputsReshaped)).feed(this.labelKey(), tensors.createTensor(labels)).feed(this.sequenceLengthKey(), tensors.createTensor(sequenceLengthArrays)).fetch(this.metricsLossKey()).fetch(this.metricsAccKey()).fetch(this.metricsF1()).fetch(this.metricsTPRKey()).run();
        float valLoss = TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(0), TensorResources$.MODULE$.extractFloats$default$2())[0];
        float valAcc = TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(1), TensorResources$.MODULE$.extractFloats$default$2())[0];
        float valF1 = TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(2), TensorResources$.MODULE$.extractFloats$default$2())[0];
        float valTPR = TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(3), TensorResources$.MODULE$.extractFloats$default$2())[0];
        tensors.clearTensors();
        return new float[]{valLoss, valAcc, valF1, valTPR};
    }

    public Option<byte[]> internalPredict$default$3() {
        return None$.MODULE$;
    }

    public float[] measure(Tuple2<float[][], float[]>[] labeled, int batchSize) {
        FloatRef loss = FloatRef.create((float)0.0f);
        FloatRef acc = FloatRef.create((float)0.0f);
        FloatRef f1 = FloatRef.create((float)0.0f);
        FloatRef tpr = FloatRef.create((float)0.0f);
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])labeled)).grouped(batchSize).foreach((Function1 & Serializable & scala.Serializable)batch -> {
            TensorflowMultiClassifier.$anonfun$measure$1(this, loss, acc, f1, tpr, batch);
            return BoxedUnit.UNIT;
        });
        int avgSize = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])labeled)).grouped(batchSize).length();
        loss.elem /= (float)avgSize;
        acc.elem /= (float)avgSize;
        f1.elem /= (float)avgSize;
        tpr.elem /= (float)avgSize;
        return new float[]{loss.elem, acc.elem, f1.elem, tpr.elem};
    }

    public int measure$default$2() {
        return 100;
    }

    public static final /* synthetic */ int $anonfun$reshapeInputFeatures$1(float[][] x) {
        return x.length;
    }

    public static final /* synthetic */ int $anonfun$train$4(Tuple2 x) {
        return ((float[][])x._1()).length;
    }

    public static final /* synthetic */ void $anonfun$train$3(TensorflowMultiClassifier $this, double learningRate$1, Option configProtoBytes$1, FloatRef loss$1, FloatRef acc$1, IntRef batches$1, Tuple2[] batch) {
        TensorResources tensors = new TensorResources();
        int[] sequenceLengthArrays = (int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToInteger((int)TensorflowMultiClassifier.$anonfun$train$4(x)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        float[][][] inputArrays = $this.reshapeInputFeatures((float[][][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> (float[][])x._1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))))));
        float[][] labelsArray = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> (float[])x._2(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
        Tensor inputTensor = tensors.createTensor(inputArrays);
        Tensor labelTensor = tensors.createTensor(labelsArray);
        Tensor sequenceLengthTensor = tensors.createTensor(sequenceLengthArrays);
        Tensor lrTensor = tensors.createTensor(BoxesRunTime.boxToFloat((float)((float)learningRate$1)));
        List calculated = $this.tensorflow().getSession((Option<byte[]>)configProtoBytes$1).runner().feed($this.inputKey(), inputTensor).feed($this.labelKey(), labelTensor).feed($this.sequenceLengthKey(), sequenceLengthTensor).feed($this.learningRateKey(), lrTensor).fetch($this.predictionKey()).fetch($this.lossKey()).fetch($this.accuracyKey()).addTarget($this.optimizerKey()).run();
        loss$1.elem += TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(1), TensorResources$.MODULE$.extractFloats$default$2())[0];
        acc$1.elem += TensorResources$.MODULE$.extractFloats((Tensor)calculated.get(2), TensorResources$.MODULE$.extractFloats$default$2())[0];
        ++batches$1.elem;
        tensors.clearTensors();
    }

    public static final /* synthetic */ int $anonfun$predict$1(float[][] x) {
        return x.length;
    }

    public static final /* synthetic */ boolean $anonfun$predict$3(float threshold$1, Tuple2 x) {
        return BoxesRunTime.unboxToFloat((Object)x._2()) >= threshold$1;
    }

    public static final /* synthetic */ Object[] $anonfun$predict$2(float threshold$1, Seq docs$1, Tuple2[] score) {
        String[] labels = (String[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])score)).filter((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToBoolean((boolean)TensorflowMultiClassifier.$anonfun$predict$3(threshold$1, x))))).map((Function1 & Serializable & scala.Serializable)x -> (String)x._1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
        int documentBegin = ((Annotation)((IterableLike)((Tuple2)docs$1.head())._2()).head()).begin();
        int documentEnd = ((Annotation)((TraversableLike)((Tuple2)docs$1.last())._2()).last()).end();
        return Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])labels)).map((Function1 & Serializable & scala.Serializable)label -> new Annotation(AnnotatorType$.MODULE$.CATEGORY(), documentBegin, documentEnd, (String)label, (Map<String, String>)((MapLike)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"sentence"), (Object)"0")}))).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])score)).flatMap((Function1 & Serializable & scala.Serializable)x -> (scala.collection.immutable.Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(x._1()), (Object)x._2().toString())})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))))), Annotation$.MODULE$.apply$default$6()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Annotation.class))));
    }

    public static final /* synthetic */ int $anonfun$internalPredict$1(float[][] x) {
        return x.length;
    }

    public static final /* synthetic */ void $anonfun$measure$1(TensorflowMultiClassifier $this, FloatRef loss$2, FloatRef acc$2, FloatRef f1$1, FloatRef tpr$1, Tuple2[] batch) {
        float[][][] originalEmbeddings = (float[][][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> (float[][])x._1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)))));
        float[][] originalLabels = (float[][])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])batch)).map((Function1 & Serializable & scala.Serializable)x -> (float[])x._2(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))));
        float[] metricsArray = $this.internalPredict(originalEmbeddings, originalLabels, $this.internalPredict$default$3());
        loss$2.elem += metricsArray[0];
        acc$2.elem += metricsArray[1];
        f1$1.elem += metricsArray[2];
        tpr$1.elem += metricsArray[3];
    }

    public TensorflowMultiClassifier(TensorflowWrapper tensorflow, ClassifierDatasetEncoder encoder, Enumeration.Value verboseLevel) {
        this.tensorflow = tensorflow;
        this.encoder = encoder;
        this.verboseLevel = verboseLevel;
        Logging.$init$(this);
        this.inputKey = "inputs:0";
        this.labelKey = "labels:0";
        this.sequenceLengthKey = "sequence_length:0";
        this.learningRateKey = "lr:0";
        this.numClasses = encoder.params().tags().length;
        this.predictionKey = new StringBuilder(25).append("sigmoid_output_").append(this.numClasses()).append("/Sigmoid:0").toString();
        this.optimizerKey = new StringBuilder(20).append("optimizer_adam_").append(this.numClasses()).append("/Adam").toString();
        this.lossKey = new StringBuilder(36).append("loss_").append(this.numClasses()).append("/bce_loss/weighted_loss/value:0").toString();
        this.accuracyKey = new StringBuilder(25).append("accuracy_").append(this.numClasses()).append("/mean_accuracy:0").toString();
        this.metricsF1 = new StringBuilder(22).append("metrics_").append(this.numClasses()).append("/f1/f1_score:0").toString();
        this.metricsAccKey = new StringBuilder(33).append("metrics_").append(this.numClasses()).append("/accuracy/mean_accuracy:0").toString();
        this.metricsLossKey = new StringBuilder(44).append("metrics_").append(this.numClasses()).append("/loss/bce_loss/weighted_loss/value:0").toString();
        this.metricsTPRKey = new StringBuilder(23).append("metrics_").append(this.numClasses()).append("/f1/truediv_4:0").toString();
        this.initKey = "init_all_tables";
    }
}

