package ml.dmlc.xgboost4j.scala.spark;

import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ServiceLoader;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams;
import ml.dmlc.xgboost4j.scala.spark.params.HasBaseMarginCol;
import ml.dmlc.xgboost4j.scala.spark.params.NonParamVariables;
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon;
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils;
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils$;
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils$MLVectorToXGBLabeledPoint$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple9;
import scala.collection.AbstractIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Iterator$;
import scala.collection.JavaConverters$;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.util.Either;
import scala.util.Left;
import scala.util.Right;

/* compiled from: PreXGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/PreXGBoost$.class */
public final class PreXGBoost$ implements PreXGBoostProvider {
    public static PreXGBoost$ MODULE$;
    private Column defaultBaseMarginColumn;
    private Column defaultWeightColumn;
    private Column defaultGroupColumn;
    private final Log logger;
    private final Option<PreXGBoostProvider> optionProvider;
    private volatile byte bitmap$0;

    static {
        new PreXGBoost$();
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public boolean providerEnabled(Option<Dataset<?>> option) {
        boolean providerEnabled;
        providerEnabled = providerEnabled(option);
        return providerEnabled;
    }

    private Log logger() {
        return this.logger;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [ml.dmlc.xgboost4j.scala.spark.PreXGBoost$] */
    private Column defaultBaseMarginColumn$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$0 & 1)) == 0) {
                this.defaultBaseMarginColumn = functions$.MODULE$.lit(BoxesRunTime.boxToFloat(Float.NaN));
                r0 = this;
                r0.bitmap$0 = (byte) (this.bitmap$0 | 1);
            }
        }
        return this.defaultBaseMarginColumn;
    }

    private Column defaultBaseMarginColumn() {
        return ((byte) (this.bitmap$0 & 1)) == 0 ? defaultBaseMarginColumn$lzycompute() : this.defaultBaseMarginColumn;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [ml.dmlc.xgboost4j.scala.spark.PreXGBoost$] */
    private Column defaultWeightColumn$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$0 & 2)) == 0) {
                this.defaultWeightColumn = functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d));
                r0 = this;
                r0.bitmap$0 = (byte) (this.bitmap$0 | 2);
            }
        }
        return this.defaultWeightColumn;
    }

    private Column defaultWeightColumn() {
        return ((byte) (this.bitmap$0 & 2)) == 0 ? defaultWeightColumn$lzycompute() : this.defaultWeightColumn;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [ml.dmlc.xgboost4j.scala.spark.PreXGBoost$] */
    private Column defaultGroupColumn$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$0 & 4)) == 0) {
                this.defaultGroupColumn = functions$.MODULE$.lit(BoxesRunTime.boxToInteger(-1));
                r0 = this;
                r0.bitmap$0 = (byte) (this.bitmap$0 | 4);
            }
        }
        return this.defaultGroupColumn;
    }

    private Column defaultGroupColumn() {
        return ((byte) (this.bitmap$0 & 4)) == 0 ? defaultGroupColumn$lzycompute() : this.defaultGroupColumn;
    }

    private Option<PreXGBoostProvider> optionProvider() {
        return this.optionProvider;
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public StructType transformSchema(XGBoostEstimatorCommon xGBoostEstimatorCommon, StructType structType) {
        StructType transformSchemaInternal;
        if (optionProvider().isDefined() && ((PreXGBoostProvider) optionProvider().get()).providerEnabled(None$.MODULE$)) {
            return ((PreXGBoostProvider) optionProvider().get()).transformSchema(xGBoostEstimatorCommon, structType);
        }
        if (xGBoostEstimatorCommon instanceof XGBoostClassifier) {
            transformSchemaInternal = ((XGBoostClassifier) xGBoostEstimatorCommon).transformSchemaInternal(structType);
        } else if (xGBoostEstimatorCommon instanceof XGBoostClassificationModel) {
            transformSchemaInternal = ((XGBoostClassificationModel) xGBoostEstimatorCommon).transformSchemaInternal(structType);
        } else if (xGBoostEstimatorCommon instanceof XGBoostRegressor) {
            transformSchemaInternal = ((XGBoostRegressor) xGBoostEstimatorCommon).transformSchemaInternal(structType);
        } else {
            if (!(xGBoostEstimatorCommon instanceof XGBoostRegressionModel)) {
                throw new RuntimeException(new StringBuilder(13).append("Unsupporting ").append(xGBoostEstimatorCommon).toString());
            }
            transformSchemaInternal = ((XGBoostRegressionModel) xGBoostEstimatorCommon).transformSchemaInternal(structType);
        }
        return transformSchemaInternal;
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public Function1<XGBoostExecutionParams, Tuple2<RDD<Function0<Watches>>, Option<RDD<?>>>> buildDatasetToRDD(Estimator<?> estimator, Dataset<?> dataset, Map<String, Object> map) {
        Some some;
        if (optionProvider().isDefined() && ((PreXGBoostProvider) optionProvider().get()).providerEnabled(new Some(dataset))) {
            return ((PreXGBoostProvider) optionProvider().get()).buildDatasetToRDD(estimator, dataset, map);
        }
        if (!(estimator instanceof XGBoostEstimatorCommon)) {
            throw new RuntimeException(new StringBuilder(13).append("Unsupporting ").append(estimator).toString());
        }
        Column defaultWeightColumn = (!estimator.isDefined(((HasWeightCol) estimator).weightCol()) || ((HasWeightCol) estimator).getWeightCol().isEmpty()) ? defaultWeightColumn() : functions$.MODULE$.col(((HasWeightCol) estimator).getWeightCol());
        Column defaultBaseMarginColumn = (!estimator.isDefined(((HasBaseMarginCol) estimator).baseMarginCol()) || ((HasBaseMarginCol) estimator).getBaseMarginCol().isEmpty()) ? defaultBaseMarginColumn() : functions$.MODULE$.col(((HasBaseMarginCol) estimator).getBaseMarginCol());
        if (estimator instanceof XGBoostRegressor) {
            XGBoostRegressor xGBoostRegressor = (XGBoostRegressor) estimator;
            some = new Some((!xGBoostRegressor.isDefined(xGBoostRegressor.groupCol()) || xGBoostRegressor.getGroupCol().isEmpty()) ? defaultGroupColumn() : functions$.MODULE$.col(xGBoostRegressor.getGroupCol()));
        } else {
            some = None$.MODULE$;
        }
        Some some2 = some;
        Tuple2<Dataset<?>, String> vectorize = ((XGBoostEstimatorCommon) estimator).vectorize(dataset);
        if (vectorize == null) {
            throw new MatchError(vectorize);
        }
        Tuple2 tuple2 = new Tuple2((Dataset) vectorize._1(), (String) vectorize._2());
        Tuple3 tuple3 = new Tuple3(new DataUtils.PackedParams(functions$.MODULE$.col(((HasLabelCol) estimator).getLabelCol()), functions$.MODULE$.col((String) tuple2._2()), defaultWeightColumn, defaultBaseMarginColumn, some2, ((GeneralParams) estimator).getNumWorkers(), ((XGBoostEstimatorCommon) estimator).needDeterministicRepartitioning()), (Map) ((NonParamVariables) estimator).getEvalSets(map).transform((str, dataset2) -> {
            Tuple2<Dataset<?>, String> vectorize2 = ((XGBoostEstimatorCommon) estimator).vectorize(dataset2);
            if (vectorize2 != null) {
                return (Dataset) vectorize2._1();
            }
            throw new MatchError(vectorize2);
        }, Map$.MODULE$.canBuildFrom()), (Dataset) tuple2._1());
        if (tuple3 == null) {
            throw new MatchError(tuple3);
        }
        Tuple3 tuple32 = new Tuple3((DataUtils.PackedParams) tuple3._1(), (Map) tuple3._2(), (Dataset) tuple3._3());
        DataUtils.PackedParams packedParams = (DataUtils.PackedParams) tuple32._1();
        Map map2 = (Map) tuple32._2();
        RDD rdd = (RDD) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(DataUtils$.MODULE$.convertDataFrameToXGBLabeledPointRDDs(packedParams, Predef$.MODULE$.wrapRefArray(new Dataset[]{(Dataset) tuple32._3()})))).head();
        Map map3 = (Map) map2.map(tuple22 -> {
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            return new Tuple2((String) tuple22._1(), new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(DataUtils$.MODULE$.convertDataFrameToXGBLabeledPointRDDs(packedParams, Predef$.MODULE$.wrapRefArray(new Dataset[]{(Dataset) tuple22._2()})))).head());
        }, Map$.MODULE$.canBuildFrom());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(packedParams.group().map(column -> {
            return BoxesRunTime.boxToBoolean($anonfun$buildDatasetToRDD$3(column));
        }).getOrElse(() -> {
            return false;
        }));
        return xGBoostExecutionParams -> {
            Tuple2 tuple23;
            Left composeInputData = MODULE$.composeInputData(rdd, unboxToBoolean, packedParams.numWorkers());
            if (composeInputData instanceof Left) {
                RDD<LabeledPoint[]> rdd2 = (RDD) composeInputData.value();
                tuple23 = new Tuple2(MODULE$.trainForRanking(rdd2, xGBoostExecutionParams, map3), xGBoostExecutionParams.cacheTrainingSet() ? new Some(rdd2.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK())) : None$.MODULE$);
            } else {
                if (!(composeInputData instanceof Right)) {
                    throw new MatchError(composeInputData);
                }
                RDD<LabeledPoint> rdd3 = (RDD) ((Right) composeInputData).value();
                tuple23 = new Tuple2(MODULE$.trainForNonRanking(rdd3, xGBoostExecutionParams, map3), xGBoostExecutionParams.cacheTrainingSet() ? new Some(rdd3.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK())) : None$.MODULE$);
            }
            return tuple23;
        };
    }

    @Override // ml.dmlc.xgboost4j.scala.spark.PreXGBoostProvider
    public Dataset<Row> transformDataset(Model<?> model, Dataset<?> dataset) {
        Tuple9 tuple9;
        if (optionProvider().isDefined() && ((PreXGBoostProvider) optionProvider().get()).providerEnabled(new Some(dataset))) {
            return ((PreXGBoostProvider) optionProvider().get()).transformDataset(model, dataset);
        }
        if (model instanceof XGBoostClassificationModel) {
            XGBoostClassificationModel xGBoostClassificationModel = (XGBoostClassificationModel) model;
            Tuple2<Dataset<?>, String> vectorize = xGBoostClassificationModel.vectorize(dataset);
            if (vectorize == null) {
                throw new MatchError(vectorize);
            }
            Tuple2 tuple2 = new Tuple2((Dataset) vectorize._1(), (String) vectorize._2());
            Dataset dataset2 = (Dataset) tuple2._1();
            String str = (String) tuple2._2();
            Function3 function3 = (booster, dMatrix, iterator) -> {
                Iterator<Row>[] producePredictionItrs = xGBoostClassificationModel.producePredictionItrs(booster, dMatrix);
                Option unapplySeq = Array$.MODULE$.unapplySeq(producePredictionItrs);
                if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(4) != 0) {
                    throw new MatchError(producePredictionItrs);
                }
                Tuple4 tuple4 = new Tuple4((Iterator) ((SeqLike) unapplySeq.get()).apply(0), (Iterator) ((SeqLike) unapplySeq.get()).apply(1), (Iterator) ((SeqLike) unapplySeq.get()).apply(2), (Iterator) ((SeqLike) unapplySeq.get()).apply(3));
                return xGBoostClassificationModel.produceResultIterator(iterator, (Iterator) tuple4._1(), (Iterator) tuple4._2(), (Iterator) tuple4._3(), (Iterator) tuple4._4());
            };
            StructType structType = new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset2.schema().fields())).$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new StructField[]{new StructField(XGBoostClassificationModel$.MODULE$._rawPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4())})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))))).$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new StructField[]{new StructField(XGBoostClassificationModel$.MODULE$._probabilityCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4())})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
            if (xGBoostClassificationModel.isDefined(xGBoostClassificationModel.leafPredictionCol())) {
                structType = structType.add(new StructField(xGBoostClassificationModel.getLeafPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            if (xGBoostClassificationModel.isDefined(xGBoostClassificationModel.contribPredictionCol())) {
                structType = structType.add(new StructField(xGBoostClassificationModel.getContribPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            tuple9 = new Tuple9(xGBoostClassificationModel._booster(), BoxesRunTime.boxToInteger(xGBoostClassificationModel.getInferBatchSize()), dataset2, str, BoxesRunTime.boxToBoolean(xGBoostClassificationModel.getUseExternalMemory()), BoxesRunTime.boxToFloat(xGBoostClassificationModel.getMissing()), BoxesRunTime.boxToBoolean(xGBoostClassificationModel.getAllowNonZeroForMissingValue()), function3, structType);
        } else {
            if (!(model instanceof XGBoostRegressionModel)) {
                throw new MatchError(model);
            }
            XGBoostRegressionModel xGBoostRegressionModel = (XGBoostRegressionModel) model;
            Tuple2<Dataset<?>, String> vectorize2 = xGBoostRegressionModel.vectorize(dataset);
            if (vectorize2 == null) {
                throw new MatchError(vectorize2);
            }
            Tuple2 tuple22 = new Tuple2((Dataset) vectorize2._1(), (String) vectorize2._2());
            Dataset dataset3 = (Dataset) tuple22._1();
            String str2 = (String) tuple22._2();
            Function3 function32 = (booster2, dMatrix2, iterator2) -> {
                Iterator<Row>[] producePredictionItrs = xGBoostRegressionModel.producePredictionItrs(booster2, dMatrix2);
                Option unapplySeq = Array$.MODULE$.unapplySeq(producePredictionItrs);
                if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
                    throw new MatchError(producePredictionItrs);
                }
                Tuple3 tuple3 = new Tuple3((Iterator) ((SeqLike) unapplySeq.get()).apply(0), (Iterator) ((SeqLike) unapplySeq.get()).apply(1), (Iterator) ((SeqLike) unapplySeq.get()).apply(2));
                return xGBoostRegressionModel.produceResultIterator(iterator2, (Iterator) tuple3._1(), (Iterator) tuple3._2(), (Iterator) tuple3._3());
            };
            StructType structType2 = new StructType((StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset3.schema().fields())).$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new StructField[]{new StructField(XGBoostRegressionModel$.MODULE$._originalPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4())})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
            if (xGBoostRegressionModel.isDefined(xGBoostRegressionModel.leafPredictionCol())) {
                structType2 = structType2.add(new StructField(xGBoostRegressionModel.getLeafPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            if (xGBoostRegressionModel.isDefined(xGBoostRegressionModel.contribPredictionCol())) {
                structType2 = structType2.add(new StructField(xGBoostRegressionModel.getContribPredictionCol(), new ArrayType(FloatType$.MODULE$, false), false, StructField$.MODULE$.apply$default$4()));
            }
            tuple9 = new Tuple9(xGBoostRegressionModel._booster(), BoxesRunTime.boxToInteger(xGBoostRegressionModel.getInferBatchSize()), dataset3, str2, BoxesRunTime.boxToBoolean(xGBoostRegressionModel.getUseExternalMemory()), BoxesRunTime.boxToFloat(xGBoostRegressionModel.getMissing()), BoxesRunTime.boxToBoolean(xGBoostRegressionModel.getAllowNonZeroForMissingValue()), function32, structType2);
        }
        Tuple9 tuple92 = tuple9;
        if (tuple92 == null) {
            throw new MatchError(tuple92);
        }
        Booster booster3 = (Booster) tuple92._1();
        int unboxToInt = BoxesRunTime.unboxToInt(tuple92._2());
        Tuple9 tuple93 = new Tuple9(booster3, BoxesRunTime.boxToInteger(unboxToInt), (Dataset) tuple92._3(), (String) tuple92._4(), BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToBoolean(tuple92._5())), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(tuple92._6())), BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToBoolean(tuple92._7())), (Function3) tuple92._8(), (StructType) tuple92._9());
        Booster booster4 = (Booster) tuple93._1();
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple93._2());
        Dataset dataset4 = (Dataset) tuple93._3();
        String str3 = (String) tuple93._4();
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(tuple93._5());
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple93._6());
        boolean unboxToBoolean2 = BoxesRunTime.unboxToBoolean(tuple93._7());
        Function3 function33 = (Function3) tuple93._8();
        StructType structType3 = (StructType) tuple93._9();
        Broadcast broadcast = dataset4.sparkSession().sparkContext().broadcast(booster4, ClassTag$.MODULE$.apply(Booster.class));
        String appName = dataset4.sparkSession().sparkContext().appName();
        RDD rdd = dataset4.rdd();
        RDD mapPartitions = rdd.mapPartitions(iterator3 -> {
            return new AbstractIterator<Row>(iterator3, unboxToInt2, str3, unboxToBoolean, appName, unboxToFloat, unboxToBoolean2, function33, broadcast) { // from class: ml.dmlc.xgboost4j.scala.spark.PreXGBoost$$anon$1
                private int batchCnt = 0;
                private final Iterator<Row> batchIterImpl;
                private final String featuresCol$1;
                private final boolean useExternalMemory$1;
                private final String appName$1;
                private final float missing$1;
                private final boolean allowNonZeroForMissing$1;
                private final Function3 predictFunc$1;
                private final Broadcast bBooster$1;

                private int batchCnt() {
                    return this.batchCnt;
                }

                private void batchCnt_$eq(int i) {
                    this.batchCnt = i;
                }

                private Iterator<Row> batchIterImpl() {
                    return this.batchIterImpl;
                }

                public boolean hasNext() {
                    return batchIterImpl().hasNext();
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Row m34next() {
                    Row row = (Row) batchIterImpl().next();
                    if (!batchIterImpl().hasNext()) {
                        Communicator.shutdown();
                    }
                    return row;
                }

                {
                    this.featuresCol$1 = str3;
                    this.useExternalMemory$1 = unboxToBoolean;
                    this.appName$1 = appName;
                    this.missing$1 = unboxToFloat;
                    this.allowNonZeroForMissing$1 = unboxToBoolean2;
                    this.predictFunc$1 = function33;
                    this.bBooster$1 = broadcast;
                    this.batchIterImpl = iterator3.grouped(unboxToInt2).flatMap(seq -> {
                        if (this.batchCnt() == 0) {
                            Communicator.init((java.util.Map) JavaConverters$.MODULE$.mapAsJavaMapConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("DMLC_TASK_ID"), BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString())})).toMap(Predef$.MODULE$.$conforms())).asJava());
                        }
                        DMatrix dMatrix3 = new DMatrix(DataUtils$.MODULE$.processMissingValues(seq.iterator().map(row -> {
                            return (Vector) row.getAs(this.featuresCol$1);
                        }).map(vector -> {
                            return DataUtils$MLVectorToXGBLabeledPoint$.MODULE$.asXGB$extension(DataUtils$.MODULE$.MLVectorToXGBLabeledPoint(vector));
                        }), this.missing$1, this.allowNonZeroForMissing$1), this.useExternalMemory$1 ? new StringBuilder(14).append(this.appName$1).append("-").append(TaskContext$.MODULE$.get().stageId()).append("-dtest_cache-").append(new StringBuilder(7).append(TaskContext$.MODULE$.getPartitionId()).append("-batch-").append(this.batchCnt()).toString()).toString() : null);
                        try {
                            return (Iterator) this.predictFunc$1.apply(this.bBooster$1.value(), dMatrix3, seq.iterator());
                        } finally {
                            this.batchCnt_$eq(this.batchCnt() + 1);
                            dMatrix3.delete();
                        }
                    });
                }
            };
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Row.class));
        broadcast.unpersist(false);
        return dataset4.sparkSession().createDataFrame(mapPartitions, structType3);
    }

    public Function1<XGBoostExecutionParams, Tuple2<RDD<Function0<Watches>>, Option<RDD<?>>>> buildRDDLabeledPointToRDDWatches(RDD<LabeledPoint> rdd, Map<String, RDD<LabeledPoint>> map, boolean z) {
        return xGBoostExecutionParams -> {
            Tuple2 tuple2;
            Left composeInputData = MODULE$.composeInputData(rdd, z, xGBoostExecutionParams.numWorkers());
            if (composeInputData instanceof Left) {
                RDD<LabeledPoint[]> rdd2 = (RDD) composeInputData.value();
                tuple2 = new Tuple2(MODULE$.trainForRanking(rdd2, xGBoostExecutionParams, map), xGBoostExecutionParams.cacheTrainingSet() ? new Some(rdd2.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK())) : None$.MODULE$);
            } else {
                if (!(composeInputData instanceof Right)) {
                    throw new MatchError(composeInputData);
                }
                RDD<LabeledPoint> rdd3 = (RDD) ((Right) composeInputData).value();
                tuple2 = new Tuple2(MODULE$.trainForNonRanking(rdd3, xGBoostExecutionParams, map), xGBoostExecutionParams.cacheTrainingSet() ? new Some(rdd3.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK())) : None$.MODULE$);
            }
            return tuple2;
        };
    }

    public Map<String, RDD<LabeledPoint>> buildRDDLabeledPointToRDDWatches$default$2() {
        return Predef$.MODULE$.Map().apply(Nil$.MODULE$);
    }

    public boolean buildRDDLabeledPointToRDDWatches$default$3() {
        return false;
    }

    private Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData(RDD<LabeledPoint> rdd, boolean z, int i) {
        return z ? scala.package$.MODULE$.Left().apply(repartitionForTrainingGroup(rdd, i)) : scala.package$.MODULE$.Right().apply(rdd);
    }

    public RDD<LabeledPoint[]> repartitionForTrainingGroup(RDD<LabeledPoint> rdd, int i) {
        RDD<LabeledPoint[]> aggByGroupInfo = aggByGroupInfo(rdd);
        logger().info(new StringBuilder(48).append("repartitioning training group set to ").append(i).append(" partitions").toString());
        return aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i));
    }

    private RDD<Function0<Watches>> trainForRanking(RDD<LabeledPoint[]> rdd, XGBoostExecutionParams xGBoostExecutionParams, Map<String, RDD<LabeledPoint>> map) {
        if (map.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return Iterator$.MODULE$.single(() -> {
                    return Watches$.MODULE$.buildWatchesWithGroup(xGBoostExecutionParams, DataUtils$.MODULE$.processMissingValuesWithGroup(iterator, xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()));
                });
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets = coPartitionGroupSets(rdd, map, xGBoostExecutionParams.numWorkers());
        return coPartitionGroupSets.mapPartitions(iterator2 -> {
            return Iterator$.MODULE$.single(() -> {
                return Watches$.MODULE$.buildWatchesWithGroup(iterator2.map(tuple2 -> {
                    if (tuple2 == null) {
                        throw new MatchError(tuple2);
                    }
                    return new Tuple2((String) tuple2._1(), DataUtils$.MODULE$.processMissingValuesWithGroup((Iterator) tuple2._2(), xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()));
                }), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()));
            });
        }, coPartitionGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class)).cache();
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets(RDD<LabeledPoint[]> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus((GenTraversableOnce) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD<LabeledPoint[]> aggByGroupInfo = MODULE$.aggByGroupInfo((RDD) tuple2._2());
            return aggByGroupInfo.getNumPartitions() != i ? Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i))) : Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class)), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    private RDD<LabeledPoint[]> aggByGroupInfo(RDD<LabeledPoint> rdd) {
        return rdd.mapPartitions(iterator -> {
            return new LabeledPointGroupIterator(iterator);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup -> {
            return BoxesRunTime.boxToBoolean($anonfun$aggByGroupInfo$2(xGBLabeledPointGroup));
        }).map(xGBLabeledPointGroup2 -> {
            return xGBLabeledPointGroup2.points();
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))).union(rdd.mapPartitions(iterator2 -> {
            return new LabeledPointGroupIterator(iterator2);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup3 -> {
            return BoxesRunTime.boxToBoolean(xGBLabeledPointGroup3.isEdgeGroup());
        }).map(xGBLabeledPointGroup4 -> {
            return new Tuple2(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()), xGBLabeledPointGroup4);
        }, ClassTag$.MODULE$.apply(Tuple2.class)).groupBy(tuple2 -> {
            return BoxesRunTime.boxToInteger($anonfun$aggByGroupInfo$7(tuple2));
        }, ClassTag$.MODULE$.Int()).map(tuple22 -> {
            return (LabeledPoint[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((Iterable) tuple22._2()).toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).sortBy(tuple22 -> {
                return BoxesRunTime.boxToInteger(tuple22._1$mcI$sp());
            }, Ordering$Int$.MODULE$))).flatMap(tuple23 -> {
                return new ArrayOps.ofRef($anonfun$aggByGroupInfo$10(tuple23));
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(LabeledPoint.class)));
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))));
    }

    private RDD<Function0<Watches>> trainForNonRanking(RDD<LabeledPoint> rdd, XGBoostExecutionParams xGBoostExecutionParams, Map<String, RDD<LabeledPoint>> map) {
        if (map.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return Iterator$.MODULE$.single(() -> {
                    return Watches$.MODULE$.buildWatches(xGBoostExecutionParams, DataUtils$.MODULE$.processMissingValues(iterator, xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()));
                });
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets = coPartitionNoGroupSets(rdd, map, xGBoostExecutionParams.numWorkers());
        return coPartitionNoGroupSets.mapPartitions(iterator2 -> {
            return Iterator$.MODULE$.single(() -> {
                return Watches$.MODULE$.buildWatches(iterator2.map(tuple2 -> {
                    if (tuple2 == null) {
                        throw new MatchError(tuple2);
                    }
                    return new Tuple2((String) tuple2._1(), DataUtils$.MODULE$.processMissingValues((Iterator) tuple2._2(), xGBoostExecutionParams.missing(), xGBoostExecutionParams.allowNonZeroForMissing()));
                }), MODULE$.getCacheDirName(xGBoostExecutionParams.useExternalMemory()));
            });
        }, coPartitionNoGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Function0.class)).cache();
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets(RDD<LabeledPoint> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) ((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD rdd2 = (RDD) tuple2._2();
            return rdd2.getNumPartitions() != i ? new Tuple2(str, rdd2.repartition(i, rdd2.repartition$default$2(i))) : new Tuple2(str, rdd2);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(LabeledPoint.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    public Option<String> getCacheDirName(boolean z) {
        return z ? new Some(Files.createTempDirectory(new StringBuilder(7).append(TaskContext$.MODULE$.get().stageId()).append("-cache-").append(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString()).toString(), new FileAttribute[0]).toAbsolutePath().toString()) : None$.MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$optionProvider$2(PreXGBoostProvider preXGBoostProvider) {
        return preXGBoostProvider.getClass().getName().equals("ml.dmlc.xgboost4j.scala.rapids.spark.GpuPreXGBoost");
    }

    public static final /* synthetic */ boolean $anonfun$buildDatasetToRDD$3(Column column) {
        Column defaultGroupColumn = MODULE$.defaultGroupColumn();
        return column != null ? !column.equals(defaultGroupColumn) : defaultGroupColumn != null;
    }

    public static final /* synthetic */ boolean $anonfun$aggByGroupInfo$2(XGBLabeledPointGroup xGBLabeledPointGroup) {
        return !xGBLabeledPointGroup.isEdgeGroup();
    }

    public static final /* synthetic */ int $anonfun$aggByGroupInfo$7(Tuple2 tuple2) {
        return ((XGBLabeledPointGroup) tuple2._2()).groupId();
    }

    public static final /* synthetic */ Object[] $anonfun$aggByGroupInfo$10(Tuple2 tuple2) {
        return Predef$.MODULE$.refArrayOps(((XGBLabeledPointGroup) tuple2._2()).points());
    }

    private PreXGBoost$() {
        None$ none$;
        MODULE$ = this;
        PreXGBoostProvider.$init$(this);
        this.logger = LogFactory.getLog("XGBoostSpark");
        $colon.colon list = ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.iterableAsScalaIterableConverter(ServiceLoader.load(PreXGBoostProvider.class, (ClassLoader) Option$.MODULE$.apply(Thread.currentThread().getContextClassLoader()).getOrElse(() -> {
            return MODULE$.getClass().getClassLoader();
        }))).asScala()).filter(preXGBoostProvider -> {
            return BoxesRunTime.boxToBoolean($anonfun$optionProvider$2(preXGBoostProvider));
        })).toList();
        if (Nil$.MODULE$.equals(list)) {
            none$ = None$.MODULE$;
        } else {
            if (list instanceof $colon.colon) {
                $colon.colon colonVar = list;
                PreXGBoostProvider preXGBoostProvider2 = (PreXGBoostProvider) colonVar.head();
                if (Nil$.MODULE$.equals(colonVar.tl$access$1())) {
                    none$ = new Some(preXGBoostProvider2);
                }
            }
            none$ = None$.MODULE$;
        }
        this.optionProvider = none$;
    }
}
