/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deepwater.caffe;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeModel;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;

public class DeepwaterCaffeBackend
implements BackendTrain {
    public static final String CAFFE_DIR = "/opt/caffe/";
    public static final String CAFFE_H2O_DIR = "/opt/caffe-h2o/";

    public void delete(BackendModel m) {
        ((DeepwaterCaffeModel)m).close();
    }

    public BackendModel buildNet(ImageDataSet dataset, RuntimeOptions opts, BackendParams bparms, int num_classes, String name) {
        if (name.equals("MLP")) {
            double[] bphdr;
            int[] hidden = (int[])bparms.get("hidden");
            int[] sizes = new int[hidden.length + 2];
            sizes[0] = dataset.getWidth();
            System.arraycopy(hidden, 0, sizes, 1, hidden.length);
            sizes[sizes.length - 1] = num_classes;
            System.err.println("Ignoring device_id");
            double[] hdr = new double[sizes.length];
            if (bparms.get("input_dropout_ratio") != null) {
                hdr[0] = (Double)bparms.get("input_dropout_ratio");
            }
            if ((bphdr = (double[])bparms.get("hidden_dropout_ratios")) != null) {
                System.arraycopy(bphdr, 0, hdr, 1, bphdr.length);
            }
            String[] layers = new String[sizes.length];
            System.arraycopy(bparms.get("activations"), 0, layers, 1, hidden.length);
            layers[0] = "data";
            layers[layers.length - 1] = "loss";
            return new DeepwaterCaffeModel((Integer)bparms.get("mini_batch_size"), sizes, layers, hdr, opts.getSeed(), opts.useGPU());
        }
        return new DeepwaterCaffeModel(name, new int[]{(Integer)bparms.get("mini_batch_size"), dataset.getChannels(), dataset.getWidth(), dataset.getHeight()}, opts.getSeed(), opts.useGPU());
    }

    public void saveModel(BackendModel m, String model_path) {
        ((DeepwaterCaffeModel)m).saveModel(model_path);
    }

    public void loadParam(BackendModel m, String param_path) {
        ((DeepwaterCaffeModel)m).loadParam(param_path);
    }

    public void saveParam(BackendModel m, String param_path) {
        ((DeepwaterCaffeModel)m).saveParam(param_path);
    }

    public float[] loadMeanImage(BackendModel m, String path) {
        throw new UnsupportedOperationException();
    }

    public String toJson(BackendModel m) {
        throw new UnsupportedOperationException();
    }

    public void setParameter(BackendModel m, String name, float value) {
    }

    public float[] train(BackendModel m, float[] data, float[] label) {
        ((DeepwaterCaffeModel)m).train(data, label);
        return null;
    }

    public float[] predict(BackendModel m, float[] data) {
        return ((DeepwaterCaffeModel)m).predict(data);
    }

    public void deleteSavedModel(String model_path) {
    }

    public void deleteSavedParam(String param_path) {
    }

    public String listAllLayers(BackendModel m) {
        return null;
    }

    public float[] extractLayer(BackendModel m, String name, float[] data) {
        return new float[0];
    }

    public void writeBytes(File file, byte[] payload) throws IOException {
        FileOutputStream os = new FileOutputStream(file.toString());
        os.write(payload);
        os.close();
    }

    public byte[] readBytes(File file) throws IOException {
        FileInputStream is = new FileInputStream(file);
        byte[] params = new byte[(int)file.length()];
        is.read(params);
        is.close();
        return params;
    }
}

