/*
 * Decompiled with CFR 0.152.
 */
package org.matheclipse.core.builtin;

import java.util.ArrayList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.hipparchus.clustering.Cluster;
import org.hipparchus.clustering.DBSCANClusterer;
import org.hipparchus.clustering.DoublePoint;
import org.hipparchus.clustering.KMeansPlusPlusClusterer;
import org.hipparchus.clustering.MultiKMeansPlusPlusClusterer;
import org.hipparchus.clustering.distance.DistanceMeasure;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.util.MathArrays;
import org.matheclipse.core.eval.EvalEngine;
import org.matheclipse.core.eval.interfaces.AbstractEvaluator;
import org.matheclipse.core.eval.util.OptionArgs;
import org.matheclipse.core.expression.ASTRealVector;
import org.matheclipse.core.expression.F;
import org.matheclipse.core.expression.S;
import org.matheclipse.core.interfaces.IAST;
import org.matheclipse.core.interfaces.IASTAppendable;
import org.matheclipse.core.interfaces.IBuiltInSymbol;
import org.matheclipse.core.interfaces.IEvaluator;
import org.matheclipse.core.interfaces.IExpr;

public class ClusteringFunctions {
    private static final Logger LOGGER = LogManager.getLogger();

    public static void initialize() {
        Initializer.init();
    }

    private ClusteringFunctions() {
    }

    static final class SquaredEuclideanDistance
    extends AbstractDistance {
        private static final long serialVersionUID = -34208439139174441L;

        SquaredEuclideanDistance() {
        }

        public double compute(double[] p1, double[] p2) throws MathIllegalArgumentException {
            double sum = 0.0;
            for (int i = 0; i < p1.length; ++i) {
                double absValue = Math.abs(p1[i] - p2[i]);
                sum += absValue * absValue;
            }
            return sum;
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            IAST vect1 = (IAST)a.normal(false);
            IAST vect2 = (IAST)b.normal(false);
            int size = a.size();
            IASTAppendable plusAST = F.PlusAlloc(size);
            return plusAST.appendArgs(size, i -> F.Sqr(F.Abs(F.Subtract(vect1.get(i), vect2.get(i)))));
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_2_2;
        }
    }

    private static final class ManhattanDistance
    extends AbstractDistance {
        private static final long serialVersionUID = -3203931866584067444L;
        static final DistanceMeasure distance = new org.hipparchus.clustering.distance.ManhattanDistance();

        private ManhattanDistance() {
        }

        public double compute(double[] arg0, double[] arg1) throws MathIllegalArgumentException {
            return distance.compute(arg0, arg1);
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            IAST vect1 = (IAST)a.normal(false);
            IAST vect2 = (IAST)b.normal(false);
            int size = a.size();
            IASTAppendable plusAST = F.PlusAlloc(size);
            return plusAST.appendArgs(size, i -> F.Abs(F.Subtract(vect1.get(i), vect2.get(i))));
        }
    }

    private static class Initializer {
        private Initializer() {
        }

        private static void init() {
            S.BinaryDistance.setEvaluator(new BinaryDistance());
            S.BrayCurtisDistance.setEvaluator(new BrayCurtisDistance());
            S.CanberraDistance.setEvaluator(new CanberraDistance());
            S.ChessboardDistance.setEvaluator(new ChessboardDistance());
            S.CosineDistance.setEvaluator(new CosineDistance());
            S.EuclideanDistance.setEvaluator(new EuclideanDistance());
            S.FindClusters.setEvaluator(new FindClusters());
            S.ManhattanDistance.setEvaluator(new ManhattanDistance());
            S.SquaredEuclideanDistance.setEvaluator(new SquaredEuclideanDistance());
        }
    }

    private static class FindClusters
    extends AbstractEvaluator {
        private FindClusters() {
        }

        @Override
        public IExpr evaluate(IAST ast, EvalEngine engine) {
            block26: {
                String method = "";
                EuclideanDistance measure = new EuclideanDistance();
                try {
                    KMeansPlusPlusClusterer transformer;
                    DoublePoint p;
                    if (!ast.arg1().isList() || ast.arg1().size() <= 1) break block26;
                    IAST listArg1 = (IAST)ast.arg1();
                    int k = 3;
                    double eps = 3.0;
                    int minPts = 1;
                    if (ast.size() > 2) {
                        OptionArgs options = new OptionArgs(ast.topHead(), ast, 2, engine);
                        IExpr option = options.getOption(S.Method);
                        if (option.isPresent()) {
                            method = option.toString();
                        }
                        if ((option = options.getOption(S.DistanceFunction)).isPresent()) {
                            IEvaluator distanceEvaluator;
                            measure = null;
                            if (option.isBuiltInSymbol() && (distanceEvaluator = ((IBuiltInSymbol)option).getEvaluator()) instanceof DistanceMeasure) {
                                measure = (DistanceMeasure)distanceEvaluator;
                            }
                            if (measure == null) {
                                return F.NIL;
                            }
                        }
                        if (("KMeans".equals(method) || "".equals(method)) && (k = ast.arg2().toIntDefault()) == Integer.MIN_VALUE) {
                            k = 3;
                        }
                    }
                    if ("DBSCAN".equals(method)) {
                        if (ast.size() < 5) {
                            return F.NIL;
                        }
                        eps = engine.evalDouble(ast.arg2());
                        minPts = ast.arg3().toIntDefault();
                        if (minPts <= 0) {
                            return F.NIL;
                        }
                    }
                    if (k <= 0) break block26;
                    ArrayList<DoublePoint> points = new ArrayList<DoublePoint>(listArg1.argSize());
                    if (listArg1.isListOfLists()) {
                        for (int j = 1; j < listArg1.size(); ++j) {
                            double[] values = listArg1.get(j).toDoubleVector();
                            if (values == null) {
                                return F.NIL;
                            }
                            p = new DoublePoint(values);
                            points.add(p);
                        }
                    } else {
                        double[] values = listArg1.toDoubleVector();
                        if (values == null) {
                            return F.NIL;
                        }
                        for (int i = 0; i < values.length; ++i) {
                            p = new DoublePoint(new double[]{values[i]});
                            points.add(p);
                        }
                    }
                    if ("KMeans".equals(method)) {
                        transformer = new KMeansPlusPlusClusterer(k, 100, (DistanceMeasure)measure);
                    } else if ("DBSCAN".equals(method)) {
                        transformer = new DBSCANClusterer(eps, minPts, (DistanceMeasure)measure);
                    } else {
                        KMeansPlusPlusClusterer kMeansTransformer = new KMeansPlusPlusClusterer(k, 100, (DistanceMeasure)measure);
                        transformer = new MultiKMeansPlusPlusClusterer(kMeansTransformer, 10);
                    }
                    List clusters = transformer.cluster(points);
                    IASTAppendable result = F.ListAlloc(clusters.size());
                    for (Cluster cluster : clusters) {
                        int i;
                        List clusterPoints = cluster.getPoints();
                        IASTAppendable list = F.ListAlloc(clusterPoints.size());
                        if (listArg1.isListOfLists()) {
                            for (i = 0; i < clusterPoints.size(); ++i) {
                                double[] dVector = (double[])((DoublePoint)clusterPoints.get(i)).getPoint().clone();
                                list.append(new ASTRealVector(dVector, false));
                            }
                        } else {
                            for (i = 0; i < clusterPoints.size(); ++i) {
                                list.append(((DoublePoint)clusterPoints.get(i)).getPoint()[0]);
                            }
                        }
                        result.append(list);
                    }
                    return result;
                }
                catch (MathRuntimeException mrex) {
                    LOGGER.log(engine.getLogLevel(), (Object)ast.topHead(), (Throwable)mrex);
                }
            }
            return F.NIL;
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_1_5;
        }
    }

    private static final class EuclideanDistance
    extends AbstractDistance {
        private static final long serialVersionUID = 2872848600632425591L;
        static final DistanceMeasure distance = new org.hipparchus.clustering.distance.EuclideanDistance();

        private EuclideanDistance() {
        }

        public double compute(double[] arg0, double[] arg1) throws MathIllegalArgumentException {
            return distance.compute(arg0, arg1);
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            IAST vect1 = (IAST)a.normal(false);
            IAST vect2 = (IAST)b.normal(false);
            int size = a.size();
            IASTAppendable plusAST = F.PlusAlloc(size);
            plusAST.appendArgs(size, i -> F.Sqr(F.Abs(F.Subtract(vect1.get(i), vect2.get(i)))));
            return F.Sqrt(plusAST);
        }
    }

    private static final class CosineDistance
    extends AbstractDistance {
        private static final long serialVersionUID = -108468814401695919L;

        private CosineDistance() {
        }

        public double compute(double[] a, double[] b) throws MathIllegalArgumentException {
            return 1.0 - MathArrays.cosAngle((double[])a, (double[])b);
        }

        @Override
        public IExpr distance(IExpr arg1, IExpr arg2) {
            return F.Subtract(F.C1, F.Divide(F.Dot(arg1, arg2), F.Times((IExpr)F.Norm(arg1), (IExpr)F.Norm(arg2))));
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_2_2;
        }
    }

    private static final class ChessboardDistance
    extends AbstractDistance {
        private static final long serialVersionUID = 6473415254245961676L;

        private ChessboardDistance() {
        }

        public double compute(double[] a, double[] b) throws MathIllegalArgumentException {
            return MathArrays.distanceInf((double[])a, (double[])b);
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            IAST vect1 = (IAST)a.normal(false);
            IAST vect2 = (IAST)b.normal(false);
            IASTAppendable maxAST = F.Max();
            return maxAST.appendArgs(a.size(), i -> F.Abs(F.Subtract(vect1.get(i), vect2.get(i))));
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_2_2;
        }
    }

    private static final class CanberraDistance
    extends AbstractDistance {
        private static final long serialVersionUID = 6257588266259496269L;
        static final DistanceMeasure distance = new org.hipparchus.clustering.distance.CanberraDistance();

        private CanberraDistance() {
        }

        public double compute(double[] arg0, double[] arg1) throws MathIllegalArgumentException {
            return distance.compute(arg0, arg1);
        }

        @Override
        public IAST distance(IExpr a, IExpr b) {
            return F.Total(F.Divide(F.Abs(F.Subtract(a, b)), F.Plus((IExpr)F.Abs(a), (IExpr)F.Abs(b))));
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_2_2;
        }
    }

    private static final class BrayCurtisDistance
    extends AbstractDistance {
        private static final long serialVersionUID = 669052613809997063L;

        private BrayCurtisDistance() {
        }

        private static double distancePlus(double[] p1, double[] p2) throws MathIllegalArgumentException {
            double sum = 0.0;
            for (int i = 0; i < p1.length; ++i) {
                sum += Math.abs(p1[i] + p2[i]);
            }
            return sum;
        }

        public double compute(double[] a, double[] b) throws MathIllegalArgumentException {
            return MathArrays.distance1((double[])a, (double[])b) / BrayCurtisDistance.distancePlus(a, b);
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            return F.Divide(F.Total(F.Abs(F.Subtract(a, b))), F.Total(F.Abs(F.Plus(a, b))));
        }
    }

    private static final class BinaryDistance
    extends AbstractDistance {
        private static final long serialVersionUID = 6407163419470076191L;

        private BinaryDistance() {
        }

        public double compute(double[] a, double[] b) throws MathIllegalArgumentException {
            if (a == b) {
                return 1.0;
            }
            if (a.length != b.length) {
                return 0.0;
            }
            for (int i = 0; i < a.length; ++i) {
                if (F.isEqual(a[i], b[i])) continue;
                return 0.0;
            }
            return 1.0;
        }

        @Override
        protected IExpr vectorDistance(IExpr arg1, IExpr arg2, EvalEngine engine) {
            return this.distance(arg1, arg2);
        }

        @Override
        public IExpr distance(IExpr a, IExpr b) {
            return a.equals(b) ? F.C1 : F.C0;
        }
    }

    private static abstract class AbstractDistance
    extends AbstractEvaluator
    implements DistanceMeasure {
        private static final long serialVersionUID = -295980120043414467L;

        private AbstractDistance() {
        }

        public abstract IExpr distance(IExpr var1, IExpr var2);

        @Override
        public IExpr evaluate(IAST ast, EvalEngine engine) {
            int dim2;
            IExpr arg1 = ast.arg1();
            IExpr arg2 = ast.arg2();
            int dim1 = arg1.isVector();
            if (dim1 > -1 && dim1 == (dim2 = arg2.isVector())) {
                if (dim1 == 0) {
                    return F.NIL;
                }
                return this.vectorDistance(arg1, arg2, engine);
            }
            return F.NIL;
        }

        protected IExpr vectorDistance(IExpr arg1, IExpr arg2, EvalEngine engine) {
            double[] b;
            double[] a;
            if ((engine.isDoubleMode() || arg1.isNumericAST() || arg2.isNumericAST()) && (a = arg1.toDoubleVector()) != null && (b = arg2.toDoubleVector()) != null) {
                return F.num(this.compute(a, b));
            }
            return this.distance(arg1, arg2);
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_2_2;
        }
    }
}

