package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/djl/ndarray/index/NDArrayIndexer.class */
public abstract class NDArrayIndexer {
    public abstract NDArray get(NDArray nDArray, NDIndexFullPick nDIndexFullPick);

    public abstract NDArray get(NDArray nDArray, NDIndexFullTake nDIndexFullTake);

    public abstract NDArray get(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice);

    public NDArray get(NDArray nDArray, NDIndex nDIndex) {
        if (nDIndex.getRank() == 0 && nDArray.getShape().isScalar()) {
            return nDArray.duplicate();
        }
        List<NDIndexElement> indices = nDIndex.getIndices();
        if (!indices.isEmpty() && (indices.get(0) instanceof NDIndexBooleans)) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException("get() currently doesn't support more that one boolean NDArray");
            }
            return nDArray.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
        }
        Optional<NDIndexFullTake> fromIndex = NDIndexFullTake.fromIndex(nDIndex, nDArray.getShape());
        if (fromIndex.isPresent()) {
            return get(nDArray, fromIndex.get());
        }
        Optional<NDIndexFullPick> fromIndex2 = NDIndexFullPick.fromIndex(nDIndex, nDArray.getShape());
        if (fromIndex2.isPresent()) {
            return get(nDArray, fromIndex2.get());
        }
        Optional<NDIndexFullSlice> fromIndex3 = NDIndexFullSlice.fromIndex(nDIndex, nDArray.getShape());
        if (fromIndex3.isPresent()) {
            return get(nDArray, fromIndex3.get());
        }
        throw new UnsupportedOperationException("get() currently supports all, fixed, and slices indices");
    }

    public void set(NDArray nDArray, NDIndex nDIndex, Object obj) {
        NDIndexFullSlice orElse = NDIndexFullSlice.fromIndex(nDIndex, nDArray.getShape()).orElse(null);
        if (orElse != null) {
            if (obj instanceof Number) {
                set(nDArray, orElse, (Number) obj);
                return;
            } else {
                if (!(obj instanceof NDArray)) {
                    throw new IllegalArgumentException("The type of value to assign cannot be other than NDArray and Number.");
                }
                set(nDArray, orElse, (NDArray) obj);
                return;
            }
        }
        List<NDIndexElement> indices = nDIndex.getIndices();
        if (indices.isEmpty() || !(indices.get(0) instanceof NDIndexBooleans)) {
            throw new UnsupportedOperationException("set() currently supports all, fixed, and slices indices");
        }
        if (indices.size() != 1) {
            throw new IllegalArgumentException("set() currently doesn't support more than one boolean NDArray");
        }
        if (obj instanceof Number) {
            set(nDArray, (NDIndexBooleans) indices.get(0), nDArray.getManager().create((Number) obj));
        } else {
            if (!(obj instanceof NDArray)) {
                throw new IllegalArgumentException("The type of value to assign cannot be other than NDArray and Number.");
            }
            set(nDArray, (NDIndexBooleans) indices.get(0), (NDArray) obj);
        }
    }

    public abstract void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, NDArray nDArray2);

    public void set(NDArray nDArray, NDIndexBooleans nDIndexBooleans, NDArray nDArray2) {
        nDArray.intern(NDArrays.where(nDIndexBooleans.getIndex(), nDArray2, nDArray));
    }

    public abstract void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, Number number);

    public void setScalar(NDArray nDArray, NDIndex nDIndex, Number number) {
        NDIndexFullSlice orElse = NDIndexFullSlice.fromIndex(nDIndex, nDArray.getShape()).orElse(null);
        if (orElse == null) {
            throw new UnsupportedOperationException("set() currently supports all, fixed, and slices indices");
        }
        if (orElse.getShape().size() != 1) {
            throw new IllegalArgumentException("The provided index does not set a scalar");
        }
        set(nDArray, nDIndex, number);
    }
}
