/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.ndarray.impl.sparse;

import java.nio.ReadOnlyBufferException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.LongStream;
import org.tensorflow.ndarray.IllegalRankException;
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArraySequence;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.SparseNdArray;
import org.tensorflow.ndarray.impl.AbstractNdArray;
import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray;
import org.tensorflow.ndarray.impl.dimension.Dimension;
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
import org.tensorflow.ndarray.impl.sequence.SingleElementSequence;
import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence;
import org.tensorflow.ndarray.impl.sparse.Validator;
import org.tensorflow.ndarray.index.Index;

public abstract class AbstractSparseNdArray<T, U extends NdArray<T>>
extends AbstractNdArray<T, U>
implements SparseNdArray<T, U> {
    private LongNdArray indices;
    private U values;
    private T defaultValue;
    private U defaultArray;

    protected AbstractSparseNdArray(LongNdArray indices, U values, T defaultValue, DimensionalSpace dimensions) {
        super(dimensions);
        this.indices = indices;
        this.values = values;
        this.setDefaultValue(defaultValue);
        if (this.indices.shape().get(0) != this.values.shape().get(0)) {
            throw new IllegalArgumentException(String.format("The number of rows in indices (%d) does not  match the number of elements in values(%d).", this.indices.shape().get(0), this.values.shape().get(0)));
        }
        if (this.indices.shape().get(1) != (long)this.shape().numDimensions()) {
            throw new IllegalArgumentException(String.format("The number of columns in indices (%d) does not  match the number of dimensions in shape (%d).", this.indices.shape().get(1), this.shape().get(0)));
        }
    }

    protected AbstractSparseNdArray(T defaultValue, DimensionalSpace dimensions) {
        super(dimensions);
        this.setDefaultValue(defaultValue);
    }

    @Override
    public NdArraySequence<U> elements(int dimensionIdx) {
        if (dimensionIdx >= this.shape().numDimensions()) {
            throw new IllegalArgumentException("Cannot iterate elements in dimension '" + dimensionIdx + "' of array with shape " + this.shape());
        }
        if (this.rank() == 0 && dimensionIdx < 0) {
            return new SingleElementSequence(this);
        }
        DimensionalSpace elemDims = this.dimensions().from(dimensionIdx + 1);
        return new SlicingElementSequence(this, dimensionIdx, elemDims);
    }

    protected long[] toCoordinates(DimensionalSpace dimensions, long position) {
        long[] result = new long[dimensions.numDimensions()];
        long p = position;
        for (int dim = 0; dim < dimensions.numDimensions(); ++dim) {
            Dimension dimension = dimensions.get(dim);
            result[dim] = p / dimension.elementSize();
            p %= dimension.elementSize();
        }
        return result;
    }

    protected long[] getIndicesCoordinates(LongNdArray l) {
        long[] results = new long[(int)l.size()];
        int i = 0;
        while ((long)i < l.size()) {
            results[i] = l.getLong(i);
            ++i;
        }
        return results;
    }

    public abstract U toDense();

    public U withShape(Shape shape) {
        throw new UnsupportedOperationException("Sparse NdArrays cannot be viewed with a different shape");
    }

    @Override
    public NdArray<T> slice(Index ... indices) {
        if (indices == null) {
            throw new IllegalArgumentException("Slicing requires at least one index");
        }
        RelativeDimensionalSpace sliceDimensions = this.dimensions().mapTo(indices);
        return this.slice(sliceDimensions.position(), sliceDimensions);
    }

    @Override
    public NdArray<T> get(long ... coordinates) {
        return this.slice(this.positionOf(coordinates, false), this.dimensions().from(coordinates.length));
    }

    @Override
    public T getObject(long ... coordinates) {
        if (coordinates.length != this.shape().numDimensions()) {
            throw new IllegalRankException(String.format("Length of coordinates (%s)%s does not match the rank %d", coordinates.length, Arrays.toString(coordinates), this.shape().numDimensions()));
        }
        long index = this.locateIndex(coordinates);
        if (index >= 0L) {
            return this.getValues().getObject(index);
        }
        return this.defaultValue;
    }

    @Override
    public NdArray<T> setObject(T value, long ... coords) {
        throw new ReadOnlyBufferException();
    }

    @Override
    public NdArray<T> set(NdArray<T> src, long ... coordinates) {
        throw new ReadOnlyBufferException();
    }

    public abstract U createValues(Shape var1);

    @Override
    public NdArray<T> copyTo(NdArray<T> dst) {
        if (dst instanceof AbstractSparseNdArray) {
            AbstractSparseNdArray sparse = (AbstractSparseNdArray)dst;
            LongNdArray indicesCopy = NdArrays.ofLongs(this.indices.shape());
            this.indices.copyTo((NdArray)indicesCopy);
            U valuesCopy = this.createValues(this.values.shape());
            this.values.copyTo(valuesCopy);
            sparse.setIndices(indicesCopy);
            sparse.setValues(valuesCopy);
        } else {
            U dense = this.toDense();
            dense.copyTo(dst);
        }
        return this;
    }

    protected long positionOf(long[] coords, boolean isValue) {
        if (coords == null || coords.length == 0) {
            return 0L;
        }
        Validator.coordinates(this.dimensions, coords, isValue);
        return this.dimensions.positionOf(coords);
    }

    @Override
    protected void slowCopyTo(NdArray<T> array) {
        if (array instanceof AbstractDenseNdArray) {
            AbstractDenseNdArray dst = (AbstractDenseNdArray)array;
            long offset = 0L;
            for (NdArray s : this.scalars()) {
                dst.setObject(s.getObject(new long[0]), offset++);
            }
        } else if (array instanceof AbstractSparseNdArray) {
            AbstractSparseNdArray dst = (AbstractSparseNdArray)array;
            this.indices.copyTo((NdArray)dst.getIndices());
            this.values.copyTo(dst.values);
        } else {
            super.slowCopyTo(array);
        }
    }

    @Override
    public LongNdArray getIndices() {
        return this.indices;
    }

    public void setIndices(LongNdArray indices) {
        this.indices = indices;
    }

    @Override
    public U getValues() {
        return this.values;
    }

    public void setValues(U values) {
        this.values = values;
    }

    protected long locateIndex(long[] coordinates) {
        long size = this.indices.shape().get(0);
        LongNdArray coordArray = NdArrays.vectorOf(coordinates);
        return this.binarySearch(size, coordArray);
    }

    @Override
    public int hashCode() {
        if (this.dimensions().isSegmented()) {
            return this.slowHashCode();
        }
        int prime = 31;
        int result = 1;
        result = 31 * result + this.indices.hashCode();
        result = 31 * result + this.values.hashCode();
        result = 31 * result + this.shape().hashCode();
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof AbstractSparseNdArray)) {
            return super.equals(obj);
        }
        AbstractSparseNdArray other = (AbstractSparseNdArray)obj;
        if (!this.shape().equals(other.shape())) {
            return false;
        }
        if (!this.indices.equals(other.indices)) {
            return false;
        }
        return this.values.equals(other.values);
    }

    public String toString() {
        long numElements;
        long l = numElements = this.values == null ? 0L : this.values.size();
        Object strDefault = this.defaultValue == null ? "<null>" : (this.defaultValue instanceof Number ? this.defaultValue.toString() : "'" + this.defaultValue + "'");
        return this.getClass().getSimpleName() + "(defaultValue=" + (String)strDefault + ", numElements=" + numElements + ", shape=" + this.shape() + ")";
    }

    private long binarySearch(long toIndex, LongNdArray coordinates) {
        long low = 0L;
        long high = toIndex - 1L;
        while (low <= high) {
            long mid = low + high >>> 1;
            LongNdArray comparable = this.indices.get(mid);
            int rc = this.compareCoordinates(comparable, coordinates);
            if (rc < 0) {
                low = mid + 1L;
                continue;
            }
            if (rc > 0) {
                high = mid - 1L;
                continue;
            }
            return mid;
        }
        return -(low + 1L);
    }

    public AbstractSparseNdArray<T, U> sortIndicesAndValues() {
        ArrayList indexes = new ArrayList();
        LongStream.range(0L, this.values.size()).forEach(indexes::add);
        indexes.sort((a, b) -> this.compareCoordinates(this.indices.get((long)a), this.indices.get((long)b)));
        LongNdArray newIndices = NdArrays.ofLongs(this.indices.shape());
        U newValues = this.createValues(this.values.shape());
        long i = 0L;
        while (i < (long)indexes.size()) {
            long moveIndex = (Long)indexes.get((int)i);
            newIndices.set((NdArray)this.indices.get(moveIndex), new long[]{i});
            newValues.setObject(this.values.getObject(moveIndex), i++);
        }
        this.indices = newIndices;
        this.values = newValues;
        return this;
    }

    private int compareCoordinates(LongNdArray a, LongNdArray b) {
        int rc = (int)(a.size() - b.size());
        if (rc != 0) {
            return rc;
        }
        long i = 0L;
        while (i < a.size()) {
            long l = a.getLong(i);
            if ((rc = (int)(l - b.getLong(i++))) == 0) continue;
            return rc;
        }
        return 0;
    }

    public T getDefaultValue() {
        return this.defaultValue;
    }

    public void setDefaultValue(T defaultValue) {
        this.defaultValue = defaultValue;
        this.defaultArray = null;
    }

    public abstract U createDefaultArray();

    public U getDefaultArray() {
        if (this.defaultArray == null) {
            this.defaultArray = this.createDefaultArray();
        }
        return this.defaultArray;
    }
}

