/*
 * Decompiled with CFR 0.152.
 */
package hex.segments;

import hex.Model;
import hex.segments.LocalSequentialSegmentModelsBuilder;
import hex.segments.SegmentModels;
import hex.segments.SegmentModelsStats;
import hex.segments.WorkAllocator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Frame;
import water.rapids.ast.prims.mungers.AstGroup;
import water.util.Log;

public class SegmentModelsBuilder {
    private static final AtomicLong nextSegmentModelsNum = new AtomicLong(0L);
    private final SegmentModelsParameters _parms;
    private final Model.Parameters _blueprint_parms;

    public SegmentModelsBuilder(SegmentModelsParameters parms, Model.Parameters blueprintParms) {
        this._parms = parms;
        this._blueprint_parms = blueprintParms;
    }

    public Job<SegmentModels> buildSegmentModels() {
        if (this._parms._parallelism <= 0) {
            throw new IllegalArgumentException("Parameter `parallelism` has to be a positive number, received=" + this._parms._parallelism);
        }
        Frame segments = this._parms._segments != null ? SegmentModelsBuilder.validateSegmentsFrame(this._parms._segments, this._parms._segment_columns) : this.makeSegmentsFrame(this._blueprint_parms._train, this._parms._segment_columns);
        Job<SegmentModels> job = new Job<SegmentModels>(this.makeDestKey(), SegmentModels.class.getName(), this._blueprint_parms.algoName());
        SegmentModelsBuilderTask segmentBuilder = new SegmentModelsBuilderTask(job, segments, this._blueprint_parms._train, this._blueprint_parms._valid, this._parms._parallelism);
        return job.start(segmentBuilder, segments.numRows());
    }

    private Frame makeSegmentsFrame(Key<Frame> trainKey, String[] segmentColumns) {
        Frame train = SegmentModelsBuilder.validateSegmentsFrame(trainKey, segmentColumns);
        return new AstGroup().performGroupingWithAggregations(train, train.find(segmentColumns), new AstGroup.AGG[0]).getFrame();
    }

    private Key<SegmentModels> makeDestKey() {
        if (this._parms._segment_models_id != null) {
            return this._parms._segment_models_id;
        }
        String id = H2O.calcNextUniqueObjectId("segment_models", nextSegmentModelsNum, this._blueprint_parms.algoName());
        return Key.make(id);
    }

    private static Frame validateSegmentsFrame(Key<Frame> segmentsKey, String[] segmentColumns) {
        Frame segments = segmentsKey.get();
        if (segments == null) {
            throw new IllegalStateException("Frame `" + segmentsKey + "` doesn't exist.");
        }
        List invalidColumns = Stream.of(segmentColumns != null ? segmentColumns : segments.names()).filter(name -> !segments.vec((String)name).isCategorical() && !segments.vec((String)name).isInt()).collect(Collectors.toList());
        if (!invalidColumns.isEmpty()) {
            throw new IllegalStateException("Columns to segment-by can only be categorical and integer of type, invalid columns: " + invalidColumns);
        }
        return segments;
    }

    public static class SegmentModelsParameters
    extends Iced<SegmentModelsParameters> {
        Key<SegmentModels> _segment_models_id;
        Key<Frame> _segments;
        String[] _segment_columns;
        int _parallelism = 1;
    }

    private static class MultiNodeRunner
    extends MRTask<MultiNodeRunner> {
        final LocalSequentialSegmentModelsBuilder _builder;
        final SegmentModels _segment_models;
        final int _parallelism;
        SegmentModelsStats _stats;

        private MultiNodeRunner(LocalSequentialSegmentModelsBuilder builder, SegmentModels segmentModels, int parallelism) {
            this._builder = builder;
            this._segment_models = segmentModels;
            this._parallelism = parallelism;
        }

        @Override
        protected void setupLocal() {
            if (this._parallelism == 1) {
                this._stats = this._builder.buildModels(this._segment_models);
            } else {
                ExecutorService executor = Executors.newFixedThreadPool(this._parallelism);
                this._stats = Stream.generate(() -> () -> ((LocalSequentialSegmentModelsBuilder)this._builder.clone()).buildModels(this._segment_models)).limit(this._parallelism).map(executor::submit).map((? super T future) -> {
                    try {
                        return (SegmentModelsStats)future.get();
                    }
                    catch (InterruptedException | ExecutionException e) {
                        throw new RuntimeException("Failed to build segment-models", e);
                    }
                }).reduce((a, b) -> {
                    a.reduce((SegmentModelsStats)b);
                    return a;
                }).get();
            }
            Log.info("Finished per-segment model building on node ", H2O.SELF, "; summary: ", this._stats);
        }

        @Override
        public void reduce(MultiNodeRunner mrt) {
            this._stats.reduce(mrt._stats);
        }
    }

    private class SegmentModelsBuilderTask
    extends H2O.H2OCountedCompleter<SegmentModelsBuilderTask> {
        private final Job<SegmentModels> _job;
        private final Frame _segments;
        private final Frame _full_train;
        private final Frame _full_valid;
        private final Key _counter_key;
        private final int _parallelism;

        private SegmentModelsBuilderTask(Job<SegmentModels> job, Frame segments, Key<Frame> train, Key<Frame> valid, int parallelism) {
            this._job = job;
            this._segments = segments;
            this._full_train = this.reorderColumns(train);
            this._full_valid = this.reorderColumns(valid);
            this._counter_key = Key.make();
            this._parallelism = parallelism;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void compute2() {
            try {
                SegmentModelsBuilder.this._blueprint_parms.read_lock_frames(this._job);
                SegmentModels segmentModels = SegmentModels.make(this._job._result, this._segments);
                WorkAllocator allocator = new WorkAllocator(this._counter_key, this._segments.numRows());
                LocalSequentialSegmentModelsBuilder localBuilder = new LocalSequentialSegmentModelsBuilder(this._job, SegmentModelsBuilder.this._blueprint_parms, this._segments, this._full_train, this._full_valid, allocator);
                SegmentModelsStats stats = ((MultiNodeRunner)new MultiNodeRunner((LocalSequentialSegmentModelsBuilder)localBuilder, (SegmentModels)segmentModels, (int)this._parallelism).doAllNodes())._stats;
                Log.info("Finished per-segment model building; summary: ", stats);
            }
            finally {
                SegmentModelsBuilder.this._blueprint_parms.read_unlock_frames(this._job);
                if (this._segments._key == null) {
                    this._segments.remove();
                }
                DKV.remove(this._counter_key);
            }
            this.tryComplete();
        }

        private Frame reorderColumns(Key<Frame> key) {
            if (key == null) {
                return null;
            }
            Frame f = key.get();
            if (f == null) {
                throw new IllegalStateException("Key " + key + " doesn't point to an existing Frame.");
            }
            Frame mutating = new Frame(f);
            Frame reordered = new Frame(this._segments.names(), mutating.vecs(this._segments.names())).add(mutating.remove(this._segments.names()));
            reordered._key = f._key;
            return reordered;
        }
    }
}

