/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.gradient;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class DefaultGradient
implements Gradient {
    public static final char DEFAULT_FLATTENING_ORDER = 'f';
    private Map<String, INDArray> gradients = new LinkedHashMap<String, INDArray>();
    private Map<String, Character> flatteningOrders;
    private INDArray flattenedGradient;

    public DefaultGradient() {
    }

    public DefaultGradient(INDArray flattenedGradient) {
        this.flattenedGradient = flattenedGradient;
    }

    @Override
    public Map<String, INDArray> gradientForVariable() {
        return this.gradients;
    }

    @Override
    public INDArray gradient(List<String> order) {
        ArrayList<INDArray> toFlatten = new ArrayList<INDArray>();
        if (this.flatteningOrders == null) {
            for (String s : order) {
                if (!this.gradients.containsKey(s)) continue;
                toFlatten.add(this.gradients.get(s));
            }
        } else {
            for (String s : order) {
                if (!this.gradients.containsKey(s)) continue;
                if (this.flatteningOrders.containsKey(s) && this.flatteningOrders.get(s).charValue() != 'f') {
                    toFlatten.add(Nd4j.toFlattened((char)this.flatteningOrders.get(s).charValue(), (INDArray[])new INDArray[]{this.gradients.get(s)}));
                    continue;
                }
                toFlatten.add(this.gradients.get(s));
            }
        }
        return Nd4j.toFlattened((char)'f', toFlatten);
    }

    private void flattenGradient() {
        if (this.flatteningOrders != null) {
            ArrayList<INDArray> toFlatten = new ArrayList<INDArray>();
            for (Map.Entry<String, INDArray> entry : this.gradients.entrySet()) {
                if (this.flatteningOrders.containsKey(entry.getKey()) && this.flatteningOrders.get(entry.getKey()).charValue() != 'f') {
                    toFlatten.add(Nd4j.toFlattened((char)this.flatteningOrders.get(entry.getKey()).charValue(), (INDArray[])new INDArray[]{entry.getValue()}));
                    continue;
                }
                toFlatten.add(entry.getValue());
            }
            this.flattenedGradient = Nd4j.toFlattened((char)'f', toFlatten);
        } else if (!this.gradients.values().isEmpty()) {
            this.flattenedGradient = Nd4j.toFlattened((char)'f', this.gradients.values());
        }
    }

    @Override
    public INDArray gradient() {
        if (this.flattenedGradient != null) {
            return this.flattenedGradient;
        }
        this.flattenGradient();
        return this.flattenedGradient;
    }

    @Override
    public void clear() {
        this.gradients.clear();
    }

    @Override
    public INDArray getGradientFor(String variable) {
        return this.gradients.get(variable);
    }

    @Override
    public INDArray setGradientFor(String variable, INDArray newGradient) {
        INDArray last = this.gradients.put(variable, newGradient);
        return last;
    }

    @Override
    public INDArray setGradientFor(String variable, INDArray gradient, Character flatteningOrder) {
        INDArray last = this.setGradientFor(variable, gradient);
        if (flatteningOrder != null) {
            if (this.flatteningOrders == null) {
                this.flatteningOrders = new LinkedHashMap<String, Character>();
            }
            this.flatteningOrders.put(variable, flatteningOrder);
        }
        return last;
    }

    @Override
    public Character flatteningOrderForVariable(String variable) {
        if (this.flatteningOrders == null) {
            return null;
        }
        return this.flatteningOrders.get(variable);
    }

    public String toString() {
        return "DefaultGradient{gradients=" + this.gradients + (this.flatteningOrders != null ? this.flatteningOrders : "") + '}';
    }

    public void setFlattenedGradient(INDArray flattenedGradient) {
        this.flattenedGradient = flattenedGradient;
    }
}

