/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDArrayIndexer;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import java.util.List;

public class PtNDArrayEx
implements NDArrayEx {
    private PtNDArray array;

    PtNDArrayEx(PtNDArray parent) {
        this.array = parent;
    }

    public PtNDArray rdiv(Number n) {
        return this.rdiv(this.array.getManager().create(n));
    }

    public PtNDArray rdiv(NDArray b) {
        return (PtNDArray)b.div((NDArray)this.array);
    }

    public PtNDArray rdivi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rdivi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsub(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsub(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsubi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsubi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmod(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmod(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmodi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmodi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rpow(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rpowi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray relu() {
        return JniUtils.relu(this.array);
    }

    public PtNDArray sigmoid() {
        return JniUtils.sigmoid(this.array);
    }

    public PtNDArray tanh() {
        return JniUtils.tanh(this.array);
    }

    public PtNDArray softPlus() {
        return JniUtils.softPlus(this.array);
    }

    public PtNDArray softSign() {
        return JniUtils.softSign(this.array);
    }

    public PtNDArray leakyRelu(float alpha) {
        return JniUtils.leakyRelu(this.array, alpha);
    }

    public PtNDArray elu(float alpha) {
        return JniUtils.elu(this.array, alpha);
    }

    public PtNDArray selu() {
        return JniUtils.selu(this.array);
    }

    public PtNDArray gelu() {
        return JniUtils.gelu(this.array);
    }

    public PtNDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return JniUtils.maxPool(this.array, kernelShape, stride, padding, ceilMode);
    }

    public PtNDArray globalMaxPool() {
        Shape shape = this.getPoolShape(this.array);
        try (PtNDArray temp = JniUtils.adaptiveMaxPool(this.array, shape);){
            PtNDArray ptNDArray = (PtNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return ptNDArray;
        }
    }

    public PtNDArray avgPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        return JniUtils.avgPool(this.array, kernelShape, stride, padding, ceilMode, countIncludePad);
    }

    public PtNDArray globalAvgPool() {
        Shape shape = this.getPoolShape(this.array);
        try (PtNDArray temp = JniUtils.adaptiveAvgPool(this.array, shape);){
            PtNDArray ptNDArray = (PtNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return ptNDArray;
        }
    }

    public PtNDArray lpPool(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        if (padding.size() != 0L) {
            throw new IllegalArgumentException("padding is not supported for PyTorch engine");
        }
        return JniUtils.lpPool(this.array, normType, kernelShape, stride, ceilMode);
    }

    public PtNDArray globalLpPool(float normType) {
        try (PtNDArray temp = JniUtils.lpPool(this.array, normType, this.array.getShape().slice(2), this.getPoolShape(this.array), false);){
            PtNDArray ptNDArray = (PtNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return ptNDArray;
        }
    }

    public void adadeltaUpdate(NDList inputs, NDList weights, float weightDecay, float rescaleGrad, float clipGrad, float rho, float epsilon) {
        throw new UnsupportedOperationException("AdaDelta optimzier is not supported for PyTorch engine!");
    }

    public void adagradUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float epsilon) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void adamUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float beta1, float beta2, float epsilon, boolean lazyUpdate) {
        PtNDManager manager = this.array.getManager();
        JniUtils.adamUpdate(manager.from((NDArray)inputs.get(0)), manager.from((NDArray)inputs.get(1)), manager.from((NDArray)inputs.get(2)), manager.from((NDArray)inputs.get(3)), learningRate, weightDecay, rescaleGrad, clipGrad, beta1, beta2, epsilon);
        JniUtils.zeroGrad(manager.from(weights.singletonOrThrow()));
    }

    public void nagUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void rmspropUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float rho, float momentum, float epsilon, boolean centered) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void sgdUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum, boolean lazyUpdate) {
        PtNDManager manager = this.array.getManager();
        JniUtils.sgdUpdate(manager.from((NDArray)inputs.get(0)), manager.from((NDArray)inputs.get(1)), momentum == 0.0f ? null : manager.from((NDArray)inputs.get(2)), learningRate, weightDecay, rescaleGrad, clipGrad, momentum);
        JniUtils.zeroGrad(manager.from(weights.singletonOrThrow()));
    }

    public NDList convolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        PtNDManager manager = this.array.getManager();
        return new NDList(new NDArray[]{JniUtils.convolution(manager.from(input), manager.from(weight), manager.from(bias), stride, padding, dilation, groups)});
    }

    public NDList deconvolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
        PtNDManager manager = this.array.getManager();
        return new NDList(new NDArray[]{JniUtils.linear(manager.from(input), manager.from(weight), manager.from(bias))});
    }

    public NDList embedding(NDArray input, NDArray weight, SparseFormat sparseFormat) {
        if (!sparseFormat.equals((Object)SparseFormat.DENSE) && !sparseFormat.equals((Object)SparseFormat.COO)) {
            throw new IllegalArgumentException("PyTorch only supports COO");
        }
        PtNDManager manager = this.array.getManager();
        return new NDList(new NDArray[]{JniUtils.embedding(manager.from(input), manager.from(weight), sparseFormat.equals((Object)SparseFormat.COO))});
    }

    public NDList prelu(NDArray input, NDArray alpha) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList dropout(NDArray input, float rate, boolean training) {
        PtNDManager manager = this.array.getManager();
        return new NDList(new NDArray[]{JniUtils.dropout(manager.from(input), rate, training)});
    }

    public NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
        PtNDManager manager = this.array.getManager();
        return new NDList(new NDArray[]{JniUtils.layerNorm(manager.from(input), normalizedShape, manager.from(gamma), manager.from(beta), eps)});
    }

    public NDList batchNorm(NDArray input, NDArray runningMean, NDArray runningVar, NDArray gamma, NDArray beta, int axis, float momentum, float eps, boolean training) {
        PtNDManager manager = this.array.getManager();
        if (axis == -1) {
            return new NDList(new NDArray[]{JniUtils.batchNorm(manager.from(input), manager.from(runningMean), manager.from(runningVar), manager.from(gamma), manager.from(beta), training, 1.0f - momentum, eps)});
        }
        try (NDManager subManager = input.getManager().newSubManager();){
            input.attach(subManager);
            NDArray result = input;
            result = result.swapAxes(1, axis);
            result = JniUtils.batchNorm(manager.from(result), manager.from(runningMean), manager.from(runningVar), manager.from(gamma), manager.from(beta), training, 1.0f - momentum, eps);
            result = result.swapAxes(1, axis);
            input.attach(subManager.getParentManager());
            result.attach(subManager.getParentManager());
            NDList nDList = new NDList(new NDArray[]{result});
            return nDList;
        }
    }

    public NDList rnn(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        PtNDManager manager = this.array.getManager();
        return JniUtils.rnn(manager.from(input), manager.from(state), params, hasBiases, numLayers, activation, dropRate, training, bidirectional, batchFirst);
    }

    public NDList gru(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        PtNDManager manager = this.array.getManager();
        return JniUtils.gru(manager.from(input), manager.from(state), params, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
    }

    public NDList lstm(NDArray input, NDList states, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        return JniUtils.lstm(this.array.getManager().from(input), states, params, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
    }

    public PtNDArray resize(int width, int height, int interpolation) {
        PtNDManager manager = this.array.getManager();
        try (NDManager subManager = manager.newSubManager();){
            int dim;
            this.array.attach(subManager);
            PtNDArray result = this.array;
            if (result.isEmpty()) {
                throw new IllegalArgumentException("attempt to resize of an empty NDArray");
            }
            if (result.getDataType() != DataType.FLOAT32) {
                result = result.toType(DataType.FLOAT32, true);
            }
            if ((dim = result.getShape().dimension()) == 3) {
                result = result.expandDims(0);
            }
            result = result.transpose(new int[]{0, 3, 1, 2});
            result = JniUtils.interpolate(this.array.getManager().from(result), new long[]{height, width}, this.getInterpolationMode(interpolation), false).transpose(0, 2, 3, 1);
            if (dim == 3) {
                result = result.squeeze(0);
            }
            this.array.attach(subManager.getParentManager());
            result.attach(subManager.getParentManager());
            PtNDArray ptNDArray = result;
            return ptNDArray;
        }
    }

    public NDArray randomFlipLeftRight() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomFlipTopBottom() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomBrightness(float brightness) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomHue(float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomColorJitter(float brightness, float contrast, float saturation, float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArrayIndexer getIndexer(NDManager manager) {
        return new PtNDArrayIndexer((PtNDManager)manager);
    }

    public PtNDArray where(NDArray condition, NDArray other) {
        if (!condition.getShape().equals((Object)this.array.getShape())) {
            throw new UnsupportedOperationException("condition and self shape mismatch, broadcast is not supported");
        }
        PtNDManager manager = this.array.getManager();
        return JniUtils.where(manager.from(condition), this.array, manager.from(other));
    }

    public PtNDArray stack(NDList arrays, int axis) {
        PtNDArray[] srcArray = new PtNDArray[arrays.size() + 1];
        srcArray[0] = this.array;
        int i = 1;
        PtNDManager manager = this.array.getManager();
        for (NDArray arr : arrays) {
            srcArray[i++] = manager.from(arr);
        }
        return JniUtils.stack(srcArray, axis);
    }

    public PtNDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput((NDList)list);
        PtNDArray[] srcArray = new PtNDArray[list.size() + 1];
        srcArray[0] = this.array;
        int i = 1;
        PtNDManager manager = this.array.getManager();
        for (NDArray arr : list) {
            srcArray[i++] = manager.from(arr);
        }
        return JniUtils.cat(srcArray, axis);
    }

    public NDList multiBoxTarget(NDList inputs, float iouThreshold, float ignoreLabel, float negativeMiningRatio, float negativeMiningThreshold, int minNegativeSamples) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxPrior(List<Float> sizes, List<Float> ratios, List<Float> steps, List<Float> offsets, boolean clip) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxDetection(NDList inputs, boolean clip, float threshold, int backgroundId, float nmsThreshold, boolean forceSuppress, int nmsTopK) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray getArray() {
        return this.array;
    }

    private Shape getPoolShape(NDArray array) {
        switch (array.getShape().dimension() - 2) {
            case 1: {
                return new Shape(new long[]{1L});
            }
            case 2: {
                return new Shape(new long[]{1L, 1L});
            }
            case 3: {
                return new Shape(new long[]{1L, 1L, 1L});
            }
        }
        throw new IllegalArgumentException("the input dimension should be in [3, 5]");
    }

    private int getInterpolationMode(int interpolation) {
        switch (interpolation) {
            case 0: {
                return 0;
            }
            case 1: {
                return 2;
            }
            case 2: {
                return 5;
            }
            case 3: {
                return 3;
            }
        }
        throw new UnsupportedOperationException("The kind of interpolation is not supported.");
    }
}

