package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.ensemble.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Objects;
import java.util.stream.Stream;
import water.DKV;
import water.Job;
import water.Key;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/ensemble/StackedEnsemble.class */
public class StackedEnsemble extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsemble$StackedEnsembleBlendingDriver.class */
    public class StackedEnsembleBlendingDriver extends StackedEnsembleDriver {
        private StackedEnsembleBlendingDriver() {
            super();
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.blending;
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).blending();
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected Frame getPredictionsForBaseModel(Model model, Frame frame, boolean z) {
            return buildPredictionsForBaseModel(model, frame);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsemble$StackedEnsembleCVStackingDriver.class */
    public class StackedEnsembleCVStackingDriver extends StackedEnsembleDriver {
        private StackedEnsembleCVStackingDriver() {
            super();
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.cross_validation;
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).train();
        }

        @Override // hex.ensemble.StackedEnsemble.StackedEnsembleDriver
        protected Frame getPredictionsForBaseModel(Model model, Frame frame, boolean z) {
            Frame buildPredictionsForBaseModel;
            if (!z) {
                buildPredictionsForBaseModel = buildPredictionsForBaseModel(model, frame);
            } else {
                if (null == model._output._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                buildPredictionsForBaseModel = (Frame) DKV.getGet(model._output._cross_validation_holdout_predictions_frame_id);
                if (null == buildPredictionsForBaseModel) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");
                }
            }
            return buildPredictionsForBaseModel;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsemble$StackedEnsembleDriver.class */
    public abstract class StackedEnsembleDriver extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput>.Driver {
        private StackedEnsembleDriver() {
            super(StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(String str, Model[] modelArr, Frame[] frameArr, Frame frame) {
            StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform metalearnerTransform;
            if (null == modelArr) {
                throw new H2OIllegalArgumentException("Base models array is null.");
            }
            if (null == frameArr) {
                throw new H2OIllegalArgumentException("Base model predictions array is null.");
            }
            if (modelArr.length == 0) {
                throw new H2OIllegalArgumentException("Base models array is empty.");
            }
            if (frameArr.length == 0) {
                throw new H2OIllegalArgumentException("Base model predictions array is empty.");
            }
            if (modelArr.length != frameArr.length) {
                throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._metalearner_transform == null || ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._metalearner_transform == StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.NONE) {
                metalearnerTransform = null;
            } else {
                if (!((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output).isBinomialClassifier() && !((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output).isMultinomialClassifier()) {
                    throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
                }
                metalearnerTransform = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._metalearner_transform;
            }
            if (null == str) {
                str = "levelone_" + StackedEnsemble.this._model._key.toString() + "_" + ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._metalearner_transform.toString();
            }
            Frame get = DKV.getGet(str);
            if (get != null && (get instanceof Frame)) {
                get.write_lock(StackedEnsemble.this._job);
                get.removeAll();
                get.update(StackedEnsemble.this._job);
                get.unlock(StackedEnsemble.this._job);
            }
            Frame frame2 = metalearnerTransform == null ? new Frame(Key.make(str)) : new Frame(new Vec[0]);
            for (int i = 0; i < modelArr.length; i++) {
                Model model = modelArr[i];
                Frame frame3 = frameArr[i];
                if (null == model) {
                    Log.warn(new Object[]{"Failed to find base model; skipping: " + modelArr[i]});
                } else if (null == frame3) {
                    Log.warn(new Object[]{"Failed to find base model " + model + " predictions; skipping: " + frame3._key});
                } else {
                    StackedEnsemble.addModelPredictionsToLevelOneFrame(model, frame3, frame2);
                    Scope.untrack(frameArr);
                }
            }
            if (metalearnerTransform != null) {
                frame2 = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._metalearner_transform.transform(StackedEnsemble.this._model, frame2, Key.make(str));
            }
            StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms, frame, frame2, true);
            Log.info(new Object[]{"Finished creating \"level one\" frame for stacking: " + frame2.toString()});
            DKV.put(frame2);
            return frame2;
        }

        private Frame prepareLevelOneFrame(String str, Key<Model>[] keyArr, Frame frame, boolean z) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Key<Model> key : keyArr) {
                if (StackedEnsemble.this.stop_requested() && z) {
                    throw new Job.JobCancelledException();
                }
                if (((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._metalearner == null || StackedEnsemble.this._model.isUsefulBaseModel(key)) {
                    Model model = (Model) DKV.getGet(key);
                    if (null == model) {
                        throw new H2OIllegalArgumentException("Failed to find base model: " + key);
                    }
                    Frame predictionsForBaseModel = getPredictionsForBaseModel(model, frame, z);
                    arrayList.add(model);
                    arrayList2.add(predictionsForBaseModel);
                }
            }
            boolean z2 = z && ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._keep_levelone_frame;
            Frame prepareLevelOneFrame = prepareLevelOneFrame(str, (Model[]) arrayList.toArray(new Model[0]), (Frame[]) arrayList2.toArray(new Frame[0]), frame);
            if (z2) {
                prepareLevelOneFrame = prepareLevelOneFrame.deepCopy(prepareLevelOneFrame._key.toString());
                prepareLevelOneFrame.write_lock(StackedEnsemble.this._job);
                prepareLevelOneFrame.update(StackedEnsemble.this._job);
                prepareLevelOneFrame.unlock(StackedEnsemble.this._job);
                Scope.untrack(prepareLevelOneFrame.keysList());
            }
            return prepareLevelOneFrame;
        }

        protected Frame buildPredictionsForBaseModel(Model model, Frame frame) {
            Key<Frame> buildPredsKey = buildPredsKey(model, frame);
            Frame get = DKV.getGet(buildPredsKey);
            if (get == null) {
                get = model.score(frame, buildPredsKey.toString(), (Job) null, false);
                Scope.untrack(get.keysList());
            }
            if (((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._base_model_predictions_keys == null) {
                ((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._base_model_predictions_keys = new Key[0];
            }
            if (!ArrayUtils.contains(((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._base_model_predictions_keys, buildPredsKey)) {
                ((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._base_model_predictions_keys = (Key[]) ArrayUtils.append(((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._base_model_predictions_keys, new Key[]{buildPredsKey});
            }
            return get;
        }

        protected abstract StackedEnsembleModel.StackingStrategy strategy();

        protected abstract Frame getActualTrainingFrame();

        protected abstract Frame getPredictionsForBaseModel(Model model, Frame frame, boolean z);

        private Key<Frame> buildPredsKey(Key key, long j, Key key2, long j2) {
            return Key.make("preds_" + j + "_on_" + j2);
        }

        protected Key<Frame> buildPredsKey(Model model, Frame frame) {
            if (frame == null || model == null) {
                return null;
            }
            return buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
        }

        public void computeImpl() {
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            ((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._stacking_strategy = strategy();
            try {
                StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
                StackedEnsemble.this._model.checkAndInheritModelProperties();
                StackedEnsemble.this._model.update(StackedEnsemble.this._job);
                Frame prepareLevelOneFrame = prepareLevelOneFrame("levelone_training_" + StackedEnsemble.this._model._key.toString(), ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._base_models, getActualTrainingFrame(), true);
                Frame frame = null;
                if (((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).valid() != null) {
                    frame = prepareLevelOneFrame("levelone_validation_" + StackedEnsemble.this._model._key.toString(), ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._base_models, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).valid(), false);
                }
                Metalearner.Algorithm algorithm = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_algorithm;
                Metalearner.Algorithm actualMetalearnerAlgo = Metalearners.getActualMetalearnerAlgo(algorithm);
                if (actualMetalearnerAlgo == null) {
                    throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + algorithm + " but must be one of " + Arrays.toString(Metalearner.Algorithm.values()));
                }
                Key<Model> make = Key.make("metalearner_" + algorithm + "_" + StackedEnsemble.this._model._key);
                Job job = new Job(make, ModelBuilder.javaName(actualMetalearnerAlgo.toString()), "StackingEnsemble metalearner (" + algorithm + ")");
                boolean z = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_parameters != null;
                long j = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._seed;
                Metalearner createInstance = Metalearners.createInstance(algorithm.name());
                createInstance.init(prepareLevelOneFrame, frame, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_parameters, StackedEnsemble.this._model, StackedEnsemble.this._job, make, job, (StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms, z, j, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._max_runtime_secs == 0.0d ? 0L : Math.max(StackedEnsemble.this.remainingTimeSecs(), 1L));
                createInstance.compute();
                if (StackedEnsemble.this._model.evalAutoParamsEnabled && ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_algorithm == Metalearner.Algorithm.AUTO) {
                    ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_algorithm = actualMetalearnerAlgo;
                }
            } finally {
                StackedEnsemble.this._model.unlock(StackedEnsemble.this._job);
            }
        }
    }

    public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
        super(stackedEnsembleParameters);
        init(false);
    }

    public StackedEnsemble(boolean z) {
        super(new StackedEnsembleModel.StackedEnsembleParameters(), z);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean isSupervised() {
        return true;
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [hex.ensemble.StackedEnsemble$1] */
    protected void ignoreBadColumns(int i, boolean z) {
        final HashSet hashSet = new HashSet();
        for (Key<Model> key : ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models) {
            Model get = DKV.getGet(key);
            hashSet.add(get._parms._response_column);
            hashSet.addAll(Arrays.asList(get._parms.getNonPredictors()));
            if (get._output._origNames != null) {
                hashSet.addAll(Arrays.asList(get._output._origNames));
            } else {
                hashSet.addAll(Arrays.asList(get._output._names));
            }
        }
        hashSet.addAll(Arrays.asList(((StackedEnsembleModel.StackedEnsembleParameters) this._parms).getNonPredictors()));
        new ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput>.FilterCols(0) { // from class: hex.ensemble.StackedEnsemble.1
            protected boolean filter(Vec vec, String str) {
                return !hashSet.contains(str);
            }
        }.doIt(this._train, "Dropping unused columns: ", z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public StackedEnsembleDriver m56trainModelImpl() {
        StackedEnsembleDriver stackedEnsembleCVStackingDriver = ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._blending == null ? new StackedEnsembleCVStackingDriver() : new StackedEnsembleBlendingDriver();
        this._driver = stackedEnsembleCVStackingDriver;
        return stackedEnsembleCVStackingDriver;
    }

    public boolean haveMojo() {
        return true;
    }

    public void init(boolean z) {
        expandBaseModels();
        super.init(z);
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._distribution != DistributionFamily.AUTO) {
            throw new H2OIllegalArgumentException("Setting \"distribution\" to StackedEnsemble is unsupported. Please set it in \"metalearner_parameters\".");
        }
        checkColumnPresent("fold", ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._metalearner_fold_column, train(), valid(), ((StackedEnsembleModel.StackedEnsembleParameters) this._parms).blending());
        checkColumnPresent("weights", ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._weights_column, train(), valid(), ((StackedEnsembleModel.StackedEnsembleParameters) this._parms).blending());
        checkColumnPresent("offset", ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._offset_column, train(), valid(), ((StackedEnsembleModel.StackedEnsembleParameters) this._parms).blending());
        validateBaseModels();
    }

    private void expandBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models == null) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (Key<Model> key : ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models) {
            Grid get = DKV.getGet(key);
            if (get instanceof Model) {
                arrayList.add(key);
            } else {
                if (!(get instanceof Grid)) {
                    if (get != null) {
                        throw new IllegalArgumentException(String.format("Unsupported type \"%s\" as a base model.", get.getClass().toString()));
                    }
                    throw new IllegalArgumentException(String.format("Specified id \"%s\" does not exist.", key));
                }
                Collections.addAll(arrayList, get.getModelKeys());
            }
        }
        ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models = (Key[]) arrayList.toArray(new Key[0]);
    }

    private void validateBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models == null) {
            return;
        }
        boolean z = true;
        String str = null;
        for (int i = 0; i < ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models.length; i++) {
            Model get = DKV.getGet(((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models[i]);
            if (i == 0) {
                if (((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._offset_column == null) {
                    ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._offset_column = get._parms._offset_column;
                }
                str = get._parms._weights_column;
                z = str != null;
            }
            if (!Objects.equals(str, get._parms._weights_column)) {
                z = false;
            }
            if (!Objects.equals(((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._offset_column, get._parms._offset_column)) {
                throw new IllegalArgumentException("All base models must have the same offset_column!");
            }
        }
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._weights_column == null && z && ((StackedEnsembleModel.StackedEnsembleParameters) this._parms)._base_models.length > 0) {
            warn("_weights_column", "All base models use weights_column=\"" + str + "\" but Stacked Ensemble does not. If you want to use the same weights_column for the meta learner, please specify it as an argument in the h2o.stackedEnsemble call.");
        }
    }

    private static void checkColumnPresent(String str, String str2, Frame... frameArr) {
        if (str2 == null) {
            return;
        }
        for (Frame frame : frameArr) {
            if (frame != null && frame.vec(str2) == null) {
                throw new IllegalArgumentException(String.format("Specified %s column '%s' not found in one of the supplied data frames. Available column names are: %s", str, str2, Arrays.toString(frame.names())));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void addModelPredictionsToLevelOneFrame(Model model, Frame frame, Frame frame2) {
        if (model._output.isBinomialClassifier()) {
            frame2.add(model._key.toString(), frame.vec(2));
            return;
        }
        if (model._output.isMultinomialClassifier()) {
            Frame subframe = frame.subframe(ArrayUtils.remove(frame.names(), "predict"));
            subframe.setNames((String[]) Stream.of((Object[]) subframe.names()).map(str -> {
                return model._key.toString().concat("/").concat(str);
            }).toArray(i -> {
                return new String[i];
            }));
            frame2.add(subframe);
        } else {
            if (model._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + model._key);
            }
            if (!model._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + model._key);
            }
            frame2.add(model._key.toString(), frame.vec("predict"));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void addNonPredictorsToLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters, Frame frame, Frame frame2, boolean z) {
        if (z && stackedEnsembleParameters._metalearner_fold_column != null) {
            frame2.add(stackedEnsembleParameters._metalearner_fold_column, frame.vec(stackedEnsembleParameters._metalearner_fold_column));
        }
        if (stackedEnsembleParameters._weights_column != null) {
            frame2.add(stackedEnsembleParameters._weights_column, frame.vec(stackedEnsembleParameters._weights_column));
        }
        if (stackedEnsembleParameters._offset_column != null) {
            frame2.add(stackedEnsembleParameters._offset_column, frame.vec(stackedEnsembleParameters._offset_column));
        }
        frame2.add(stackedEnsembleParameters._response_column, frame.vec(stackedEnsembleParameters._response_column));
    }
}
