package ml.dmlc.xgboost4j.scala.spark;

import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.util.Either;

/* compiled from: XGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/XGBoost$.class */
public final class XGBoost$ implements Serializable {
    public static final XGBoost$ MODULE$ = null;
    private final Log ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger;

    static {
        new XGBoost$();
    }

    public Log ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger() {
        return this.ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger;
    }

    private Iterator<LabeledPoint> verifyMissingSetting(Iterator<LabeledPoint> iterator, float f, boolean z) {
        return (f == 0.0f || z) ? iterator : iterator.map(new XGBoost$$anonfun$verifyMissingSetting$1(f));
    }

    private Iterator<LabeledPoint> removeMissingValues(Iterator<LabeledPoint> iterator, float f, Function1<Object, Object> function1) {
        return iterator.map(new XGBoost$$anonfun$removeMissingValues$1(function1));
    }

    public Iterator<LabeledPoint> processMissingValues(Iterator<LabeledPoint> iterator, float f, boolean z) {
        return Predef$.MODULE$.float2Float(f).isNaN() ? removeMissingValues(verifyMissingSetting(iterator, f, z), f, new XGBoost$$anonfun$processMissingValues$2()) : removeMissingValues(verifyMissingSetting(iterator, f, z), f, new XGBoost$$anonfun$processMissingValues$1(f));
    }

    public Iterator<LabeledPoint[]> ml$dmlc$xgboost4j$scala$spark$XGBoost$$processMissingValuesWithGroup(Iterator<LabeledPoint[]> iterator, float f, boolean z) {
        return Predef$.MODULE$.float2Float(f).isNaN() ? iterator : iterator.map(new XGBoost$$anonfun$ml$dmlc$xgboost4j$scala$spark$XGBoost$$processMissingValuesWithGroup$1(f, z));
    }

    public Option<String> ml$dmlc$xgboost4j$scala$spark$XGBoost$$getCacheDirName(boolean z) {
        return z ? new Some(Files.createTempDirectory(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "-cache-", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(TaskContext$.MODULE$.get().stageId()), BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString()})), new FileAttribute[0]).toAbsolutePath().toString()) : None$.MODULE$;
    }

    public Iterator<Tuple2<Booster, Map<String, float[]>>> ml$dmlc$xgboost4j$scala$spark$XGBoost$$buildDistributedBooster(Watches watches, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, ObjectiveTrait objectiveTrait, EvalTrait evalTrait, Booster booster) {
        if (((DMatrix) watches.toMap().apply("train")).rowNum() == 0) {
            throw new XGBoostError(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"detected an empty partition in the training data, partition ID:"})).s(Nil$.MODULE$)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId())}))).toString());
        }
        String obj = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString();
        String obj2 = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.get().attemptNumber()).toString();
        map.put("DMLC_TASK_ID", obj);
        map.put("DMLC_NUM_ATTEMPT", obj2);
        map.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false");
        int numRounds = xGBoostExecutionParams.numRounds();
        boolean z = xGBoostExecutionParams.checkpointParam().isDefined() && new StringOps(Predef$.MODULE$.augmentString(obj)).toInt() == 0;
        try {
            try {
                Rabit.init(map);
                int numEarlyStoppingRounds = xGBoostExecutionParams.earlyStoppingParams().numEarlyStoppingRounds();
                float[][] fArr = (float[][]) Array$.MODULE$.tabulate(watches.size(), new XGBoost$$anonfun$14(numRounds), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
                Iterator<Tuple2<Booster, Map<String, float[]>>> apply = scala.package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(z ? ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.trainAndSaveCheckpoint((DMatrix) watches.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster, xGBoostExecutionParams.checkpointParam()) : ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix) watches.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster)), ((TraversableOnce) watches.toMap().keys().zip(Predef$.MODULE$.wrapRefArray(fArr), Iterable$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()))}));
                Rabit.shutdown();
                watches.delete();
                return apply;
            } catch (XGBoostError e) {
                ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger().error(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"XGBooster worker ", " has failed ", " times due to "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{obj, obj2})), e);
                throw e;
            }
        } catch (Throwable th) {
            Rabit.shutdown();
            watches.delete();
            throw th;
        }
    }

    private IRabitTracker startTracker(int i, TrackerConf trackerConf) {
        String trackerImpl = trackerConf.trackerImpl();
        RabitTracker rabitTracker = "scala".equals(trackerImpl) ? new RabitTracker(i, RabitTracker$.MODULE$.$lessinit$greater$default$2(), RabitTracker$.MODULE$.$lessinit$greater$default$3()) : "python".equals(trackerImpl) ? new ml.dmlc.xgboost4j.java.RabitTracker(i) : new ml.dmlc.xgboost4j.java.RabitTracker(i);
        Predef$.MODULE$.require(rabitTracker.start(trackerConf.workerConnectionTimeout()), new XGBoost$$anonfun$startTracker$1());
        return rabitTracker;
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets(RDD<LabeledPoint> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) ((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus(map).map(new XGBoost$$anonfun$15(i), Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, new XGBoost$$anonfun$coPartitionNoGroupSets$1(), ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), new XGBoost$$anonfun$coPartitionNoGroupSets$2());
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForNonRanking(RDD<LabeledPoint> rdd, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, Booster booster, Map<String, RDD<LabeledPoint>> map2) {
        if (map2.isEmpty()) {
            return rdd.mapPartitions(new XGBoost$$anonfun$trainForNonRanking$1(xGBoostExecutionParams, map, booster), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets = coPartitionNoGroupSets(rdd, map2, xGBoostExecutionParams.numWorkers());
        return coPartitionNoGroupSets.mapPartitions(new XGBoost$$anonfun$16(xGBoostExecutionParams, map, booster), coPartitionNoGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForRanking(RDD<LabeledPoint[]> rdd, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, Booster booster, Map<String, RDD<LabeledPoint>> map2) {
        if (map2.isEmpty()) {
            return rdd.mapPartitions(new XGBoost$$anonfun$trainForRanking$1(xGBoostExecutionParams, map, booster), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets = coPartitionGroupSets(rdd, map2, xGBoostExecutionParams.numWorkers());
        return coPartitionGroupSets.mapPartitions(new XGBoost$$anonfun$18(xGBoostExecutionParams, map, booster), coPartitionGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<?> cacheData(boolean z, RDD<?> rdd) {
        return z ? rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : rdd;
    }

    private Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData(RDD<LabeledPoint> rdd, boolean z, boolean z2, int i) {
        if (!z2) {
            return scala.package$.MODULE$.Right().apply(cacheData(z, rdd));
        }
        return scala.package$.MODULE$.Left().apply(cacheData(z, repartitionForTrainingGroup(rdd, i)));
    }

    /*  JADX ERROR: NullPointerException in pass: AttachTryCatchVisitor
        java.lang.NullPointerException: Cannot invoke "String.charAt(int)" because "obj" is null
        	at jadx.core.utils.Utils.cleanObjectName(Utils.java:38)
        	at jadx.core.dex.instructions.args.ArgType.object(ArgType.java:86)
        	at jadx.core.dex.info.ClassInfo.fromName(ClassInfo.java:42)
        	at jadx.core.dex.visitors.AttachTryCatchVisitor.convertToHandlers(AttachTryCatchVisitor.java:113)
        	at jadx.core.dex.visitors.AttachTryCatchVisitor.initTryCatches(AttachTryCatchVisitor.java:54)
        	at jadx.core.dex.visitors.AttachTryCatchVisitor.visit(AttachTryCatchVisitor.java:42)
        */
    public scala.Tuple2<ml.dmlc.xgboost4j.scala.Booster, scala.collection.immutable.Map<java.lang.String, float[]>> trainDistributed(org.apache.spark.rdd.RDD<ml.dmlc.xgboost4j.LabeledPoint> r10, scala.collection.immutable.Map<java.lang.String, java.lang.Object> r11, boolean r12, scala.collection.immutable.Map<java.lang.String, org.apache.spark.rdd.RDD<ml.dmlc.xgboost4j.LabeledPoint>> r13) throws ml.dmlc.xgboost4j.java.XGBoostError {
        /*
            Method dump skipped, instructions count: 634
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, boolean, scala.collection.immutable.Map):scala.Tuple2");
    }

    public boolean trainDistributed$default$3() {
        return false;
    }

    public Map<String, RDD<LabeledPoint>> trainDistributed$default$4() {
        return Predef$.MODULE$.Map().apply(Nil$.MODULE$);
    }

    private void uncacheTrainingData(boolean z, Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> either) {
        if (z) {
            if (either.isLeft()) {
                RDD rdd = (RDD) either.left().get();
                rdd.unpersist(rdd.unpersist$default$1());
            } else {
                RDD rdd2 = (RDD) either.right().get();
                rdd2.unpersist(rdd2.unpersist$default$1());
            }
        }
    }

    public RDD<LabeledPoint[]> ml$dmlc$xgboost4j$scala$spark$XGBoost$$aggByGroupInfo(RDD<LabeledPoint> rdd) {
        return rdd.mapPartitions(new XGBoost$$anonfun$21(), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(new XGBoost$$anonfun$22()).map(new XGBoost$$anonfun$23(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))).union(rdd.mapPartitions(new XGBoost$$anonfun$24(), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(new XGBoost$$anonfun$25()).map(new XGBoost$$anonfun$26(), ClassTag$.MODULE$.apply(Tuple2.class)).groupBy(new XGBoost$$anonfun$27(), ClassTag$.MODULE$.Int()).map(new XGBoost$$anonfun$28(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))));
    }

    public RDD<LabeledPoint[]> repartitionForTrainingGroup(RDD<LabeledPoint> rdd, int i) {
        RDD<LabeledPoint[]> ml$dmlc$xgboost4j$scala$spark$XGBoost$$aggByGroupInfo = ml$dmlc$xgboost4j$scala$spark$XGBoost$$aggByGroupInfo(rdd);
        ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"repartitioning training group set to ", " partitions"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i)})));
        return ml$dmlc$xgboost4j$scala$spark$XGBoost$$aggByGroupInfo.repartition(i, ml$dmlc$xgboost4j$scala$spark$XGBoost$$aggByGroupInfo.repartition$default$2(i));
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets(RDD<LabeledPoint[]> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus((GenTraversableOnce) map.map(new XGBoost$$anonfun$29(i), Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, new XGBoost$$anonfun$coPartitionGroupSets$1(), ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), new XGBoost$$anonfun$coPartitionGroupSets$2());
    }

    private Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing(int i, RDD<Tuple2<Booster, Map<String, float[]>>> rdd, Thread thread) {
        if (i != 0) {
            try {
                if (thread.isAlive()) {
                    thread.interrupt();
                }
            } catch (InterruptedException unused) {
                ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger().info("spark job thread is interrupted");
            }
            throw new XGBoostError("XGBoostModel training failed");
        }
        thread.join();
        Tuple2 tuple2 = (Tuple2) rdd.first();
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
        Booster booster = (Booster) tuple22._1();
        Map map = (Map) tuple22._2();
        rdd.unpersist(false);
        return new Tuple2<>(booster, map);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private XGBoost$() {
        MODULE$ = this;
        this.ml$dmlc$xgboost4j$scala$spark$XGBoost$$logger = LogFactory.getLog("XGBoostSpark");
    }
}
