package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/tensor/TensorShape.class */
public class TensorShape {
    public static TensorShape one = of(1, 1);
    private final int[] tshape;
    private final long capacity;
    private final Optional<Pair<Integer, Integer>> sparseColumnRange;
    private final Optional<Pair<Integer, Integer>> sparseRowRange;
    private final boolean isSparse;
    private final int sparseColumnOffset;
    private final int sparseColumnLength;
    private final int sparseRowOffset;
    private final int sparseRowLength;

    public static TensorShape of(int... iArr) {
        if (iArr.length == 1) {
            iArr = new int[]{1, iArr[0]};
        }
        return new TensorShape(iArr, Optional.empty(), Optional.empty());
    }

    public static TensorShape sparseColumn(int[] iArr, Pair<Integer, Integer> pair) {
        return new TensorShape(iArr, Optional.empty(), Optional.of(pair));
    }

    public static TensorShape sparseRow(int[] iArr, Pair<Integer, Integer> pair) {
        return new TensorShape(iArr, Optional.of(pair), Optional.empty());
    }

    private TensorShape(int[] iArr, Optional<Pair<Integer, Integer>> optional, Optional<Pair<Integer, Integer>> optional2) {
        Preconditions.checkArgument(iArr.length > 1, "Shape must have at least two dimensions, even if first is 1 (to represent a vector)");
        this.tshape = iArr;
        this.sparseColumnRange = optional2;
        this.sparseRowRange = optional;
        this.isSparse = this.sparseColumnRange.isPresent() || this.sparseRowRange.isPresent();
        this.sparseColumnOffset = ((Integer) this.sparseColumnRange.map((v0) -> {
            return v0.left();
        }).orElse(0)).intValue();
        this.sparseColumnLength = ((Integer) this.sparseColumnRange.map((v0) -> {
            return v0.right();
        }).orElse(Integer.valueOf(iArr[iArr.length - 1]))).intValue();
        this.sparseRowOffset = ((Integer) this.sparseRowRange.map((v0) -> {
            return v0.left();
        }).orElse(0)).intValue();
        this.sparseRowLength = ((Integer) this.sparseRowRange.map((v0) -> {
            return v0.right();
        }).orElse(Integer.valueOf(iArr[iArr.length - 2]))).intValue();
        long j = 1;
        for (int i = 0; i < iArr.length - 2; i++) {
            j *= iArr[i];
        }
        this.capacity = j * this.sparseRowLength * this.sparseColumnLength;
    }

    public final boolean isSparse() {
        return this.isSparse;
    }

    public int dims() {
        return this.tshape.length;
    }

    public int dim(int i) {
        Preconditions.checkArgument(i < this.tshape.length);
        return this.tshape[i];
    }

    public final int getOffset(int... iArr) {
        switch (iArr.length) {
            case 1:
                return (this.sparseColumnLength * (iArr[0] - this.sparseRowOffset)) - this.sparseColumnOffset;
            case 2:
                return ((this.sparseColumnLength * (iArr[0] - this.sparseRowOffset)) + iArr[1]) - this.sparseColumnOffset;
            case 3:
                return ((((this.sparseColumnLength * this.tshape[1]) * (iArr[0] - this.sparseRowOffset)) + (this.sparseColumnLength * iArr[1])) + iArr[2]) - this.sparseColumnOffset;
            default:
                int i = 0;
                for (int i2 = 0; i2 < iArr.length - 1; i2++) {
                    int i3 = this.sparseColumnLength;
                    for (int length = this.tshape.length - 2; length > i2; length--) {
                        i3 *= this.tshape[length];
                    }
                    i += iArr[i2] * i3;
                }
                return (i + iArr[iArr.length - 1]) - this.sparseColumnOffset;
        }
    }

    public int sparseColumnLength() {
        return this.sparseColumnLength;
    }

    public int sparseColumnOffset() {
        return this.sparseColumnOffset;
    }

    public int sparseRowLength() {
        return this.sparseRowLength;
    }

    public int sparseRowOffset() {
        return this.sparseRowOffset;
    }

    public TensorShape scaleLastDim(float f) {
        int[] copyOf = Arrays.copyOf(this.tshape, this.tshape.length);
        copyOf[copyOf.length - 1] = (int) (copyOf[r1] * f);
        return this.sparseColumnRange.isPresent() ? sparseColumn(copyOf, Pair.of(Integer.valueOf((int) (this.sparseColumnOffset * f)), Integer.valueOf((int) (this.sparseColumnLength * f)))) : of(copyOf);
    }

    public TensorShape setDimValue(int i, int i2) {
        Preconditions.checkArgument(i < this.tshape.length);
        int[] copyOf = Arrays.copyOf(this.tshape, this.tshape.length);
        copyOf[i] = i2;
        return this.sparseColumnRange.isPresent() ? sparseColumn(copyOf, Pair.of(Integer.valueOf(this.sparseColumnOffset), Integer.valueOf(copyOf[copyOf.length - 1]))) : of(copyOf);
    }

    public int first() {
        return this.tshape[0];
    }

    public int last() {
        return this.tshape[this.tshape.length - 1];
    }

    public long size() {
        return this.capacity;
    }

    public TensorShape sparsifyColumns(int i, int i2) {
        Preconditions.checkArgument(!this.isSparse, "Cannot sparsify a sparse tensor");
        return new TensorShape(this.tshape, Optional.empty(), Optional.of(Pair.of(Integer.valueOf(i), Integer.valueOf(i2))));
    }

    public TensorShape slice(int i) {
        Preconditions.checkArgument(i < this.tshape.length, "Too many dimensions specified for tensor");
        return this.tshape.length - i == 1 ? new TensorShape(new int[]{1, this.tshape[this.tshape.length - 1]}, this.sparseRowRange, this.sparseColumnRange) : new TensorShape(Arrays.copyOfRange(this.tshape, i, this.tshape.length), this.sparseRowRange, this.sparseColumnRange);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TensorShape tensorShape = (TensorShape) obj;
        return Arrays.equals(this.tshape, tensorShape.tshape) && Objects.equals(this.sparseColumnRange, tensorShape.sparseColumnRange);
    }

    public int hashCode() {
        return (31 * Objects.hash(this.sparseColumnRange)) + Arrays.hashCode(this.tshape);
    }

    public String toString() {
        String arrays = Arrays.toString(this.tshape);
        long j = this.capacity;
        String.valueOf(this.sparseColumnRange);
        return "TensorShape{tshape=" + arrays + ", capacity=" + j + ", sparseRange=" + arrays + "}";
    }
}
