package com.microsoft.azure.synapse.ml.lightgbm;

import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster;
import com.microsoft.azure.synapse.ml.lightgbm.dataset.LightGBMDataset;
import com.microsoft.azure.synapse.ml.lightgbm.params.BaseTrainParams;
import com.microsoft.azure.synapse.ml.lightgbm.params.FObjTrait;
import java.io.Serializable;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TrainUtils.scala */
/* loaded from: input_file:com/microsoft/azure/synapse/ml/lightgbm/TrainUtils$.class */
public final class TrainUtils$ implements Serializable {
    public static TrainUtils$ MODULE$;

    static {
        new TrainUtils$();
    }

    public LightGBMBooster createBooster(BaseTrainParams baseTrainParams, LightGBMDataset lightGBMDataset, Option<LightGBMDataset> option) {
        LightGBMBooster lightGBMBooster = new LightGBMBooster(lightGBMDataset, baseTrainParams.toString());
        baseTrainParams.generalParams().modelString().foreach(str -> {
            lightGBMBooster.mergeBooster(str);
            return BoxedUnit.UNIT;
        });
        option.foreach(lightGBMDataset2 -> {
            lightGBMBooster.addValidationDataset(lightGBMDataset2);
            return BoxedUnit.UNIT;
        });
        return lightGBMBooster;
    }

    public void beforeTrainIteration(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        if (partitionTaskTrainingState.ctx().trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) partitionTaskTrainingState.ctx().trainingParams().delegate().get()).beforeTrainIteration(partitionTaskTrainingState.ctx().trainingCtx().batchIndex(), partitionTaskTrainingState.ctx().partitionId(), partitionTaskTrainingState.iteration(), logger, partitionTaskTrainingState.ctx().trainingParams(), partitionTaskTrainingState.booster(), partitionTaskTrainingState.ctx().trainingCtx().hasValidationData());
        }
    }

    public void afterTrainIteration(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger, Option<Map<String, Object>> option, Option<Map<String, Object>> option2) {
        PartitionTaskContext ctx = partitionTaskTrainingState.ctx();
        TrainingContext trainingCtx = ctx.trainingCtx();
        if (ctx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) ctx.trainingParams().delegate().get()).afterTrainIteration(trainingCtx.batchIndex(), ctx.partitionId(), partitionTaskTrainingState.iteration(), logger, trainingCtx.trainingParams(), partitionTaskTrainingState.booster(), trainingCtx.hasValidationData(), partitionTaskTrainingState.isFinished(), option, option2);
        }
    }

    public double getLearningRate(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        double learningRate;
        Some delegate = partitionTaskTrainingState.ctx().trainingParams().delegate();
        if (delegate instanceof Some) {
            learningRate = ((LightGBMDelegate) delegate.value()).getLearningRate(partitionTaskTrainingState.ctx().trainingCtx().batchIndex(), partitionTaskTrainingState.ctx().partitionId(), partitionTaskTrainingState.iteration(), logger, partitionTaskTrainingState.ctx().trainingParams(), partitionTaskTrainingState.learningRate());
        } else {
            if (!None$.MODULE$.equals(delegate)) {
                throw new MatchError(delegate);
            }
            learningRate = partitionTaskTrainingState.learningRate();
        }
        return learningRate;
    }

    public void updateOneIteration(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        try {
            logger.debug(new StringBuilder(28).append("LightGBM running iteration: ").append(partitionTaskTrainingState.iteration()).toString());
            Option<FObjTrait> fobj = partitionTaskTrainingState.ctx().trainingParams().objectiveParams().fobj();
            if (fobj.isDefined()) {
                Tuple2<float[], float[]> gradient = ((FObjTrait) fobj.get()).getGradient(partitionTaskTrainingState.booster().innerPredict(0, partitionTaskTrainingState.ctx().trainingCtx().isClassification()), (LightGBMDataset) partitionTaskTrainingState.booster().trainDataset().get());
                if (gradient == null) {
                    throw new MatchError(gradient);
                }
                Tuple2 tuple2 = new Tuple2((float[]) gradient._1(), (float[]) gradient._2());
                partitionTaskTrainingState.isFinished_$eq(partitionTaskTrainingState.booster().updateOneIterationCustom((float[]) tuple2._1(), (float[]) tuple2._2()));
            } else {
                partitionTaskTrainingState.isFinished_$eq(partitionTaskTrainingState.booster().updateOneIteration());
            }
        } catch (Exception e) {
            logger.warn(new StringBuilder(126).append("LightGBM reached early termination on one task, stopping training on task. This message should rarely occur. Inner exception: ").append(e.toString()).toString());
            partitionTaskTrainingState.isFinished_$eq(true);
        }
    }

    public Option<Object> executeTrainingIterations(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        logger.info(new StringBuilder(60).append("Beginning training on LightGBM Booster for task ").append(partitionTaskTrainingState.ctx().taskId()).append(", partition ").append(partitionTaskTrainingState.ctx().partitionId()).toString());
        partitionTaskTrainingState.ctx().measures().markTrainingIterationsStart();
        Option<Object> iterationLoop$1 = iterationLoop$1(partitionTaskTrainingState.ctx().trainingParams().generalParams().numIterations(), partitionTaskTrainingState, logger);
        partitionTaskTrainingState.ctx().measures().markTrainingIterationsStop();
        return iterationLoop$1;
    }

    public Option<Map<String, Object>> getTrainEvalResults(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        Tuple2<String, Object>[] evalResults = partitionTaskTrainingState.booster().getEvalResults(partitionTaskTrainingState.evalNames(), 0);
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(evalResults)).foreach(tuple2 -> {
            $anonfun$getTrainEvalResults$1(logger, tuple2);
            return BoxedUnit.UNIT;
        });
        return Option$.MODULE$.apply(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(evalResults)));
    }

    public Option<Map<String, Object>> getValidEvalResults(PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        Tuple2<String, Object>[] evalResults = partitionTaskTrainingState.booster().getEvalResults(partitionTaskTrainingState.evalNames(), 1);
        return Option$.MODULE$.apply(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(evalResults)).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
            if (tuple2 != null) {
                Tuple2 tuple2 = (Tuple2) tuple2._1();
                int _2$mcI$sp = tuple2._2$mcI$sp();
                if (tuple2 != null) {
                    String str = (String) tuple2._1();
                    double _2$mcD$sp = tuple2._2$mcD$sp();
                    logger.info(new StringBuilder(7).append("Valid ").append(str).append("=").append(_2$mcD$sp).toString());
                    Function3 function3 = (str.startsWith("auc") || str.startsWith("ndcg@") || str.startsWith("map@") || str.startsWith("average_precision")) ? (obj, obj2, obj3) -> {
                        return BoxesRunTime.boxToBoolean($anonfun$getValidEvalResults$2(BoxesRunTime.unboxToDouble(obj), BoxesRunTime.unboxToDouble(obj2), BoxesRunTime.unboxToDouble(obj3)));
                    } : (obj4, obj5, obj6) -> {
                        return BoxesRunTime.boxToBoolean($anonfun$getValidEvalResults$3(BoxesRunTime.unboxToDouble(obj4), BoxesRunTime.unboxToDouble(obj5), BoxesRunTime.unboxToDouble(obj6)));
                    };
                    if (partitionTaskTrainingState.bestScores()[_2$mcI$sp] == null || BoxesRunTime.unboxToBoolean(function3.apply(BoxesRunTime.boxToDouble(_2$mcD$sp), BoxesRunTime.boxToDouble(partitionTaskTrainingState.bestScore()[_2$mcI$sp]), BoxesRunTime.boxToDouble(partitionTaskTrainingState.ctx().trainingCtx().improvementTolerance())))) {
                        partitionTaskTrainingState.bestScore()[_2$mcI$sp] = _2$mcD$sp;
                        partitionTaskTrainingState.bestIteration()[_2$mcI$sp] = partitionTaskTrainingState.iteration();
                        partitionTaskTrainingState.bestScores()[_2$mcI$sp] = (double[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(evalResults)).map(tuple22 -> {
                            return BoxesRunTime.boxToDouble(tuple22._2$mcD$sp());
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
                    } else if (partitionTaskTrainingState.iteration() - partitionTaskTrainingState.bestIteration()[_2$mcI$sp] >= partitionTaskTrainingState.ctx().trainingCtx().earlyStoppingRound()) {
                        partitionTaskTrainingState.isFinished_$eq(true);
                        logger.info(new StringBuilder(34).append("Early stopping, best iteration is ").append(partitionTaskTrainingState.bestIteration()[_2$mcI$sp]).toString());
                        partitionTaskTrainingState.bestIterationResult_$eq(new Some(BoxesRunTime.boxToInteger(partitionTaskTrainingState.bestIteration()[_2$mcI$sp])));
                    }
                    return new Tuple2(str, BoxesRunTime.boxToDouble(_2$mcD$sp));
                }
            }
            throw new MatchError(tuple2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))));
    }

    public void beforeGenerateTrainDataset(PartitionTaskContext partitionTaskContext, Logger logger) {
        TrainingContext trainingCtx = partitionTaskContext.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) trainingCtx.trainingParams().delegate().get()).beforeGenerateTrainDataset(trainingCtx.batchIndex(), partitionTaskContext.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), logger, trainingCtx.trainingParams());
        }
    }

    public void afterGenerateTrainDataset(PartitionTaskContext partitionTaskContext, Logger logger) {
        TrainingContext trainingCtx = partitionTaskContext.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) trainingCtx.trainingParams().delegate().get()).afterGenerateTrainDataset(trainingCtx.batchIndex(), partitionTaskContext.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), logger, trainingCtx.trainingParams());
        }
    }

    public void beforeGenerateValidDataset(PartitionTaskContext partitionTaskContext, Logger logger) {
        TrainingContext trainingCtx = partitionTaskContext.trainingCtx();
        if (partitionTaskContext.trainingCtx().trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) trainingCtx.trainingParams().delegate().get()).beforeGenerateValidDataset(trainingCtx.batchIndex(), partitionTaskContext.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), logger, trainingCtx.trainingParams());
        }
    }

    public void afterGenerateValidDataset(PartitionTaskContext partitionTaskContext, Logger logger) {
        TrainingContext trainingCtx = partitionTaskContext.trainingCtx();
        if (trainingCtx.trainingParams().delegate().isDefined()) {
            ((LightGBMDelegate) trainingCtx.trainingParams().delegate().get()).afterGenerateValidDataset(trainingCtx.batchIndex(), partitionTaskContext.partitionId(), trainingCtx.columnParams(), trainingCtx.schema(), logger, trainingCtx.trainingParams());
        }
    }

    private Object readResolve() {
        return MODULE$;
    }

    private final Option iterationLoop$1(int i, PartitionTaskTrainingState partitionTaskTrainingState, Logger logger) {
        while (true) {
            beforeTrainIteration(partitionTaskTrainingState, logger);
            double learningRate = getLearningRate(partitionTaskTrainingState, logger);
            if (learningRate != partitionTaskTrainingState.learningRate()) {
                logger.info(new StringBuilder(86).append("LightGBM task calling booster.resetParameter to reset learningRate").append(" (newLearningRate: ").append(learningRate).append(")").toString());
                partitionTaskTrainingState.booster().resetParameter(new StringBuilder(14).append("learning_rate=").append(learningRate).toString());
                partitionTaskTrainingState.learningRate_$eq(learningRate);
            }
            logger.info(new StringBuilder(33).append("LightGBM task starting iteration ").append(partitionTaskTrainingState.iteration()).toString());
            updateOneIteration(partitionTaskTrainingState, logger);
            afterTrainIteration(partitionTaskTrainingState, logger, (!partitionTaskTrainingState.ctx().trainingCtx().isProvideTrainingMetric() || partitionTaskTrainingState.isFinished()) ? None$.MODULE$ : getTrainEvalResults(partitionTaskTrainingState, logger), (!partitionTaskTrainingState.ctx().trainingCtx().hasValidationData() || partitionTaskTrainingState.isFinished()) ? None$.MODULE$ : getValidEvalResults(partitionTaskTrainingState, logger));
            partitionTaskTrainingState.iteration_$eq(partitionTaskTrainingState.iteration() + 1);
            if (partitionTaskTrainingState.isFinished() || partitionTaskTrainingState.iteration() >= i) {
                break;
            }
            i = i;
        }
        return partitionTaskTrainingState.bestIterationResult();
    }

    public static final /* synthetic */ void $anonfun$getTrainEvalResults$1(Logger logger, Tuple2 tuple2) {
        if (tuple2 != null) {
            String str = (String) tuple2._1();
            double _2$mcD$sp = tuple2._2$mcD$sp();
            if (str != null) {
                logger.info(new StringBuilder(7).append("Train ").append(str).append("=").append(_2$mcD$sp).toString());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                return;
            }
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ boolean $anonfun$getValidEvalResults$2(double d, double d2, double d3) {
        return d - d2 > d3;
    }

    public static final /* synthetic */ boolean $anonfun$getValidEvalResults$3(double d, double d2, double d3) {
        return d - d2 < d3;
    }

    private TrainUtils$() {
        MODULE$ = this;
    }
}
