/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class TimeSeriesUtils {
    private TimeSeriesUtils() {
    }

    public static INDArray movingAverage(INDArray toAvg, int n) {
        INDArray ret = Nd4j.cumsum((INDArray)toAvg);
        INDArrayIndex[] ends = new INDArrayIndex[]{NDArrayIndex.interval((int)n, (int)toAvg.columns())};
        INDArrayIndex[] begins = new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)(toAvg.columns() - n), (boolean)false)};
        INDArrayIndex[] nMinusOne = new INDArrayIndex[]{NDArrayIndex.interval((int)(n - 1), (int)toAvg.columns())};
        ret.put(ends, ret.get(ends).sub(ret.get(begins)));
        return ret.get(nMinusOne).divi((Number)n);
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray timeSeriesMask) {
        if (timeSeriesMask.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (timeSeriesMask.ordering() != 'f') {
            timeSeriesMask = timeSeriesMask.dup('f');
        }
        return timeSeriesMask.reshape('f', new long[]{timeSeriesMask.length(), 1L});
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray timeSeriesMask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (timeSeriesMask.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (timeSeriesMask.ordering() != 'f' || !Shape.hasDefaultStridesForShape((INDArray)timeSeriesMask)) {
            timeSeriesMask = workspaceMgr.dup(arrayType, timeSeriesMask, 'f');
        }
        return workspaceMgr.leverageTo(arrayType, timeSeriesMask.reshape('f', new long[]{timeSeriesMask.length(), 1L}));
    }

    public static INDArray reshapeTimeSeriesMaskToCnn4dMask(INDArray timeSeriesMask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (timeSeriesMask.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (timeSeriesMask.ordering() != 'f' || !Shape.hasDefaultStridesForShape((INDArray)timeSeriesMask)) {
            timeSeriesMask = workspaceMgr.dup(arrayType, timeSeriesMask, 'f');
        }
        return workspaceMgr.leverageTo(arrayType, timeSeriesMask.reshape('f', new long[]{timeSeriesMask.length(), 1L, 1L, 1L}));
    }

    public static INDArray reshapeVectorToTimeSeriesMask(INDArray timeSeriesMaskAsVector, int minibatchSize) {
        if (!timeSeriesMaskAsVector.isVector()) {
            throw new IllegalArgumentException("Cannot reshape mask: expected vector");
        }
        long timeSeriesLength = timeSeriesMaskAsVector.length() / (long)minibatchSize;
        return timeSeriesMaskAsVector.reshape('f', new long[]{minibatchSize, timeSeriesLength});
    }

    public static INDArray reshapeCnnMaskToTimeSeriesMask(INDArray timeSeriesMaskAsCnnMask, int minibatchSize) {
        Preconditions.checkArgument((timeSeriesMaskAsCnnMask.rank() == 4 || timeSeriesMaskAsCnnMask.size(1) != 1L || timeSeriesMaskAsCnnMask.size(2) != 1L || timeSeriesMaskAsCnnMask.size(3) != 1L ? 1 : 0) != 0, (String)"Expected rank 4 mask with shape [mb*seqLength, 1, 1, 1]. Got rank %s mask array with shape %s", (Object)timeSeriesMaskAsCnnMask.rank(), (Object)timeSeriesMaskAsCnnMask.shape());
        long timeSeriesLength = timeSeriesMaskAsCnnMask.length() / (long)minibatchSize;
        return timeSeriesMaskAsCnnMask.reshape('f', new long[]{minibatchSize, timeSeriesLength});
    }

    public static INDArray reshapePerOutputTimeSeriesMaskTo2d(INDArray perOutputTimeSeriesMask) {
        if (perOutputTimeSeriesMask.rank() != 3) {
            throw new IllegalArgumentException("Cannot reshape per output mask: rank is not 3 (is: " + perOutputTimeSeriesMask.rank() + ", shape = " + Arrays.toString(perOutputTimeSeriesMask.shape()) + ")");
        }
        return TimeSeriesUtils.reshape3dTo2d(perOutputTimeSeriesMask);
    }

    public static INDArray reshapePerOutputTimeSeriesMaskTo2d(INDArray perOutputTimeSeriesMask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (perOutputTimeSeriesMask.rank() != 3) {
            throw new IllegalArgumentException("Cannot reshape per output mask: rank is not 3 (is: " + perOutputTimeSeriesMask.rank() + ", shape = " + Arrays.toString(perOutputTimeSeriesMask.shape()) + ")");
        }
        return TimeSeriesUtils.reshape3dTo2d(perOutputTimeSeriesMask, workspaceMgr, arrayType);
    }

    public static INDArray reshape3dTo2d(INDArray in) {
        if (in.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        long[] shape = in.shape();
        if (shape[0] == 1L) {
            return in.tensorAlongDimension(0L, new int[]{1, 2}).permutei(new int[]{1, 0});
        }
        if (shape[2] == 1L) {
            return in.tensorAlongDimension(0L, new int[]{1, 0});
        }
        INDArray permuted = in.permute(new int[]{0, 2, 1});
        return permuted.reshape('f', new long[]{shape[0] * shape[2], shape[1]});
    }

    public static INDArray reshape3dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        INDArray ret;
        if (in.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        long[] shape = in.shape();
        if (shape[0] == 1L) {
            ret = in.tensorAlongDimension(0L, new int[]{1, 2}).permutei(new int[]{1, 0});
        } else if (shape[2] == 1L) {
            ret = in.tensorAlongDimension(0L, new int[]{1, 0});
        } else {
            INDArray permuted = in.permute(new int[]{0, 2, 1});
            ret = permuted.reshape('f', new long[]{shape[0] * shape[2], shape[1]});
        }
        return workspaceMgr.leverageTo(arrayType, ret);
    }

    public static INDArray reshape2dTo3d(INDArray in, int miniBatchSize) {
        if (in.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        long[] shape = in.shape();
        if (in.ordering() != 'f') {
            in = Shape.toOffsetZeroCopy((INDArray)in, (char)'f');
        }
        INDArray reshaped = in.reshape('f', new long[]{miniBatchSize, shape[0] / (long)miniBatchSize, shape[1]});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    public static INDArray reshape2dTo3d(INDArray in, long miniBatchSize, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (in.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        long[] shape = in.shape();
        if (in.ordering() != 'f') {
            in = workspaceMgr.dup(arrayType, in, 'f');
        }
        INDArray reshaped = in.reshape('f', new long[]{miniBatchSize, shape[0] / miniBatchSize, shape[1]});
        return workspaceMgr.leverageTo(arrayType, reshaped.permute(new int[]{0, 2, 1}));
    }

    public static INDArray reverseTimeSeries(INDArray in) {
        if (in == null) {
            return null;
        }
        if (in.ordering() != 'f' || in.isView() || !Shape.strideDescendingCAscendingF((INDArray)in)) {
            in = in.dup('f');
        }
        int[] idxs = new int[(int)in.size(2)];
        int j = 0;
        int i = idxs.length - 1;
        while (i >= 0) {
            idxs[j++] = i--;
        }
        INDArray inReshape = in.reshape('f', new long[]{in.size(0) * in.size(1), in.size(2)});
        INDArray outReshape = Nd4j.pullRows((INDArray)inReshape, (int)0, (int[])idxs, (char)'f');
        return outReshape.reshape('f', new long[]{in.size(0), in.size(1), in.size(2)});
    }

    public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType, RNNFormat dataFormat) {
        if (dataFormat == RNNFormat.NCW) {
            return TimeSeriesUtils.reverseTimeSeries(in, workspaceMgr, arrayType);
        }
        return TimeSeriesUtils.reverseTimeSeries(in.permute(new int[]{0, 2, 1}), workspaceMgr, arrayType).permute(new int[]{0, 2, 1});
    }

    public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (in == null) {
            return null;
        }
        if (in.ordering() != 'f' || in.isView() || !Shape.strideDescendingCAscendingF((INDArray)in)) {
            in = workspaceMgr.dup(arrayType, in, 'f');
        }
        if (in.size(2) > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int[] idxs = new int[(int)in.size(2)];
        int j = 0;
        int i = idxs.length - 1;
        while (i >= 0) {
            idxs[j++] = i--;
        }
        INDArray inReshape = in.reshape('f', new long[]{in.size(0) * in.size(1), in.size(2)});
        INDArray outReshape = workspaceMgr.create(arrayType, in.dataType(), new long[]{inReshape.size(0), idxs.length}, 'f');
        Nd4j.pullRows((INDArray)inReshape, (INDArray)outReshape, (int)0, (int[])idxs);
        return workspaceMgr.leverageTo(arrayType, outReshape.reshape('f', new long[]{in.size(0), in.size(1), in.size(2)}));
    }

    public static INDArray reverseTimeSeriesMask(INDArray mask) {
        if (mask == null) {
            return null;
        }
        if (mask.rank() == 3) {
            return TimeSeriesUtils.reverseTimeSeries(mask);
        }
        if (mask.rank() != 2) {
            throw new IllegalArgumentException("Invalid mask rank: must be rank 2 or 3. Got rank " + mask.rank() + " with shape " + Arrays.toString(mask.shape()));
        }
        if (mask.size(1) > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int[] idxs = new int[(int)mask.size(1)];
        int j = 0;
        int i = idxs.length - 1;
        while (i >= 0) {
            idxs[j++] = i--;
        }
        return Nd4j.pullRows((INDArray)mask, (int)0, (int[])idxs);
    }

    public static INDArray reverseTimeSeriesMask(INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        if (mask == null) {
            return null;
        }
        if (mask.rank() == 3) {
            return TimeSeriesUtils.reverseTimeSeries(mask, workspaceMgr, arrayType);
        }
        if (mask.rank() != 2) {
            throw new IllegalArgumentException("Invalid mask rank: must be rank 2 or 3. Got rank " + mask.rank() + " with shape " + Arrays.toString(mask.shape()));
        }
        if (mask.size(1) > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int[] idxs = new int[(int)mask.size(1)];
        int j = 0;
        int i = idxs.length - 1;
        while (i >= 0) {
            idxs[j++] = i--;
        }
        INDArray ret = workspaceMgr.createUninitialized(arrayType, mask.dataType(), new long[]{mask.size(0), idxs.length}, 'f');
        return Nd4j.pullRows((INDArray)mask, (INDArray)ret, (int)0, (int[])idxs);
    }

    public static Pair<INDArray, int[]> pullLastTimeSteps(INDArray pullFrom, INDArray mask) {
        int[] fwdPassTimeSteps;
        INDArray out;
        if (mask == null) {
            long lastTS = pullFrom.size(2) - 1L;
            out = pullFrom.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)lastTS)});
            fwdPassTimeSteps = null;
        } else {
            long[] outShape = new long[]{pullFrom.size(0), pullFrom.size(1)};
            out = Nd4j.create((long[])outShape);
            INDArray lastStepArr = BooleanIndexing.lastIndex((INDArray)mask, (Condition)Conditions.epsNotEquals((Number)0.0), (int[])new int[]{1});
            fwdPassTimeSteps = lastStepArr.data().asInt();
            for (int i = 0; i < fwdPassTimeSteps.length; ++i) {
                out.putRow((long)i, pullFrom.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.all(), NDArrayIndex.point((long)fwdPassTimeSteps[i])}));
            }
        }
        return new Pair((Object)out, fwdPassTimeSteps);
    }

    public static Pair<INDArray, int[]> pullLastTimeSteps(INDArray pullFrom, INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        int[] fwdPassTimeSteps;
        INDArray out;
        if (mask == null) {
            long lastTS = pullFrom.size(2) - 1L;
            out = pullFrom.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)lastTS)});
            fwdPassTimeSteps = null;
        } else {
            long[] outShape = new long[]{pullFrom.size(0), pullFrom.size(1)};
            out = Nd4j.create((long[])outShape);
            INDArray lastStepArr = BooleanIndexing.lastIndex((INDArray)mask, (Condition)Conditions.epsNotEquals((Number)0.0), (int[])new int[]{1});
            fwdPassTimeSteps = lastStepArr.data().asInt();
            for (int i = 0; i < fwdPassTimeSteps.length; ++i) {
                int lastStepIdx = fwdPassTimeSteps[i];
                Preconditions.checkState((lastStepIdx >= 0 ? 1 : 0) != 0, (String)"Invalid last time step index: example %s in minibatch is entirely masked out (input mask is all 0s, meaning no input data is present for this example)", (int)i);
                out.putRow((long)i, pullFrom.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.all(), NDArrayIndex.point((long)lastStepIdx)}));
            }
        }
        return new Pair((Object)workspaceMgr.leverageTo(arrayType, out), (Object)fwdPassTimeSteps);
    }

    public static RNNFormat getFormatFromRnnLayer(Layer layer) {
        if (layer instanceof BaseRecurrentLayer) {
            return ((BaseRecurrentLayer)layer).getRnnDataFormat();
        }
        if (layer instanceof MaskZeroLayer) {
            return TimeSeriesUtils.getFormatFromRnnLayer(((MaskZeroLayer)layer).getUnderlying());
        }
        if (layer instanceof Bidirectional) {
            return TimeSeriesUtils.getFormatFromRnnLayer(((Bidirectional)layer).getFwd());
        }
        if (layer instanceof LastTimeStep) {
            return TimeSeriesUtils.getFormatFromRnnLayer(((LastTimeStep)layer).getUnderlying());
        }
        if (layer instanceof TimeDistributed) {
            return ((TimeDistributed)layer).getRnnDataFormat();
        }
        throw new IllegalStateException("Unable to get RNNFormat from layer of type: " + layer);
    }
}

