/*
 * Decompiled with CFR 0.152.
 */
package org.matheclipse.core.builtin;

import java.util.Collection;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.hipparchus.analysis.ParametricUnivariateFunction;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.fitting.PolynomialCurveFitter;
import org.hipparchus.fitting.SimpleCurveFitter;
import org.hipparchus.fitting.WeightedObservedPoints;
import org.hipparchus.stat.regression.SimpleRegression;
import org.matheclipse.core.basic.Config;
import org.matheclipse.core.builtin.IOFunctions;
import org.matheclipse.core.convert.Convert;
import org.matheclipse.core.eval.EvalEngine;
import org.matheclipse.core.eval.exception.ASTElementLimitExceeded;
import org.matheclipse.core.eval.exception.ValidateException;
import org.matheclipse.core.eval.interfaces.AbstractEvaluator;
import org.matheclipse.core.eval.interfaces.AbstractFunctionEvaluator;
import org.matheclipse.core.expression.F;
import org.matheclipse.core.expression.S;
import org.matheclipse.core.interfaces.IAST;
import org.matheclipse.core.interfaces.IASTAppendable;
import org.matheclipse.core.interfaces.IASTMutable;
import org.matheclipse.core.interfaces.IExpr;
import org.matheclipse.core.interfaces.ISignedNumber;
import org.matheclipse.core.interfaces.ISymbol;

public class CurveFitterFunctions {
    private static final Logger LOGGER = LogManager.getLogger();

    public static void initialize() {
        Initializer.init();
    }

    private CurveFitterFunctions() {
    }

    private static final class LinearModelFit
    extends AbstractEvaluator {
        private LinearModelFit() {
        }

        @Override
        public IExpr evaluate(IAST ast, EvalEngine engine) {
            int[] dim;
            if (ast.arg1().isList() && (dim = ast.arg1().isMatrix()) != null && dim[1] == 2) {
                double[][] data = ast.arg1().toDoubleMatrix();
                SimpleRegression model = new SimpleRegression();
                model.addData(data);
                double[] values = new double[]{model.getIntercept(), model.getSlope()};
                return F.Plus((IExpr)F.num(model.getIntercept()), (IExpr)F.Times((IExpr)F.num(model.getSlope()), ast.arg3()));
            }
            return F.NIL;
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_3_3;
        }
    }

    private static class Fit
    extends FindFit {
        private Fit() {
        }

        @Override
        public IExpr numericEval(IAST ast, EvalEngine engine) {
            int polynomialDegree;
            if (ast.arg1().isList() && ast.arg2().isReal() && ast.arg3().isSymbol() && (polynomialDegree = ast.arg2().toIntDefault()) > 0) {
                WeightedObservedPoints obs;
                if (Config.MAX_AST_SIZE < polynomialDegree) {
                    ASTElementLimitExceeded.throwIt(polynomialDegree);
                }
                PolynomialCurveFitter fitter = PolynomialCurveFitter.create((int)polynomialDegree);
                IAST data = (IAST)ast.arg1();
                if (Fit.addWeightedObservedPoints(data, obs = new WeightedObservedPoints())) {
                    try {
                        return Convert.polynomialFunction2Expr(fitter.fit((Collection)obs.toList()), (ISymbol)ast.arg3());
                    }
                    catch (RuntimeException rex) {
                        LOGGER.log(engine.getLogLevel(), (Object)ast.topHead(), (Throwable)rex);
                    }
                }
            }
            return F.NIL;
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_3_3;
        }
    }

    private static class FindFit
    extends AbstractFunctionEvaluator {
        private FindFit() {
        }

        protected static boolean addWeightedObservedPoints(IAST data, WeightedObservedPoints obs) {
            int[] isMatrix = data.isMatrix();
            if (isMatrix != null && isMatrix[1] == 2) {
                double[][] elements = data.toDoubleMatrix();
                if (elements == null) {
                    return false;
                }
                for (int i = 0; i < elements.length; ++i) {
                    obs.add(1.0, elements[i][0], elements[i][1]);
                }
            } else {
                int rowSize = data.isVector();
                if (rowSize < 0) {
                    return false;
                }
                double[] elements = data.toDoubleVector();
                if (elements == null) {
                    return false;
                }
                for (int i = 0; i < elements.length; ++i) {
                    obs.add(1.0, (double)(i + 1), elements[i]);
                }
            }
            return true;
        }

        private static IExpr convertToRulesList(IAST listOfSymbols, double[] values) {
            IASTAppendable result = F.ListAlloc(listOfSymbols.size());
            listOfSymbols.forEach((arg, i) -> result.append(F.Rule(arg, (IExpr)F.num(values[i - 1]))));
            return result;
        }

        @Override
        public IExpr evaluate(IAST ast, EvalEngine engine) {
            return this.numericEval(ast, engine);
        }

        @Override
        public int[] expectedArgSize(IAST ast) {
            return ARGS_4_4;
        }

        protected static IAST initialGuess(IAST listOfSymbolsOrPairs, double[] initialGuess) {
            IASTAppendable newListOfSymbols = F.ListAlloc(listOfSymbolsOrPairs.size());
            for (int i = 1; i < listOfSymbolsOrPairs.size(); ++i) {
                IExpr temp = listOfSymbolsOrPairs.get(i);
                if (temp.isSymbol()) {
                    initialGuess[i - 1] = 1.0;
                    newListOfSymbols.append(temp);
                    continue;
                }
                if (temp.isAST(S.List, 3) && temp.first().isSymbol()) {
                    ISignedNumber signedNumber = temp.second().evalReal();
                    if (signedNumber == null) {
                        return F.NIL;
                    }
                    initialGuess[i - 1] = signedNumber.doubleValue();
                    newListOfSymbols.append(temp.first());
                    continue;
                }
                return F.NIL;
            }
            return newListOfSymbols;
        }

        protected static double[] initialGuess(int length, double value) {
            double[] initialGuess = new double[length];
            for (int i = 0; i < initialGuess.length; ++i) {
                initialGuess[i] = value;
            }
            return initialGuess;
        }

        @Override
        public IExpr numericEval(IAST ast, EvalEngine engine) {
            if (ast.arg1().isList() && ast.arg3().isList() && ast.arg4().isSymbol()) {
                IAST data = (IAST)ast.arg1();
                IExpr function = ast.arg2();
                IAST listOfSymbols = (IAST)ast.arg3();
                ISymbol x = (ISymbol)ast.arg4();
                double[] initialGuess = new double[listOfSymbols.size() - 1];
                if ((listOfSymbols = FindFit.initialGuess(listOfSymbols, initialGuess)).isPresent()) {
                    try {
                        IExpr gradientList = S.Grad.of(engine, function, listOfSymbols);
                        if (gradientList.isList()) {
                            SimpleCurveFitter fitter = SimpleCurveFitter.create((ParametricUnivariateFunction)new FindFitParametricFunction(function, (IAST)gradientList, listOfSymbols, x, engine), (double[])initialGuess);
                            WeightedObservedPoints obs = new WeightedObservedPoints();
                            if (FindFit.addWeightedObservedPoints(data, obs)) {
                                double[] values = fitter.fit((Collection)obs.toList());
                                return FindFit.convertToRulesList(listOfSymbols, values);
                            }
                        }
                    }
                    catch (ValidateException ve) {
                        return IOFunctions.printMessage(ast.topHead(), ve, engine);
                    }
                    catch (RuntimeException rex) {
                        LOGGER.log(engine.getLogLevel(), (Object)ast.topHead(), (Throwable)rex);
                    }
                }
            }
            return F.NIL;
        }

        private static class FindFitParametricFunction
        implements ParametricUnivariateFunction {
            final EvalEngine engine;
            final IExpr function;
            IAST gradientList;
            IASTAppendable listOfRules;

            public FindFitParametricFunction(IExpr function, IAST gradientList, IAST listOfSymbols, ISymbol x, EvalEngine engine) {
                this.function = function;
                this.engine = engine;
                this.gradientList = gradientList;
                this.listOfRules = F.ListAlloc(gradientList.size());
                this.listOfRules.append(F.Rule((IExpr)x, (IExpr)S.Null));
                listOfSymbols.forEach((Consumer<? super IExpr>)((Consumer<IExpr>)arg -> this.listOfRules.append(F.Rule(arg, (IExpr)S.Null))));
            }

            private void createSubstitutionRules(double t, double ... parameters) {
                IASTMutable substitutionRules = (IASTMutable)this.listOfRules.arg1();
                substitutionRules.set(2, F.num(t));
                for (int i = 2; i < this.listOfRules.size(); ++i) {
                    substitutionRules = (IASTMutable)this.listOfRules.get(i);
                    substitutionRules.set(2, F.num(parameters[i - 2]));
                }
            }

            public double[] gradient(double t, double ... parameters) {
                this.createSubstitutionRules(t, parameters);
                double[] gradient = new double[parameters.length];
                for (int i = 0; i < parameters.length; ++i) {
                    gradient[i] = F.subst(this.gradientList.get(i + 1), this.listOfRules).evalDouble();
                }
                return gradient;
            }

            public double value(double t, double ... parameters) throws MathIllegalArgumentException {
                this.createSubstitutionRules(t, parameters);
                return this.engine.evalDouble(F.subst(this.function, this.listOfRules));
            }
        }
    }

    private static class Initializer {
        private Initializer() {
        }

        private static void init() {
            S.FindFit.setEvaluator(new FindFit());
            S.Fit.setEvaluator(new Fit());
            S.LinearModelFit.setEvaluator(new LinearModelFit());
        }
    }
}

