package hex.grid;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.grid.HyperSpaceWalker;
import java.util.Map;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;

/* loaded from: input_file:hex/grid/GridSearch.class */
public final class GridSearch<MP extends Model.Parameters> extends Job<Grid> {
    private final transient ModelFactory<MP> _modelFactory;
    private final transient HyperSpaceWalker<MP> _hyperSpaceWalker;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/grid/GridSearch$SimpleParametersBuilderFactory.class */
    public static class SimpleParametersBuilderFactory<MP extends Model.Parameters> implements ModelParametersBuilderFactory<MP> {

        /* loaded from: input_file:hex/grid/GridSearch$SimpleParametersBuilderFactory$SimpleParamsBuilder.class */
        public static class SimpleParamsBuilder<MP extends Model.Parameters> implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
            private final MP params;

            public SimpleParamsBuilder(MP mp) {
                this.params = mp;
            }

            @Override // hex.ModelParametersBuilderFactory.ModelParametersBuilder
            public ModelParametersBuilderFactory.ModelParametersBuilder<MP> set(String str, Object obj) {
                PojoUtils.setField(this.params, str, obj, PojoUtils.FieldNaming.CONSISTENT);
                return this;
            }

            @Override // hex.ModelParametersBuilderFactory.ModelParametersBuilder
            public MP build() {
                return this.params;
            }
        }

        SimpleParametersBuilderFactory() {
        }

        @Override // hex.ModelParametersBuilderFactory
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP mp) {
            return new SimpleParamsBuilder(mp);
        }

        @Override // hex.ModelParametersBuilderFactory
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.CONSISTENT;
        }
    }

    private GridSearch(Key key, ModelFactory<MP> modelFactory, HyperSpaceWalker<MP> hyperSpaceWalker) {
        super(key, modelFactory.getModelName() + " Grid Search");
        if (!$assertionsDisabled && modelFactory == null) {
            throw new AssertionError("Grid search needs to know how to build a new model!");
        }
        if (!$assertionsDisabled && hyperSpaceWalker == null) {
            throw new AssertionError("Grid search needs to know to how walk around hyper space!");
        }
        this._modelFactory = modelFactory;
        this._hyperSpaceWalker = hyperSpaceWalker;
    }

    GridSearch start() {
        int hyperSpaceSize = this._hyperSpaceWalker.getHyperSpaceSize();
        Log.info("Starting gridsearch: estimated size of search space = " + hyperSpaceSize);
        Grid grid = (Grid) DKV.getGet(dest());
        if (grid != null) {
            Frame train = this._hyperSpaceWalker.getParams().train();
            Frame trainingFrame = grid.getTrainingFrame();
            if (!train._key.equals(trainingFrame._key) || train.checksum() != trainingFrame.checksum()) {
                throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input");
            }
            grid.write_lock(jobKey());
        } else {
            grid = new Grid(dest(), this._hyperSpaceWalker.getParams(), this._hyperSpaceWalker.getHyperParamNames(), this._modelFactory.getModelName(), this._hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
            grid.delete_and_lock(jobKey());
        }
        final Grid grid2 = grid;
        start(new H2O.H2OCountedCompleter() { // from class: hex.grid.GridSearch.1
            @Override // water.H2O.H2OCountedCompleter
            public void compute2() {
                GridSearch.this.gridSearch(grid2);
                tryComplete();
            }
        }, hyperSpaceSize, true);
        return this;
    }

    public int getModelCount() {
        return this._hyperSpaceWalker.getHyperSpaceSize();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Finally extract failed */
    public void gridSearch(Grid<MP> grid) {
        Model model = null;
        String str = this._hyperSpaceWalker.getParams()._model_id == null ? grid._key + "_model_" : this._hyperSpaceWalker.getParams()._model_id.toString() + H2O.calcNextUniqueModelId("") + "_";
        try {
            try {
                HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
                int i = 0;
                while (it.hasNext(model)) {
                    if (!isRunning()) {
                        cancel();
                        grid.unlock(jobKey());
                        return;
                    }
                    try {
                        try {
                            MP nextModelParameters = it.nextModelParameters(model);
                            try {
                                int i2 = i;
                                i++;
                                model = buildModel(nextModelParameters, grid, i2, str);
                            } catch (RuntimeException e) {
                                Log.warn("Grid search: model builder for parameters " + nextModelParameters + " failed! Exception: ", e);
                                grid.appendFailedModelParameters((Grid<MP>) nextModelParameters, e);
                            }
                            update(1L);
                            grid.update(jobKey());
                        } catch (IllegalArgumentException e2) {
                            Log.warn("Grid search: construction of model parameters failed! Exception: ", e2);
                            grid.appendFailedModelParameters(it.getCurrentRawParameters(), e2);
                            update(1L);
                            grid.update(jobKey());
                        }
                    } catch (Throwable th) {
                        update(1L);
                        grid.update(jobKey());
                        throw th;
                    }
                }
                done();
                grid.unlock(jobKey());
            } catch (Throwable th2) {
                if (((Job) DKV.getGet(jobKey()))._state != Job.JobState.CANCELLED) {
                    failed(th2);
                    throw th2;
                }
                Log.info("Job " + jobKey() + " cancelled by user.");
                grid.unlock(jobKey());
            }
        } catch (Throwable th3) {
            grid.unlock(jobKey());
            throw th3;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Model buildModel(MP mp, Grid<MP> grid, int i, String str) {
        long checksum = mp.checksum();
        Key<Model> modelKey = grid.getModelKey(checksum);
        if (modelKey != null) {
            return modelKey.get();
        }
        mp._model_id = Key.make(str + i);
        Model model = (Model) startBuildModel(mp, grid).get();
        grid.putModel(checksum, model._key);
        return model;
    }

    private ModelBuilder startBuildModel(MP mp, Grid<MP> grid) {
        if (grid.getModel(mp) != null) {
            return null;
        }
        ModelBuilder buildModel = this._modelFactory.buildModel(mp);
        buildModel.trainModel();
        return buildModel;
    }

    protected static Key<Grid> gridKeyName(String str, Frame frame) {
        if (frame == null || frame._key == null) {
            throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
        }
        return Key.make("Grid_" + str + "_" + frame._key.toString() + H2O.calcNextUniqueModelId(""));
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> key, MP mp, Map<String, Object[]> map, ModelFactory<MP> modelFactory, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory) {
        return startGridSearch(key, modelFactory, new HyperSpaceWalker.CartesianWalker(mp, map, modelParametersBuilderFactory));
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> key, MP mp, Map<String, Object[]> map, ModelFactory<MP> modelFactory) {
        return startGridSearch(key, mp, map, modelFactory, new SimpleParametersBuilderFactory());
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(MP mp, Map<String, Object[]> map, ModelFactory<MP> modelFactory) {
        return startGridSearch(null, mp, map, modelFactory);
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> key, ModelFactory<MP> modelFactory, HyperSpaceWalker<MP> hyperSpaceWalker) {
        return new GridSearch(key != null ? key : gridKeyName(modelFactory.getModelName(), hyperSpaceWalker.getParams().train()), modelFactory, hyperSpaceWalker).start();
    }

    static {
        $assertionsDisabled = !GridSearch.class.desiredAssertionStatus();
    }
}
