/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.multilayer.scoring;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

public class VaeReconstructionProbWithKeyFunction<K>
extends BaseVaeReconstructionProbWithKeyFunction<K> {
    public VaeReconstructionProbWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean useLogProbability, int batchSize, int numSamples) {
        super(params, jsonConfig, useLogProbability, batchSize, numSamples);
    }

    @Override
    public VariationalAutoencoder getVaeLayer() {
        MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).dup();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParameters(val);
        Layer l = network.getLayer(0);
        if (!(l instanceof VariationalAutoencoder)) {
            throw new RuntimeException("Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE layer as layer 0. Layer type: " + l.getClass());
        }
        return (VariationalAutoencoder)l;
    }
}

