/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.preprocessors;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasFlattenRnnPreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(KerasFlattenRnnPreprocessor.class);
    private long tsLength;
    private long depth;

    public KerasFlattenRnnPreprocessor(@JsonProperty(value="depth") long depth, @JsonProperty(value="tsLength") long tsLength) {
        this.tsLength = Math.abs(tsLength);
        this.depth = depth;
    }

    @Override
    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        INDArray output = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
        return output.reshape(input.size(0), this.depth * this.tsLength);
    }

    @Override
    public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        return workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c').reshape(new long[]{miniBatchSize, this.depth, this.tsLength});
    }

    @Override
    public KerasFlattenRnnPreprocessor clone() {
        return (KerasFlattenRnnPreprocessor)super.clone();
    }

    @Override
    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        return InputType.feedForward(this.depth * this.tsLength);
    }

    public long getTsLength() {
        return this.tsLength;
    }

    public long getDepth() {
        return this.depth;
    }

    public void setTsLength(long tsLength) {
        this.tsLength = tsLength;
    }

    public void setDepth(long depth) {
        this.depth = depth;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof KerasFlattenRnnPreprocessor)) {
            return false;
        }
        KerasFlattenRnnPreprocessor other = (KerasFlattenRnnPreprocessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getTsLength() != other.getTsLength()) {
            return false;
        }
        return this.getDepth() == other.getDepth();
    }

    protected boolean canEqual(Object other) {
        return other instanceof KerasFlattenRnnPreprocessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $tsLength = this.getTsLength();
        result = result * 59 + (int)($tsLength >>> 32 ^ $tsLength);
        long $depth = this.getDepth();
        result = result * 59 + (int)($depth >>> 32 ^ $depth);
        return result;
    }

    public String toString() {
        return "KerasFlattenRnnPreprocessor(tsLength=" + this.getTsLength() + ", depth=" + this.getDepth() + ")";
    }
}

