/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class VariationalAutoencoderParamInitializer
extends DefaultParamInitializer {
    private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
    public static final String WEIGHT_KEY_SUFFIX = "W";
    public static final String BIAS_KEY_SUFFIX = "b";
    public static final String PZX_PREFIX = "pZX";
    public static final String PZX_MEAN_PREFIX = "pZXMean";
    public static final String PZX_LOGSTD2_PREFIX = "pZXLogStd2";
    public static final String ENCODER_PREFIX = "e";
    public static final String DECODER_PREFIX = "d";
    public static final String PZX_MEAN_W = "pZXMeanW";
    public static final String PZX_MEAN_B = "pZXMeanb";
    public static final String PZX_LOGSTD2_W = "pZXLogStd2W";
    public static final String PZX_LOGSTD2_B = "pZXLogStd2b";
    public static final String PXZ_PREFIX = "pXZ";
    public static final String PXZ_W = "pXZW";
    public static final String PXZ_B = "pXZb";

    public static VariationalAutoencoderParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        long nIn = layer.getNIn();
        long nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        int paramCount = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            long encoderLayerIn = i == 0 ? nIn : (long)encoderLayerSizes[i - 1];
            paramCount = (int)((long)paramCount + (encoderLayerIn + 1L) * (long)encoderLayerSizes[i]);
        }
        int lastEncLayerSize = encoderLayerSizes[encoderLayerSizes.length - 1];
        paramCount = (int)((long)paramCount + (long)((lastEncLayerSize + 1) * 2) * nOut);
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            long decoderLayerNIn = i == 0 ? nOut : (long)decoderLayerSizes[i - 1];
            paramCount = (int)((long)paramCount + (decoderLayerNIn + 1L) * (long)decoderLayerSizes[i]);
        }
        if (nIn > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int)nIn);
        int lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
        return paramCount += (lastDecLayerSize + 1) * nDistributionParams;
    }

    @Override
    public List<String> paramKeys(Layer l) {
        String sB;
        String sW;
        int i;
        VariationalAutoencoder layer = (VariationalAutoencoder)l;
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        ArrayList<String> p = new ArrayList<String>();
        boolean soFar = false;
        for (i = 0; i < encoderLayerSizes.length; ++i) {
            sW = ENCODER_PREFIX + i + WEIGHT_KEY_SUFFIX;
            sB = ENCODER_PREFIX + i + BIAS_KEY_SUFFIX;
            p.add(sW);
            p.add(sB);
        }
        p.add(PZX_MEAN_W);
        p.add(PZX_MEAN_B);
        p.add(PZX_LOGSTD2_W);
        p.add(PZX_LOGSTD2_B);
        for (i = 0; i < decoderLayerSizes.length; ++i) {
            sW = DECODER_PREFIX + i + WEIGHT_KEY_SUFFIX;
            sB = DECODER_PREFIX + i + BIAS_KEY_SUFFIX;
            p.add(sW);
            p.add(sB);
        }
        p.add(PXZ_W);
        p.add(PXZ_B);
        return p;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : this.paramKeys(layer)) {
            if (!this.isWeightParam(layer, s)) continue;
            out.add(s);
        }
        return out;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : this.paramKeys(layer)) {
            if (!this.isBiasParam(layer, s)) continue;
            out.add(s);
        }
        return out;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return key.endsWith(WEIGHT_KEY_SUFFIX);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return key.endsWith(BIAS_KEY_SUFFIX);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        if (paramsView.length() != this.numParams(conf)) {
            throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + this.numParams(conf) + ", got length " + paramsView.length());
        }
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        long nIn = layer.getNIn();
        long nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        IWeightInit weightInit = layer.getWeightInitFn();
        int soFar = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            long encoderLayerNIn = i == 0 ? nIn : (long)encoderLayerSizes[i - 1];
            INDArray paramsViewReshape = paramsView.reshape(new long[]{paramsView.length()});
            long weightParamCount = encoderLayerNIn * (long)encoderLayerSizes[i];
            INDArray weightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + weightParamCount))});
            soFar = (int)((long)soFar + weightParamCount);
            INDArray biasView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + encoderLayerSizes[i]))});
            soFar += encoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit, weightView, initializeParams);
            INDArray layerBiases = this.createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams);
            String sW = ENCODER_PREFIX + i + WEIGHT_KEY_SUFFIX;
            String sB = ENCODER_PREFIX + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
            conf.addVariable(sW);
            conf.addVariable(sB);
        }
        INDArray paramsViewReshape = paramsView.reshape(new long[]{paramsView.length()});
        long nWeightsPzx = (long)encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray pzxWeightsMean = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nWeightsPzx))});
        soFar = (int)((long)soFar + nWeightsPzx);
        INDArray pzxBiasMean = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nOut))});
        soFar = (int)((long)soFar + nOut);
        INDArray pzxWeightsMeanReshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, pzxWeightsMean, initializeParams);
        INDArray pzxBiasMeanReshaped = this.createBias(nOut, 0.0, pzxBiasMean, initializeParams);
        ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
        ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
        conf.addVariable(PZX_MEAN_W);
        conf.addVariable(PZX_MEAN_B);
        INDArray pzxWeightsLogStdev2 = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nWeightsPzx))});
        soFar = (int)((long)soFar + nWeightsPzx);
        INDArray pzxBiasLogStdev2 = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nOut))});
        soFar = (int)((long)soFar + nOut);
        INDArray pzxWeightsLogStdev2Reshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, pzxWeightsLogStdev2, initializeParams);
        INDArray pzxBiasLogStdev2Reshaped = this.createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams);
        ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
        ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
        conf.addVariable(PZX_LOGSTD2_W);
        conf.addVariable(PZX_LOGSTD2_B);
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            long decoderLayerNIn = i == 0 ? nOut : (long)decoderLayerSizes[i - 1];
            long weightParamCount = decoderLayerNIn * (long)decoderLayerSizes[i];
            INDArray weightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + weightParamCount))});
            soFar = (int)((long)soFar + weightParamCount);
            INDArray biasView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + decoderLayerSizes[i]))});
            soFar += decoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit, weightView, initializeParams);
            INDArray layerBiases = this.createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams);
            String sW = DECODER_PREFIX + i + WEIGHT_KEY_SUFFIX;
            String sB = DECODER_PREFIX + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
            conf.addVariable(sW);
            conf.addVariable(sB);
        }
        if (nIn > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int)nIn);
        int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
        INDArray pxzWeightView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + pxzWeightCount))});
        INDArray pxzBiasView = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)(soFar += pxzWeightCount), (int)(soFar + nDistributionParams))});
        INDArray pxzWeightsReshaped = this.createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, weightInit, pxzWeightView, initializeParams);
        INDArray pxzBiasReshaped = this.createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams);
        ret.put(PXZ_W, pxzWeightsReshaped);
        ret.put(PXZ_B, pxzBiasReshaped);
        conf.addVariable(PXZ_W);
        conf.addVariable(PXZ_B);
        return ret;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        LinkedHashMap<String, INDArray> ret = new LinkedHashMap<String, INDArray>();
        VariationalAutoencoder layer = (VariationalAutoencoder)conf.getLayer();
        long nIn = layer.getNIn();
        long nOut = layer.getNOut();
        int[] encoderLayerSizes = layer.getEncoderLayerSizes();
        int[] decoderLayerSizes = layer.getDecoderLayerSizes();
        int soFar = 0;
        for (int i = 0; i < encoderLayerSizes.length; ++i) {
            long encoderLayerNIn = i == 0 ? nIn : (long)encoderLayerSizes[i - 1];
            INDArray gradientViewReshape = gradientView.reshape(new long[]{gradientView.length()});
            long weightParamCount = encoderLayerNIn * (long)encoderLayerSizes[i];
            INDArray weightGradView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + weightParamCount))});
            soFar = (int)((long)soFar + weightParamCount);
            INDArray biasGradView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + encoderLayerSizes[i]))});
            soFar += encoderLayerSizes[i];
            INDArray layerWeights = weightGradView.reshape('f', new long[]{encoderLayerNIn, encoderLayerSizes[i]});
            INDArray layerBiases = biasGradView;
            ret.put(ENCODER_PREFIX + i + WEIGHT_KEY_SUFFIX, layerWeights);
            ret.put(ENCODER_PREFIX + i + BIAS_KEY_SUFFIX, layerBiases);
        }
        INDArray gradientViewReshape = gradientView.reshape(new long[]{gradientView.length()});
        long nWeightsPzx = (long)encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray pzxWeightsMean = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nWeightsPzx))});
        soFar = (int)((long)soFar + nWeightsPzx);
        INDArray pzxBiasMean = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nOut))});
        soFar = (int)((long)soFar + nOut);
        INDArray pzxWeightGradMeanReshaped = pzxWeightsMean.reshape('f', new long[]{encoderLayerSizes[encoderLayerSizes.length - 1], nOut});
        ret.put(PZX_MEAN_W, pzxWeightGradMeanReshaped);
        ret.put(PZX_MEAN_B, pzxBiasMean);
        INDArray pzxWeightsLogStdev2 = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nWeightsPzx))});
        soFar = (int)((long)soFar + nWeightsPzx);
        INDArray pzxBiasLogStdev2 = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + nOut))});
        soFar = (int)((long)soFar + nOut);
        INDArray pzxWeightsLogStdev2Reshaped = this.createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, null, pzxWeightsLogStdev2, false);
        ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
        ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2);
        for (int i = 0; i < decoderLayerSizes.length; ++i) {
            long decoderLayerNIn = i == 0 ? nOut : (long)decoderLayerSizes[i - 1];
            long weightParamCount = decoderLayerNIn * (long)decoderLayerSizes[i];
            INDArray weightView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)soFar, (long)((long)soFar + weightParamCount))});
            soFar = (int)((long)soFar + weightParamCount);
            INDArray biasView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + decoderLayerSizes[i]))});
            soFar += decoderLayerSizes[i];
            INDArray layerWeights = this.createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], null, weightView, false);
            INDArray layerBiases = this.createBias(decoderLayerSizes[i], 0.0, biasView, false);
            String sW = DECODER_PREFIX + i + WEIGHT_KEY_SUFFIX;
            String sB = DECODER_PREFIX + i + BIAS_KEY_SUFFIX;
            ret.put(sW, layerWeights);
            ret.put(sB, layerBiases);
        }
        if (nIn > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int)nIn);
        int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
        INDArray pxzWeightView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)soFar, (int)(soFar + pxzWeightCount))});
        INDArray pxzBiasView = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((int)(soFar += pxzWeightCount), (int)(soFar + nDistributionParams))});
        INDArray pxzWeightsReshaped = this.createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, null, pxzWeightView, false);
        INDArray pxzBiasReshaped = this.createBias(nDistributionParams, 0.0, pxzBiasView, false);
        ret.put(PXZ_W, pxzWeightsReshaped);
        ret.put(PXZ_B, pxzBiasReshaped);
        return ret;
    }
}

