/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class GravesBidirectionalLSTMParamInitializer
implements ParamInitializer {
    private static final GravesBidirectionalLSTMParamInitializer INSTANCE = new GravesBidirectionalLSTMParamInitializer();
    public static final String RECURRENT_WEIGHT_KEY_FORWARDS = "RWF";
    public static final String BIAS_KEY_FORWARDS = "bF";
    public static final String INPUT_WEIGHT_KEY_FORWARDS = "WF";
    public static final String RECURRENT_WEIGHT_KEY_BACKWARDS = "RWB";
    public static final String BIAS_KEY_BACKWARDS = "bB";
    public static final String INPUT_WEIGHT_KEY_BACKWARDS = "WB";
    private static final List<String> WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList("WF", "WB", "RWF", "RWB"));
    private static final List<String> BIAS_KEYS = Collections.unmodifiableList(Arrays.asList("bF", "bB"));
    private static final List<String> ALL_PARAM_KEYS = Collections.unmodifiableList(Arrays.asList("WF", "WB", "RWF", "RWB", "bF", "bB"));

    public static GravesBidirectionalLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer l) {
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)l;
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        long nParamsForward = nLast * (4L * nL) + nL * (4L * nL + 3L) + 4L * nL;
        return 2L * nParamsForward;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        return ALL_PARAM_KEYS;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return BIAS_KEYS;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return RECURRENT_WEIGHT_KEY_FORWARDS.equals(key) || INPUT_WEIGHT_KEY_FORWARDS.equals(key) || RECURRENT_WEIGHT_KEY_BACKWARDS.equals(key) || INPUT_WEIGHT_KEY_BACKWARDS.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY_FORWARDS.equals(key) || BIAS_KEY_BACKWARDS.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        double forgetGateInit = layerConf.getForgetGateBiasInit();
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        conf.addVariable(INPUT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(BIAS_KEY_FORWARDS);
        conf.addVariable(INPUT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(BIAS_KEY_BACKWARDS);
        long nParamsInput = nLast * (4L * nL);
        long nParamsRecurrent = nL * (4L * nL + 3L);
        long nBias = 4L * nL;
        long rwFOffset = nParamsInput;
        long bFOffset = rwFOffset + nParamsRecurrent;
        long iwROffset = bFOffset + nBias;
        long rwROffset = iwROffset + nParamsInput;
        long bROffset = rwROffset + nParamsRecurrent;
        INDArray paramsViewReshape = paramsView.reshape(new long[]{paramsView.length()});
        INDArray iwF = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)rwFOffset)});
        INDArray rwF = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)rwFOffset, (long)bFOffset)});
        INDArray bF = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)bFOffset, (long)iwROffset)});
        INDArray iwR = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)iwROffset, (long)rwROffset)});
        INDArray rwR = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)rwROffset, (long)bROffset)});
        INDArray bR = paramsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)bROffset, (long)(bROffset + nBias))});
        if (initializeParams) {
            bF.put(new INDArrayIndex[]{NDArrayIndex.interval((long)nL, (long)(2L * nL))}, Nd4j.ones((long[])new long[]{1L, nL}).muli((Number)forgetGateInit));
            bR.put(new INDArrayIndex[]{NDArrayIndex.interval((long)nL, (long)(2L * nL))}, Nd4j.ones((long[])new long[]{1L, nL}).muli((Number)forgetGateInit));
        }
        if (initializeParams) {
            long fanIn = nL;
            long fanOut = nLast + nL;
            long[] inputWShape = new long[]{nLast, 4L * nL};
            long[] recurrentWShape = new long[]{nL, 4L * nL + 3L};
            params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, 'f', iwF));
            params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, 'f', rwF));
            params.put(BIAS_KEY_FORWARDS, bF);
            params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, 'f', iwR));
            params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, 'f', rwR));
            params.put(BIAS_KEY_BACKWARDS, bR);
        } else {
            params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new long[]{nLast, 4L * nL}, iwF));
            params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new long[]{nL, 4L * nL + 3L}, rwF));
            params.put(BIAS_KEY_FORWARDS, bF);
            params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new long[]{nLast, 4L * nL}, iwR));
            params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new long[]{nL, 4L * nL + 3L}, rwR));
            params.put(BIAS_KEY_BACKWARDS, bR);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        long nL = layerConf.getNOut();
        long nLast = layerConf.getNIn();
        long nParamsInput = nLast * (4L * nL);
        long nParamsRecurrent = nL * (4L * nL + 3L);
        long nBias = 4L * nL;
        long rwFOffset = nParamsInput;
        long bFOffset = rwFOffset + nParamsRecurrent;
        long iwROffset = bFOffset + nBias;
        long rwROffset = iwROffset + nParamsInput;
        long bROffset = rwROffset + nParamsRecurrent;
        INDArray gradientViewReshape = gradientView.reshape(new long[]{gradientView.length()});
        INDArray iwFG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)rwFOffset)}).reshape('f', new long[]{nLast, 4L * nL});
        INDArray rwFG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)rwFOffset, (long)bFOffset)}).reshape('f', new long[]{nL, 4L * nL + 3L});
        INDArray bFG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)bFOffset, (long)iwROffset)});
        INDArray iwRG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)iwROffset, (long)rwROffset)}).reshape('f', new long[]{nLast, 4L * nL});
        INDArray rwRG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)rwROffset, (long)bROffset)}).reshape('f', new long[]{nL, 4L * nL + 3L});
        INDArray bRG = gradientViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)bROffset, (long)(bROffset + nBias))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(INPUT_WEIGHT_KEY_FORWARDS, iwFG);
        out.put(RECURRENT_WEIGHT_KEY_FORWARDS, rwFG);
        out.put(BIAS_KEY_FORWARDS, bFG);
        out.put(INPUT_WEIGHT_KEY_BACKWARDS, iwRG);
        out.put(RECURRENT_WEIGHT_KEY_BACKWARDS, rwRG);
        out.put(BIAS_KEY_BACKWARDS, bRG);
        return out;
    }
}

