package org.apache.spark.ml.regression;

import java.io.Serializable;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ModuleSerializationProxy;
import scala.runtime.RichInt$;

/* compiled from: FMRegressor.scala */
/* loaded from: input_file:org/apache/spark/ml/regression/FactorizationMachines$.class */
public final class FactorizationMachines$ implements Serializable {
    public static final FactorizationMachines$ MODULE$ = new FactorizationMachines$();
    private static final String GD = "gd";
    private static final String AdamW = "adamW";
    private static final String[] supportedSolvers = {MODULE$.GD(), MODULE$.AdamW()};
    private static final String LogisticLoss = "logisticLoss";
    private static final String SquaredError = "squaredError";
    private static final String[] supportedRegressorLosses = {MODULE$.SquaredError()};
    private static final String[] supportedClassifierLosses = {MODULE$.LogisticLoss()};
    private static final String[] supportedLosses = (String[]) ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps(MODULE$.supportedRegressorLosses()), MODULE$.supportedClassifierLosses(), ClassTag$.MODULE$.apply(String.class));

    public String GD() {
        return GD;
    }

    public String AdamW() {
        return AdamW;
    }

    public String[] supportedSolvers() {
        return supportedSolvers;
    }

    public String LogisticLoss() {
        return LogisticLoss;
    }

    public String SquaredError() {
        return SquaredError;
    }

    public String[] supportedRegressorLosses() {
        return supportedRegressorLosses;
    }

    public String[] supportedClassifierLosses() {
        return supportedClassifierLosses;
    }

    public String[] supportedLosses() {
        return supportedLosses;
    }

    public Updater parseSolver(String str, int i) {
        Updater adamWUpdater;
        String GD2 = GD();
        if (GD2 != null ? !GD2.equals(str) : str != null) {
            String AdamW2 = AdamW();
            if (AdamW2 != null ? !AdamW2.equals(str) : str != null) {
                throw new MatchError(str);
            }
            adamWUpdater = new AdamWUpdater(i);
        } else {
            adamWUpdater = new SquaredL2Updater();
        }
        return adamWUpdater;
    }

    public BaseFactorizationMachinesGradient parseLoss(String str, int i, boolean z, boolean z2, int i2) {
        BaseFactorizationMachinesGradient mSEFactorizationMachinesGradient;
        String LogisticLoss2 = LogisticLoss();
        if (LogisticLoss2 != null ? !LogisticLoss2.equals(str) : str != null) {
            String SquaredError2 = SquaredError();
            if (SquaredError2 != null ? !SquaredError2.equals(str) : str != null) {
                throw new IllegalArgumentException(new StringBuilder(35).append("loss function type ").append(str).append(" is invalidation").toString());
            }
            mSEFactorizationMachinesGradient = new MSEFactorizationMachinesGradient(i, z, z2, i2);
        } else {
            mSEFactorizationMachinesGradient = new LogisticFactorizationMachinesGradient(i, z, z2, i2);
        }
        return mSEFactorizationMachinesGradient;
    }

    public Tuple3<Object, Vector, Matrix> splitCoefficients(Vector vector, int i, int i2, boolean z, boolean z2) {
        int i3 = (i * i2) + (z2 ? i : 0) + (z ? 1 : 0);
        Predef$.MODULE$.require(i3 == vector.size(), () -> {
            return new StringBuilder(50).append("coefficients.size did not match the excepted size ").append(i3).toString();
        });
        return new Tuple3<>(BoxesRunTime.boxToDouble(z ? vector.apply(vector.size() - 1) : 0.0d), z2 ? new DenseVector((double[]) ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.doubleArrayOps(vector.toArray()), i * i2, (i * i2) + i)) : Vectors$.MODULE$.sparse(i, package$.MODULE$.Seq().empty()), new DenseMatrix(i, i2, (double[]) ArrayOps$.MODULE$.slice$extension(Predef$.MODULE$.doubleArrayOps(vector.toArray()), 0, i * i2), true));
    }

    public Vector combineCoefficients(double d, Vector vector, Matrix matrix, boolean z, boolean z2) {
        return new DenseVector((double[]) ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.doubleArrayOps((double[]) ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.doubleArrayOps(matrix.toDense().values()), z2 ? vector.toArray() : Array$.MODULE$.emptyDoubleArray(), ClassTag$.MODULE$.Double())), z ? new double[]{d} : Array$.MODULE$.emptyDoubleArray(), ClassTag$.MODULE$.Double()));
    }

    public double getRawPrediction(Vector vector, double d, Vector vector2, Matrix matrix) {
        DoubleRef create = DoubleRef.create(d + vector.dot(vector2));
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), matrix.numCols()).foreach$mVc$sp(i -> {
            DoubleRef create2 = DoubleRef.create(0.0d);
            DoubleRef create3 = DoubleRef.create(0.0d);
            vector.foreachNonZero((i, d2) -> {
                Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i, d2);
                if (spVar == null) {
                    throw new MatchError(spVar);
                }
                double apply = matrix.apply(spVar._1$mcI$sp(), i) * spVar._2$mcD$sp();
                create2.elem += apply * apply;
                create3.elem += apply;
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            create.elem += 0.5d * ((create3.elem * create3.elem) - create2.elem);
        });
        return create.elem;
    }

    private Object writeReplace() {
        return new ModuleSerializationProxy(FactorizationMachines$.class);
    }

    private FactorizationMachines$() {
    }
}
