package com.yahoo.tensor;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.UnmodifiableIterator;
import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/yahoo/tensor/MixedTensor.class */
public class MixedTensor implements Tensor {
    private final TensorType type;
    private final List<Tensor.Cell> cells;
    private final Index index;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.yahoo.tensor.MixedTensor$2, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$TensorType$Value = new int[TensorType.Value.values().length];

        static {
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[TensorType.Value.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$BoundBuilder.class */
    public static class BoundBuilder extends Builder {
        private final Map<TensorAddress, double[]> denseSubspaceMap;
        private final Index.Builder indexBuilder;
        private final Index index;
        private final TensorType denseSubtype;

        private BoundBuilder(TensorType tensorType) {
            super(tensorType);
            this.denseSubspaceMap = new HashMap();
            this.indexBuilder = new Index.Builder(tensorType);
            this.index = this.indexBuilder.index();
            this.denseSubtype = new TensorType(tensorType.valueType(), tensorType.dimensions().stream().filter((v0) -> {
                return v0.isIndexed();
            }).toList());
        }

        public long denseSubspaceSize() {
            return this.index.denseSubspaceSize();
        }

        private double[] denseSubspace(TensorAddress tensorAddress) {
            if (!this.denseSubspaceMap.containsKey(tensorAddress)) {
                this.denseSubspaceMap.put(tensorAddress, new double[(int) denseSubspaceSize()]);
            }
            return this.denseSubspaceMap.get(tensorAddress);
        }

        public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress tensorAddress) {
            double[] dArr = new double[(int) denseSubspaceSize()];
            this.denseSubspaceMap.put(tensorAddress, dArr);
            return new DenseSubspaceBuilder(this.denseSubtype, dArr);
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(TensorAddress tensorAddress, float f) {
            return cell(tensorAddress, f);
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(TensorAddress tensorAddress, double d) {
            denseSubspace(this.index.sparsePartialAddress(tensorAddress))[(int) this.index.denseOffset(tensorAddress)] = d;
            return this;
        }

        public Tensor.Builder block(TensorAddress tensorAddress, double[] dArr) {
            int denseSubspaceSize = (int) denseSubspaceSize();
            if (dArr.length < denseSubspaceSize) {
                throw new IllegalArgumentException("Block should have " + denseSubspaceSize + " values, but has only " + dArr.length);
            }
            System.arraycopy(dArr, 0, denseSubspace(tensorAddress), 0, denseSubspaceSize);
            return this;
        }

        @Override // com.yahoo.tensor.MixedTensor.Builder, com.yahoo.tensor.Tensor.Builder
        public MixedTensor build() {
            long j = 0;
            ArrayList arrayList = new ArrayList();
            for (Map.Entry<TensorAddress, double[]> entry : this.denseSubspaceMap.entrySet()) {
                TensorAddress key = entry.getKey();
                this.indexBuilder.put(key, j);
                double[] value = entry.getValue();
                long j2 = 0;
                while (true) {
                    long j3 = j2;
                    if (j3 < value.length) {
                        arrayList.add(new Tensor.Cell(this.index.addressOf(key, j3), value[(int) j3]));
                        j++;
                        j2 = j3 + 1;
                    }
                }
            }
            return new MixedTensor(this.type, arrayList, this.indexBuilder.build());
        }

        public static BoundBuilder of(TensorType tensorType) {
            return new BoundBuilder(tensorType);
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$Builder.class */
    public static abstract class Builder implements Tensor.Builder {
        final TensorType type;

        public static Builder of(TensorType tensorType) {
            return tensorType.dimensions().stream().anyMatch(dimension -> {
                return dimension instanceof TensorType.IndexedUnboundDimension;
            }) ? new UnboundBuilder(tensorType) : new BoundBuilder(tensorType);
        }

        private Builder(TensorType tensorType) {
            this.type = tensorType;
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public TensorType type() {
            return this.type;
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(float f, long... jArr) {
            return cell(f, jArr);
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(double d, long... jArr) {
            throw new UnsupportedOperationException("Not implemented.");
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder.CellBuilder cell() {
            return new Tensor.Builder.CellBuilder(type(), this);
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public abstract MixedTensor build();
    }

    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$DenseSubspaceBuilder.class */
    private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder {
        private final TensorType type;
        private final double[] values;

        public DenseSubspaceBuilder(TensorType tensorType, double[] dArr) {
            this.type = tensorType;
            this.values = dArr;
        }

        @Override // com.yahoo.tensor.IndexedTensor.DirectIndexBuilder
        public TensorType type() {
            return this.type;
        }

        @Override // com.yahoo.tensor.IndexedTensor.DirectIndexBuilder
        public void cellByDirectIndex(long j, double d) {
            this.values[(int) j] = d;
        }

        @Override // com.yahoo.tensor.IndexedTensor.DirectIndexBuilder
        public void cellByDirectIndex(long j, float f) {
            this.values[(int) j] = f;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$Index.class */
    public static class Index {
        private final TensorType type;
        private final TensorType sparseType;
        private final TensorType denseType;
        private final List<TensorType.Dimension> mappedDimensions;
        private final List<TensorType.Dimension> indexedDimensions;
        private ImmutableMap<TensorAddress, Long> sparseMap;
        private long denseSubspaceSize = -1;

        /* loaded from: input_file:com/yahoo/tensor/MixedTensor$Index$Builder.class */
        public static class Builder {
            private final Index index;
            private final ImmutableMap.Builder<TensorAddress, Long> builder = new ImmutableMap.Builder<>();

            public Builder(TensorType tensorType) {
                this.index = new Index(tensorType);
            }

            public void put(TensorAddress tensorAddress, long j) {
                this.builder.put(tensorAddress, Long.valueOf(j));
            }

            public Index build() {
                this.index.sparseMap = this.builder.build();
                return this.index;
            }

            public Index index() {
                return this.index;
            }
        }

        private Index(TensorType tensorType) {
            this.type = tensorType;
            this.mappedDimensions = tensorType.dimensions().stream().filter(dimension -> {
                return !dimension.isIndexed();
            }).toList();
            this.indexedDimensions = tensorType.dimensions().stream().filter((v0) -> {
                return v0.isIndexed();
            }).toList();
            this.sparseType = MixedTensor.createPartialType(tensorType.valueType(), this.mappedDimensions);
            this.denseType = MixedTensor.createPartialType(tensorType.valueType(), this.indexedDimensions);
        }

        public long indexOf(TensorAddress tensorAddress) {
            TensorAddress sparsePartialAddress = sparsePartialAddress(tensorAddress);
            if (this.sparseMap.containsKey(sparsePartialAddress)) {
                return ((Long) this.sparseMap.get(sparsePartialAddress)).longValue() + denseOffset(tensorAddress);
            }
            return -1L;
        }

        public long denseSubspaceSize() {
            if (this.denseSubspaceSize == -1) {
                this.denseSubspaceSize = 1L;
                for (int i = 0; i < this.type.dimensions().size(); i++) {
                    TensorType.Dimension dimension = this.type.dimensions().get(i);
                    if (dimension.isIndexed()) {
                        this.denseSubspaceSize *= dimension.size().orElseThrow(() -> {
                            return new IllegalArgumentException("Unknown size of indexed dimension");
                        }).longValue();
                    }
                }
            }
            return this.denseSubspaceSize;
        }

        private TensorAddress sparsePartialAddress(TensorAddress tensorAddress) {
            if (this.type.dimensions().size() != tensorAddress.size()) {
                throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + tensorAddress);
            }
            TensorAddress.Builder builder = new TensorAddress.Builder(this.sparseType);
            for (int i = 0; i < this.type.dimensions().size(); i++) {
                TensorType.Dimension dimension = this.type.dimensions().get(i);
                if (!dimension.isIndexed()) {
                    builder.add(dimension.name(), tensorAddress.label(i));
                }
            }
            return builder.build();
        }

        private long denseOffset(TensorAddress tensorAddress) {
            long j = 1;
            long j2 = 0;
            int size = this.type.dimensions().size();
            while (true) {
                size--;
                if (size < 0) {
                    return j2;
                }
                TensorType.Dimension dimension = this.type.dimensions().get(size);
                if (dimension.isIndexed()) {
                    j2 += tensorAddress.numericLabel(size) * j;
                    j *= dimension.size().orElseThrow(() -> {
                        return new IllegalArgumentException("Unknown size of indexed dimension.");
                    }).longValue();
                }
            }
        }

        private TensorAddress denseOffsetToAddress(long j) {
            if (j < 0 || j > this.denseSubspaceSize) {
                throw new IllegalArgumentException("Offset out of bounds");
            }
            long j2 = j;
            long j3 = this.denseSubspaceSize;
            long[] jArr = new long[this.indexedDimensions.size()];
            for (int i = 0; i < jArr.length; i++) {
                j3 /= this.indexedDimensions.get(i).size().orElseThrow(() -> {
                    return new IllegalArgumentException("Unknown size of indexed dimension.");
                }).longValue();
                jArr[i] = j2 / j3;
                j2 %= j3;
            }
            return TensorAddress.of(jArr);
        }

        private TensorAddress addressOf(TensorAddress tensorAddress, long j) {
            TensorAddress denseOffsetToAddress = denseOffsetToAddress(j);
            String[] strArr = new String[this.type.dimensions().size()];
            int i = 0;
            int i2 = 0;
            Iterator<TensorType.Dimension> it = this.type.dimensions().iterator();
            while (it.hasNext()) {
                if (it.next().isIndexed()) {
                    strArr[i + i2] = denseOffsetToAddress.label(i2);
                    i2++;
                } else {
                    strArr[i + i2] = tensorAddress.label(i);
                    i++;
                }
            }
            return TensorAddress.of(strArr);
        }

        public String toString() {
            return "index into " + this.type;
        }

        private String contentToString(MixedTensor mixedTensor, long j) {
            if (this.mappedDimensions.size() > 1) {
                throw new IllegalStateException("Should be ensured by caller");
            }
            if (this.mappedDimensions.size() == 0) {
                StringBuilder sb = new StringBuilder();
                int denseSubspaceToString = denseSubspaceToString(mixedTensor, 0L, j, sb);
                if (denseSubspaceToString == j && denseSubspaceToString < mixedTensor.size()) {
                    sb.append("...]");
                }
                return sb.toString();
            }
            StringBuilder sb2 = new StringBuilder("{");
            ArrayList arrayList = new ArrayList((Collection) this.sparseMap.entrySet());
            arrayList.sort(Map.Entry.comparingByKey());
            int i = 0;
            for (int i2 = 0; i2 < arrayList.size() && i < j; i2++) {
                if (i2 > 0) {
                    sb2.append(", ");
                }
                sb2.append(TensorAddress.labelToString(((TensorAddress) ((Map.Entry) arrayList.get(i2)).getKey()).label(0)));
                sb2.append(":");
                i += denseSubspaceToString(mixedTensor, ((Long) ((Map.Entry) arrayList.get(i2)).getValue()).longValue(), j - i, sb2);
            }
            if (i >= j && i < mixedTensor.size()) {
                sb2.append(", ...");
            }
            sb2.append("}");
            return sb2.toString();
        }

        private int denseSubspaceToString(MixedTensor mixedTensor, long j, long j2, StringBuilder sb) {
            if (j2 <= 0) {
                return 0;
            }
            if (this.denseSubspaceSize == 1) {
                sb.append(getDouble(j, 0L, mixedTensor));
                return 1;
            }
            IndexedTensor.Indexes of = IndexedTensor.Indexes.of(this.denseType);
            int i = 0;
            while (i < this.denseSubspaceSize && i < j2) {
                of.next();
                if (i > 0) {
                    sb.append(", ");
                }
                for (int i2 = 0; i2 < of.nextDimensionsAtStart(); i2++) {
                    sb.append("[");
                }
                switch (AnonymousClass2.$SwitchMap$com$yahoo$tensor$TensorType$Value[this.type.valueType().ordinal()]) {
                    case 1:
                        sb.append(getDouble(j, i, mixedTensor));
                        break;
                    case 2:
                        sb.append(getDouble(j, i, mixedTensor));
                        break;
                    case 3:
                        sb.append(getDouble(j, i, mixedTensor));
                        break;
                    case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                        sb.append(getDouble(j, i, mixedTensor));
                        break;
                    default:
                        throw new IllegalStateException("Unexpected value type " + this.type.valueType());
                }
                for (int i3 = 0; i3 < of.nextDimensionsAtEnd(); i3++) {
                    sb.append("]");
                }
                i++;
            }
            return i;
        }

        private double getDouble(long j, long j2, MixedTensor mixedTensor) {
            return mixedTensor.cells.get((int) (j + j2)).getDoubleValue();
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/MixedTensor$UnboundBuilder.class */
    public static class UnboundBuilder extends Builder {
        private final Map<TensorAddress, Double> cells;
        private final long[] dimensionBounds;

        private UnboundBuilder(TensorType tensorType) {
            super(tensorType);
            this.cells = new HashMap();
            this.dimensionBounds = new long[tensorType.dimensions().size()];
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(TensorAddress tensorAddress, float f) {
            return cell(tensorAddress, f);
        }

        @Override // com.yahoo.tensor.Tensor.Builder
        public Tensor.Builder cell(TensorAddress tensorAddress, double d) {
            this.cells.put(tensorAddress, Double.valueOf(d));
            trackBounds(tensorAddress);
            return this;
        }

        @Override // com.yahoo.tensor.MixedTensor.Builder, com.yahoo.tensor.Tensor.Builder
        public MixedTensor build() {
            BoundBuilder boundBuilder = new BoundBuilder(createBoundType());
            for (Map.Entry<TensorAddress, Double> entry : this.cells.entrySet()) {
                boundBuilder.cell(entry.getKey(), entry.getValue().doubleValue());
            }
            return boundBuilder.build();
        }

        public void trackBounds(TensorAddress tensorAddress) {
            for (int i = 0; i < this.type.dimensions().size(); i++) {
                if (this.type.dimensions().get(i).isIndexed()) {
                    this.dimensionBounds[i] = Math.max(tensorAddress.numericLabel(i), this.dimensionBounds[i]);
                }
            }
        }

        public TensorType createBoundType() {
            TensorType.Builder builder = new TensorType.Builder(type().valueType());
            for (int i = 0; i < this.type.dimensions().size(); i++) {
                TensorType.Dimension dimension = this.type.dimensions().get(i);
                if (dimension.isIndexed()) {
                    builder.indexed(dimension.name(), dimension.size().orElse(Long.valueOf(this.dimensionBounds[i] + 1)).longValue());
                } else {
                    builder.mapped(dimension.name());
                }
            }
            return builder.build();
        }

        public static UnboundBuilder of(TensorType tensorType) {
            return new UnboundBuilder(tensorType);
        }
    }

    private MixedTensor(TensorType tensorType, List<Tensor.Cell> list, Index index) {
        this.type = tensorType;
        this.cells = List.copyOf(list);
        this.index = index;
    }

    @Override // com.yahoo.tensor.Tensor
    public TensorType type() {
        return this.type;
    }

    @Override // com.yahoo.tensor.Tensor
    public long size() {
        return this.cells.size();
    }

    @Override // com.yahoo.tensor.Tensor
    public double get(TensorAddress tensorAddress) {
        long indexOf = this.index.indexOf(tensorAddress);
        if (indexOf < 0 || indexOf >= this.cells.size()) {
            return 0.0d;
        }
        Tensor.Cell cell = this.cells.get((int) indexOf);
        if (tensorAddress.equals(cell.getKey())) {
            return cell.getValue().doubleValue();
        }
        return 0.0d;
    }

    @Override // com.yahoo.tensor.Tensor
    public boolean has(TensorAddress tensorAddress) {
        long indexOf = this.index.indexOf(tensorAddress);
        return indexOf >= 0 && indexOf < ((long) this.cells.size()) && tensorAddress.equals(this.cells.get((int) indexOf).getKey());
    }

    @Override // com.yahoo.tensor.Tensor
    public Iterator<Tensor.Cell> cellIterator() {
        return this.cells.iterator();
    }

    @Override // com.yahoo.tensor.Tensor
    public Iterator<Double> valueIterator() {
        return new Iterator<Double>() { // from class: com.yahoo.tensor.MixedTensor.1
            final Iterator<Tensor.Cell> cellIterator;

            {
                this.cellIterator = MixedTensor.this.cellIterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.cellIterator.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Double next() {
                return this.cellIterator.next().getValue();
            }
        };
    }

    @Override // com.yahoo.tensor.Tensor
    public Map<TensorAddress, Double> cells() {
        ImmutableMap.Builder builder = new ImmutableMap.Builder();
        for (Tensor.Cell cell : this.cells) {
            builder.put(cell.getKey(), cell.getValue());
        }
        return builder.build();
    }

    @Override // com.yahoo.tensor.Tensor
    public Tensor withType(TensorType tensorType) {
        if (this.type.isRenamableTo(this.type)) {
            return new MixedTensor(tensorType, this.cells, this.index);
        }
        throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + this.type + "'");
    }

    @Override // com.yahoo.tensor.Tensor
    public Tensor remove(Set<TensorAddress> set) {
        Tensor.Builder of = Tensor.Builder.of(type());
        UnmodifiableIterator it = this.index.sparseMap.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            if (!set.contains((TensorAddress) entry.getKey())) {
                long longValue = ((Long) entry.getValue()).longValue();
                for (int i = 0; i < this.index.denseSubspaceSize; i++) {
                    Tensor.Cell cell = this.cells.get(((int) longValue) + i);
                    of.cell(cell.getKey(), cell.getValue().doubleValue());
                }
            }
        }
        return of.build();
    }

    @Override // com.yahoo.tensor.Tensor
    public int hashCode() {
        return this.cells.hashCode();
    }

    @Override // com.yahoo.tensor.Tensor
    public String toString() {
        return toString(true, true);
    }

    @Override // com.yahoo.tensor.Tensor
    public String toString(boolean z, boolean z2) {
        return toString(z, z2, Long.MAX_VALUE);
    }

    @Override // com.yahoo.tensor.Tensor
    public String toAbbreviatedString(boolean z, boolean z2) {
        return toString(z, z2, Math.max(2L, 10 / (type().dimensions().stream().filter((v0) -> {
            return v0.isMapped();
        }).count() + 1)));
    }

    private String toString(boolean z, boolean z2, long j) {
        if (!z2 || this.type.rank() == 0 || ((this.type.rank() > 1 && this.type.dimensions().stream().filter((v0) -> {
            return v0.isIndexed();
        }).anyMatch(dimension -> {
            return dimension.size().isEmpty();
        })) || this.type.dimensions().stream().filter((v0) -> {
            return v0.isMapped();
        }).count() > 1)) {
            return Tensor.toStandardString(this, z, z2, j);
        }
        return (z ? this.type + ":" : "") + this.index.contentToString(this, j);
    }

    @Override // com.yahoo.tensor.Tensor
    public boolean equals(Object obj) {
        if (obj instanceof Tensor) {
            return Tensor.equals(this, (Tensor) obj);
        }
        return false;
    }

    public long denseSubspaceSize() {
        return this.index.denseSubspaceSize();
    }

    public static TensorType createPartialType(TensorType.Value value, List<TensorType.Dimension> list) {
        TensorType.Builder builder = new TensorType.Builder(value);
        Iterator<TensorType.Dimension> it = list.iterator();
        while (it.hasNext()) {
            builder.set(it.next());
        }
        return builder.build();
    }
}
