/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.evaluation;

import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette;
import org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.Function3;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.MapLike;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichDouble;
import scala.runtime.RichInt$;

public final class SquaredEuclideanSilhouette$ {
    public static final SquaredEuclideanSilhouette$ MODULE$;
    private boolean kryoRegistrationPerformed;

    static {
        new SquaredEuclideanSilhouette$();
    }

    public void registerKryoClasses(SparkContext sc) {
        if (!this.kryoRegistrationPerformed) {
            sc.getConf().registerKryoClasses((Class[])((Object[])new Class[]{SquaredEuclideanSilhouette.ClusterStats.class}));
            this.kryoRegistrationPerformed = true;
        }
    }

    public Map<Object, SquaredEuclideanSilhouette.ClusterStats> computeClusterStats(Dataset<Row> df, String predictionCol, String featuresCol) {
        int numFeatures = ((Vector)((Row)df.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(featuresCol)})).first()).getAs(0)).size();
        RDD clustersStatsRDD = RDD$.MODULE$.rddToPairRDDFunctions(df.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(predictionCol).cast((DataType)DoubleType$.MODULE$), functions$.MODULE$.col(featuresCol), functions$.MODULE$.col("squaredNorm")})).rdd().map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple2<Object, Tuple2<Vector, Object>> apply(Row row) {
                return new Tuple2((Object)BoxesRunTime.boxToDouble((double)row.getDouble(0)), (Object)new Tuple2(row.getAs(1), (Object)BoxesRunTime.boxToDouble((double)row.getDouble(2))));
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple2.class), (Ordering)Ordering.Double$.MODULE$).aggregateByKey((Object)new Tuple3((Object)Vectors$.MODULE$.zeros(numFeatures).toDense(), (Object)BoxesRunTime.boxToDouble((double)0.0), (Object)BoxesRunTime.boxToLong((long)0L)), (Function2)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple3<DenseVector, Object, Object> apply(Tuple3<DenseVector, Object, Object> x0$1, Tuple2<Vector, Object> x1$1) {
                Tuple2 tuple2 = new Tuple2(x0$1, x1$1);
                if (tuple2 != null) {
                    Tuple3 tuple3 = (Tuple3)tuple2._1();
                    Tuple2 tuple22 = (Tuple2)tuple2._2();
                    if (tuple3 != null) {
                        DenseVector featureSum = (DenseVector)tuple3._1();
                        double squaredNormSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                        long numOfPoints = BoxesRunTime.unboxToLong((Object)tuple3._3());
                        if (featureSum != null) {
                            DenseVector denseVector = featureSum;
                            double d = squaredNormSum;
                            long l = numOfPoints;
                            if (tuple22 != null) {
                                Vector features = (Vector)tuple22._1();
                                double squaredNorm = tuple22._2$mcD$sp();
                                BLAS$.MODULE$.axpy(1.0, features, (Vector)denseVector);
                                Tuple3 tuple32 = new Tuple3((Object)denseVector, (Object)BoxesRunTime.boxToDouble((double)(d + squaredNorm)), (Object)BoxesRunTime.boxToLong((long)(l + 1L)));
                                return tuple32;
                            }
                        }
                    }
                }
                throw new MatchError((Object)tuple2);
            }
        }, (Function2)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple3<DenseVector, Object, Object> apply(Tuple3<DenseVector, Object, Object> x0$2, Tuple3<DenseVector, Object, Object> x1$2) {
                Tuple2 tuple2 = new Tuple2(x0$2, x1$2);
                if (tuple2 != null) {
                    Tuple3 tuple3 = (Tuple3)tuple2._1();
                    Tuple3 tuple32 = (Tuple3)tuple2._2();
                    if (tuple3 != null) {
                        DenseVector featureSum1 = (DenseVector)tuple3._1();
                        double squaredNormSum1 = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                        long numOfPoints1 = BoxesRunTime.unboxToLong((Object)tuple3._3());
                        if (tuple32 != null) {
                            DenseVector featureSum2 = (DenseVector)tuple32._1();
                            double squaredNormSum2 = BoxesRunTime.unboxToDouble((Object)tuple32._2());
                            long numOfPoints2 = BoxesRunTime.unboxToLong((Object)tuple32._3());
                            BLAS$.MODULE$.axpy(1.0, (Vector)featureSum2, (Vector)featureSum1);
                            Tuple3 tuple33 = new Tuple3((Object)featureSum1, (Object)BoxesRunTime.boxToDouble((double)(squaredNormSum1 + squaredNormSum2)), (Object)BoxesRunTime.boxToLong((long)(numOfPoints1 + numOfPoints2)));
                            return tuple33;
                        }
                    }
                }
                throw new MatchError((Object)tuple2);
            }
        }, ClassTag$.MODULE$.apply(Tuple3.class));
        return RDD$.MODULE$.rddToPairRDDFunctions(clustersStatsRDD, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple3.class), (Ordering)Ordering.Double$.MODULE$).collectAsMap().mapValues((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final SquaredEuclideanSilhouette.ClusterStats apply(Tuple3<DenseVector, Object, Object> x0$3) {
                Tuple3<DenseVector, Object, Object> tuple3 = x0$3;
                if (tuple3 != null) {
                    DenseVector featureSum = (DenseVector)tuple3._1();
                    double squaredNormSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                    long numOfPoints = BoxesRunTime.unboxToLong((Object)tuple3._3());
                    if (featureSum != null) {
                        DenseVector denseVector = featureSum;
                        double d = squaredNormSum;
                        long l = numOfPoints;
                        SquaredEuclideanSilhouette.ClusterStats clusterStats = new SquaredEuclideanSilhouette.ClusterStats((Vector)denseVector, d, l);
                        return clusterStats;
                    }
                }
                throw new MatchError(tuple3);
            }
        }).toMap(Predef$.MODULE$.$conforms());
    }

    public double computeSilhouetteCoefficient(Broadcast<Map<Object, SquaredEuclideanSilhouette.ClusterStats>> broadcastedClustersMap, Vector features, double clusterId, double squaredNorm) {
        double d;
        DoubleRef neighboringClusterDissimilarity = DoubleRef.create((double)Double.MAX_VALUE);
        ((scala.collection.immutable.MapLike)broadcastedClustersMap.value()).keySet().foreach((Function1)new Serializable(broadcastedClustersMap, features, clusterId, squaredNorm, neighboringClusterDissimilarity){
            public static final long serialVersionUID = 0L;
            private final Broadcast broadcastedClustersMap$1;
            private final Vector features$1;
            private final double clusterId$1;
            private final double squaredNorm$1;
            private final DoubleRef neighboringClusterDissimilarity$1;

            public final void apply(double c) {
                this.apply$mcVD$sp(c);
            }

            public void apply$mcVD$sp(double c) {
                double dissimilarity;
                if (c != this.clusterId$1 && (dissimilarity = SquaredEuclideanSilhouette$.MODULE$.org$apache$spark$ml$evaluation$SquaredEuclideanSilhouette$$compute$1(this.squaredNorm$1, this.features$1, (SquaredEuclideanSilhouette.ClusterStats)((MapLike)this.broadcastedClustersMap$1.value()).apply((Object)BoxesRunTime.boxToDouble((double)c)))) < this.neighboringClusterDissimilarity$1.elem) {
                    this.neighboringClusterDissimilarity$1.elem = dissimilarity;
                }
            }
            {
                this.broadcastedClustersMap$1 = broadcastedClustersMap$1;
                this.features$1 = features$1;
                this.clusterId$1 = clusterId$1;
                this.squaredNorm$1 = squaredNorm$1;
                this.neighboringClusterDissimilarity$1 = neighboringClusterDissimilarity$1;
            }
        });
        SquaredEuclideanSilhouette.ClusterStats currentCluster = (SquaredEuclideanSilhouette.ClusterStats)((MapLike)broadcastedClustersMap.value()).apply((Object)BoxesRunTime.boxToDouble((double)clusterId));
        double currentClusterDissimilarity = currentCluster.numOfPoints() == 1L ? 0.0 : this.org$apache$spark$ml$evaluation$SquaredEuclideanSilhouette$$compute$1(squaredNorm, features, currentCluster) * (double)currentCluster.numOfPoints() / (double)(currentCluster.numOfPoints() - 1L);
        int n = RichInt$.MODULE$.signum$extension(Predef$.MODULE$.intWrapper(new RichDouble(Predef$.MODULE$.doubleWrapper(currentClusterDissimilarity)).compare((Object)BoxesRunTime.boxToDouble((double)neighboringClusterDissimilarity.elem))));
        switch (n) {
            default: {
                throw new MatchError((Object)BoxesRunTime.boxToInteger((int)n));
            }
            case 0: {
                d = 0.0;
                break;
            }
            case 1: {
                d = neighboringClusterDissimilarity.elem / currentClusterDissimilarity - 1.0;
                break;
            }
            case -1: {
                d = 1.0 - currentClusterDissimilarity / neighboringClusterDissimilarity.elem;
            }
        }
        return d;
    }

    public double computeSilhouetteScore(Dataset<?> dataset, String predictionCol, String featuresCol) {
        this.registerKryoClasses(dataset.sparkSession().sparkContext());
        JavaUniverse $u = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator1$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }

            public Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator1$1() {
            }
        }
        UserDefinedFunction squaredNormUDF = functions$.MODULE$.udf((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final double apply(Vector features) {
                return scala.math.package$.MODULE$.pow(Vectors$.MODULE$.norm(features, 2.0), 2.0);
            }
        }, ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator1$1()));
        Dataset dfWithSquaredNorm = dataset.withColumn("squaredNorm", squaredNormUDF.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(featuresCol)})));
        Map<Object, SquaredEuclideanSilhouette.ClusterStats> clustersStatsMap = this.computeClusterStats((Dataset<Row>)dfWithSquaredNorm, predictionCol, featuresCol);
        Predef$.MODULE$.assert(clustersStatsMap.size() > 1, (Function0)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "Number of clusters must be greater than one.";
            }
        });
        Broadcast bClustersStatsMap = dataset.sparkSession().sparkContext().broadcast(clustersStatsMap, ClassTag$.MODULE$.apply(Map.class));
        JavaUniverse $u2 = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m2 = package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator2$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }

            public Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator2$1() {
            }
        }
        UserDefinedFunction computeSilhouetteCoefficientUDF = functions$.MODULE$.udf((Function3)new Serializable(bClustersStatsMap){
            public static final long serialVersionUID = 0L;
            private final Broadcast bClustersStatsMap$1;

            public final double apply(Vector x$1, double x$2, double x$3) {
                return SquaredEuclideanSilhouette$.MODULE$.computeSilhouetteCoefficient((Broadcast<Map<Object, SquaredEuclideanSilhouette.ClusterStats>>)this.bClustersStatsMap$1, x$1, x$2, x$3);
            }
            {
                this.bClustersStatsMap$1 = bClustersStatsMap$1;
            }
        }, ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)$u2).TypeTag().apply((Mirror)$m2, (TypeCreator)new Org_apache_spark_ml_evaluation_SquaredEuclideanSilhouette$$typecreator2$1()), ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)package$.MODULE$.universe()).TypeTag().Double());
        double silhouetteScore = ((Row[])dfWithSquaredNorm.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.avg(computeSilhouetteCoefficientUDF.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(featuresCol), functions$.MODULE$.col(predictionCol).cast((DataType)DoubleType$.MODULE$), functions$.MODULE$.col("squaredNorm")})))})).collect())[0].getDouble(0);
        bClustersStatsMap.destroy();
        return silhouetteScore;
    }

    public final double org$apache$spark$ml$evaluation$SquaredEuclideanSilhouette$$compute$1(double squaredNorm, Vector point, SquaredEuclideanSilhouette.ClusterStats clusterStats) {
        double pointDotClusterFeaturesSum = BLAS$.MODULE$.dot(point, clusterStats.featureSum());
        return squaredNorm + clusterStats.squaredNormSum() / (double)clusterStats.numOfPoints() - (double)2 * pointDotClusterFeaturesSum / (double)clusterStats.numOfPoints();
    }

    private SquaredEuclideanSilhouette$() {
        MODULE$ = this;
        this.kryoRegistrationPerformed = false;
    }
}

