package ai.djl.engine.rust;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/engine/rust/RsNDArrayIndexer.class */
public class RsNDArrayIndexer extends NDArrayIndexer {
    private RsNDManager manager;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RsNDArrayIndexer(RsNDManager rsNDManager) {
        this.manager = rsNDManager;
    }

    public NDArray get(NDArray nDArray, NDIndexFullPick nDIndexFullPick) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray rsNDArray = new RsNDArray(this.manager, RustLibrary.pick(((Long) this.manager.mo165from(nDArray).getHandle()).longValue(), ((Long) this.manager.mo165from(nDIndexFullPick.getIndices()).getHandle()).longValue(), nDIndexFullPick.getAxis()));
            NDScope.unregister(rsNDArray);
            nDScope.close();
            return rsNDArray;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public NDArray get(NDArray nDArray, NDIndexFullTake nDIndexFullTake) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray rsNDArray = new RsNDArray(this.manager, RustLibrary.take(((Long) this.manager.mo165from(nDArray).getHandle()).longValue(), ((Long) this.manager.mo165from(nDIndexFullTake.getIndices()).getHandle()).longValue()));
            NDScope.unregister(rsNDArray);
            nDScope.close();
            return rsNDArray;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public NDArray get(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice) {
        long[] min = nDIndexFullSlice.getMin();
        long[] max = nDIndexFullSlice.getMax();
        long[] step = nDIndexFullSlice.getStep();
        long[] jArr = (long[]) nDArray.getShape().getShape().clone();
        if (Arrays.stream(step).anyMatch(j -> {
            return j != 1;
        })) {
            throw new UnsupportedOperationException("only step 1 is supported");
        }
        for (int i = 0; i < min.length; i++) {
            if (min[i] >= max[i] || min[i] >= jArr[i]) {
                return this.manager.create(nDIndexFullSlice.getSqueezedShape(), nDArray.getDataType(), nDArray.getDevice());
            }
        }
        NDScope nDScope = new NDScope();
        try {
            long fullSlice = RustLibrary.fullSlice(((Long) this.manager.mo165from(nDArray).getHandle()).longValue(), min, max, step);
            long reshape = RustLibrary.reshape(fullSlice, nDIndexFullSlice.getSqueezedShape().getShape());
            RustLibrary.deleteTensor(fullSlice);
            RsNDArray rsNDArray = new RsNDArray(this.manager, reshape, nDArray.getDataType());
            NDScope.unregister(rsNDArray);
            nDScope.close();
            return rsNDArray;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, NDArray nDArray2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, Number number) {
        set(nDArray, nDIndexFullSlice, nDArray.getManager().create(number));
    }
}
