package com.yahoo.tensor;

import com.google.common.collect.ImmutableSet;
import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/yahoo/tensor/TensorType.class */
public class TensorType {
    static Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + Ascii7BitMatcher.charsAndNumbers(), "_@$" + Ascii7BitMatcher.charsAndNumbers());
    public static final TensorType empty = new TensorType();
    private final Value valueType;
    private final List<Dimension> dimensions;
    private final Set<String> dimensionNames;
    private final TensorType mappedSubtype;
    private final TensorType indexedSubtype;
    private final int indexedUnBoundCount;

    /* renamed from: com.yahoo.tensor.TensorType$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/TensorType$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$TensorType$Value;

        static {
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Dimension$Type[Dimension.Type.indexedUnbound.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Dimension$Type[Dimension.Type.indexedBound.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Dimension$Type[Dimension.Type.mapped.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$com$yahoo$tensor$TensorType$Value = new int[Value.values().length];
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[Value.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[Value.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[Value.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$TensorType$Value[Value.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$Builder.class */
    public static final class Builder {
        private final Map<String, Dimension> dimensions;
        private final Value valueType;

        public Builder() {
            this(Value.DOUBLE);
        }

        public Builder(Value value) {
            this.dimensions = new LinkedHashMap();
            this.valueType = value;
        }

        public Builder(TensorType... tensorTypeArr) {
            this(true, tensorTypeArr);
        }

        public Builder(boolean z, TensorType... tensorTypeArr) {
            this.dimensions = new LinkedHashMap();
            this.valueType = TensorType.combinedValueType(tensorTypeArr);
            for (TensorType tensorType : tensorTypeArr) {
                addDimensionsOf(tensorType, z);
            }
        }

        public Builder(Iterable<Dimension> iterable) {
            this(Value.DOUBLE, iterable);
        }

        public Builder(Value value, Iterable<Dimension> iterable) {
            this.dimensions = new LinkedHashMap();
            this.valueType = value;
            Iterator<Dimension> it = iterable.iterator();
            while (it.hasNext()) {
                dimension(it.next());
            }
        }

        private void addDimensionsOf(TensorType tensorType, boolean z) {
            for (Dimension dimension : tensorType.dimensions) {
                set(dimension.combineWith(Optional.ofNullable(this.dimensions.get(dimension.name())), z));
            }
        }

        public int rank() {
            return this.dimensions.size();
        }

        private Builder add(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            if (this.dimensions.containsKey(dimension.name())) {
                throw new IllegalArgumentException("Could not add dimension " + String.valueOf(dimension) + " as this dimension is already present");
            }
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder set(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder indexed(String str, long j) {
            return add(new IndexedBoundDimension(str, j));
        }

        public Builder indexed(String str) {
            return add(new IndexedUnboundDimension(str));
        }

        public Builder mapped(String str) {
            return add(new MappedDimension(str));
        }

        public Builder dimension(Dimension dimension) {
            return add(dimension);
        }

        public Optional<Dimension> getDimension(String str) {
            return Optional.ofNullable(this.dimensions.get(str));
        }

        public Builder dimension(String str, Dimension.Type type) {
            switch (type) {
                case indexedUnbound:
                    indexed(str);
                    break;
                case mapped:
                    mapped(str);
                    break;
                default:
                    throw new IllegalArgumentException("This can not create a dimension of type " + String.valueOf(type));
            }
            return this;
        }

        public TensorType build() {
            return new TensorType(this.valueType, this.dimensions.values());
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$Dimension.class */
    public static abstract class Dimension implements Comparable<Dimension> {
        private final String name;

        /* loaded from: input_file:com/yahoo/tensor/TensorType$Dimension$Type.class */
        public enum Type {
            indexedBound,
            indexedUnbound,
            mapped
        }

        private Dimension(String str) {
            this.name = requireIdentifier(str);
        }

        public final String name() {
            return this.name;
        }

        public abstract Optional<Long> size();

        public abstract Type type();

        public abstract Dimension withName(String str);

        public Dimension withSize(long j) {
            return IndexedBoundDimension.indexed(this.name, j);
        }

        public boolean isIndexed() {
            return type() == Type.indexedBound || type() == Type.indexedUnbound;
        }

        public boolean isMapped() {
            return type() == Type.mapped;
        }

        Dimension combineWith(Optional<Dimension> optional, boolean z) {
            if (!optional.isEmpty() && !(this instanceof MappedDimension)) {
                if (optional.get() instanceof MappedDimension) {
                    return optional.get();
                }
                if (this instanceof IndexedUnboundDimension) {
                    return this;
                }
                if (optional.get() instanceof IndexedUnboundDimension) {
                    return optional.get();
                }
                IndexedBoundDimension indexedBoundDimension = (IndexedBoundDimension) this;
                IndexedBoundDimension indexedBoundDimension2 = (IndexedBoundDimension) optional.get();
                if (z) {
                    return indexedBoundDimension.size().get().longValue() < indexedBoundDimension2.size().get().longValue() ? indexedBoundDimension : indexedBoundDimension2;
                }
                if (indexedBoundDimension.size().equals(indexedBoundDimension2.size())) {
                    return indexedBoundDimension;
                }
                throw new IllegalArgumentException("Unequal dimension sizes in " + String.valueOf(indexedBoundDimension) + " and " + String.valueOf(indexedBoundDimension2));
            }
            return this;
        }

        public abstract String toString();

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            return this.name.equals(((Dimension) obj).name);
        }

        public int hashCode() {
            return this.name.hashCode();
        }

        @Override // java.lang.Comparable
        public int compareTo(Dimension dimension) {
            return this.name.compareTo(dimension.name);
        }

        public static Dimension indexed(String str, long j) {
            return new IndexedBoundDimension(str, j);
        }

        public static Dimension indexed(String str) {
            return new IndexedUnboundDimension(str);
        }

        public static Dimension mapped(String str) {
            return new MappedDimension(str);
        }

        private static String requireIdentifier(String str) {
            if (str == null) {
                throw new IllegalArgumentException("A dimension name cannot be null");
            }
            if (TensorType.labelMatcher.matches(str)) {
                return str;
            }
            throw new IllegalArgumentException("A dimension name must be an identifier or integer, not '" + str + "'");
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$IndexedBoundDimension.class */
    public static class IndexedBoundDimension extends Dimension {
        private final Long size;

        private IndexedBoundDimension(String str, long j) {
            super(str);
            if (j < 1) {
                throw new IllegalArgumentException("Size of bound dimension '" + str + "' must be at least 1");
            }
            if (j > 2147483647L) {
                throw new IllegalArgumentException("Size of bound dimension '" + str + "' cannot be larger than 2147483647");
            }
            this.size = Long.valueOf(j);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Long> size() {
            return Optional.of(this.size);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.indexedBound;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public IndexedBoundDimension withName(String str) {
            return new IndexedBoundDimension(str, this.size.longValue());
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "[" + this.size + "]";
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && super.equals(obj) && this.size.equals(((IndexedBoundDimension) obj).size);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public int hashCode() {
            return (31 * super.hashCode()) + this.size.hashCode();
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$IndexedUnboundDimension.class */
    public static class IndexedUnboundDimension extends Dimension {
        private IndexedUnboundDimension(String str) {
            super(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Long> size() {
            return Optional.empty();
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.indexedUnbound;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public IndexedUnboundDimension withName(String str) {
            return new IndexedUnboundDimension(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "[]";
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$MappedDimension.class */
    public static class MappedDimension extends Dimension {
        private MappedDimension(String str) {
            super(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Long> size() {
            return Optional.empty();
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.mapped;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public MappedDimension withName(String str) {
            return new MappedDimension(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "{}";
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$Value.class */
    public enum Value {
        DOUBLE("double"),
        FLOAT("float"),
        BFLOAT16("bfloat16"),
        INT8("int8");

        private final String id;

        /* JADX INFO: Access modifiers changed from: package-private */
        public int sizeOfCell() {
            switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$TensorType$Value[ordinal()]) {
                case 1:
                    return 8;
                case 2:
                    return 4;
                case 3:
                    return 2;
                case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                    return 1;
                default:
                    throw new IncompatibleClassChangeError();
            }
        }

        Value(String str) {
            this.id = str;
        }

        public String id() {
            return this.id;
        }

        public boolean isEqualOrLargerThan(Value value) {
            return this == value || largestOf(this, value) == this;
        }

        public static Value largestOf(List<Value> list) {
            if (list.isEmpty()) {
                return DOUBLE;
            }
            Value value = null;
            for (Value value2 : list) {
                value = value == null ? value2 : largestOf(value, value2);
            }
            return value;
        }

        public static Value largestOf(Value value, Value value2) {
            if (value == DOUBLE || value2 == DOUBLE) {
                return DOUBLE;
            }
            if (value == FLOAT || value2 == FLOAT) {
                return FLOAT;
            }
            if (value == BFLOAT16 || value2 == BFLOAT16) {
                return BFLOAT16;
            }
            if (value == INT8 && value2 == INT8) {
                return INT8;
            }
            throw new IllegalArgumentException("Cannot find largest of " + String.valueOf(value) + " and " + String.valueOf(value2));
        }

        @Override // java.lang.Enum
        public String toString() {
            return name().toLowerCase();
        }

        public static Value fromId(String str) {
            for (Value value : values()) {
                if (value.id.equals(str)) {
                    return value;
                }
            }
            throw new IllegalArgumentException("Value type must be either 'double', 'float', 'bfloat16', or 'int8' but was '" + str + "'");
        }
    }

    private TensorType() {
        this.valueType = Value.DOUBLE;
        this.dimensions = List.of();
        this.dimensionNames = Set.of();
        this.mappedSubtype = this;
        this.indexedSubtype = this;
        this.indexedUnBoundCount = 0;
    }

    public TensorType(Value value, Collection<Dimension> collection) {
        this.valueType = value;
        ArrayList arrayList = new ArrayList(collection);
        Collections.sort(arrayList);
        this.dimensions = List.copyOf(arrayList);
        ImmutableSet.Builder builder = new ImmutableSet.Builder();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            builder.add(((Dimension) it.next()).name());
            switch (r0.type()) {
                case indexedUnbound:
                    i2++;
                    break;
                case indexedBound:
                    i++;
                    break;
                case mapped:
                    i3++;
                    break;
            }
        }
        this.indexedUnBoundCount = i2;
        this.dimensionNames = builder.build();
        if (i3 == 0) {
            this.mappedSubtype = empty;
            this.indexedSubtype = this;
        } else if (i + i2 == 0) {
            this.mappedSubtype = this;
            this.indexedSubtype = empty;
        } else {
            this.mappedSubtype = new TensorType(value, collection.stream().filter(dimension -> {
                return !dimension.isIndexed();
            }).toList());
            this.indexedSubtype = new TensorType(value, collection.stream().filter((v0) -> {
                return v0.isIndexed();
            }).toList());
        }
    }

    public boolean hasIndexedDimensions() {
        return this.indexedSubtype != empty;
    }

    public boolean hasMappedDimensions() {
        return this.mappedSubtype != empty;
    }

    public boolean hasOnlyIndexedBoundDimensions() {
        return (hasMappedDimensions() || hasIndexedUnboundDimensions()) ? false : true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean hasIndexedUnboundDimensions() {
        return this.indexedUnBoundCount > 0;
    }

    public static Value combinedValueType(TensorType... tensorTypeArr) {
        ArrayList arrayList = new ArrayList();
        for (TensorType tensorType : tensorTypeArr) {
            if (tensorType.rank() > 0) {
                arrayList.add(tensorType.valueType());
            }
        }
        return Value.largestOf(arrayList);
    }

    public static TensorType fromSpec(String str) {
        return TensorTypeParser.fromSpec(str);
    }

    public Value valueType() {
        return this.valueType;
    }

    public TensorType mappedSubtype() {
        return this.mappedSubtype;
    }

    public TensorType indexedSubtype() {
        return this.indexedSubtype;
    }

    public int rank() {
        return this.dimensions.size();
    }

    public List<Dimension> dimensions() {
        return this.dimensions;
    }

    public Set<String> dimensionNames() {
        return this.dimensionNames;
    }

    public Optional<Dimension> dimension(String str) {
        Optional<Integer> indexOfDimension = indexOfDimension(str);
        List<Dimension> list = this.dimensions;
        Objects.requireNonNull(list);
        return indexOfDimension.map((v1) -> {
            return r1.get(v1);
        });
    }

    public Optional<Integer> indexOfDimension(String str) {
        for (int i = 0; i < this.dimensions.size(); i++) {
            if (this.dimensions.get(i).name().equals(str)) {
                return Optional.of(Integer.valueOf(i));
            }
        }
        return Optional.empty();
    }

    public int indexOfDimensionAsInt(String str) {
        for (int i = 0; i < this.dimensions.size(); i++) {
            if (this.dimensions.get(i).name().equals(str)) {
                return i;
            }
        }
        return -1;
    }

    public Optional<Long> sizeOfDimension(String str) {
        Optional<Dimension> dimension = dimension(str);
        return dimension.isEmpty() ? Optional.empty() : dimension.get().size();
    }

    public boolean isAssignableTo(TensorType tensorType) {
        return isConvertibleOrAssignableTo(tensorType, false, true);
    }

    public boolean isConvertibleTo(TensorType tensorType) {
        return isConvertibleOrAssignableTo(tensorType, true, true);
    }

    public boolean isRenamableTo(TensorType tensorType) {
        return isConvertibleOrAssignableTo(tensorType, false, false);
    }

    private boolean isConvertibleOrAssignableTo(TensorType tensorType, boolean z, boolean z2) {
        if (!tensorType.valueType().isEqualOrLargerThan(this.valueType) || tensorType.dimensions().size() != dimensions().size()) {
            return false;
        }
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            Dimension dimension = dimensions().get(i);
            Dimension dimension2 = tensorType.dimensions().get(i);
            if (dimension.isIndexed() != dimension2.isIndexed()) {
                return false;
            }
            if (z2 && !dimension.name().equals(dimension2.name())) {
                return false;
            }
            if (dimension2.size().isPresent()) {
                if (dimension.size().isEmpty()) {
                    return false;
                }
                if (z) {
                    if (dimension.size().get().longValue() > dimension2.size().get().longValue()) {
                        return false;
                    }
                } else if (!dimension.size().equals(dimension2.size())) {
                    return false;
                }
            }
        }
        return true;
    }

    public String toString() {
        return "tensor" + (this.valueType == Value.DOUBLE ? "" : "<" + this.valueType.id() + ">") + "(" + ((String) this.dimensions.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(","))) + ")";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TensorType tensorType = (TensorType) obj;
        if (rank() == 0 && tensorType.rank() == 0) {
            return true;
        }
        return this.valueType == tensorType.valueType && this.dimensions.equals(tensorType.dimensions);
    }

    public boolean mathematicallyEquals(TensorType tensorType) {
        if (dimensions().size() != tensorType.dimensions().size()) {
            return false;
        }
        for (int i = 0; i < dimensions().size(); i++) {
            if (!dimensions().get(i).name().equals(tensorType.dimensions().get(i).name())) {
                return false;
            }
        }
        return true;
    }

    public Optional<TensorType> dimensionwiseGeneralizationWith(TensorType tensorType) {
        if (equals(tensorType)) {
            return Optional.of(this);
        }
        if (this.dimensions.size() != tensorType.dimensions.size()) {
            return Optional.empty();
        }
        Builder builder = new Builder(Value.largestOf(this.valueType, tensorType.valueType));
        for (int i = 0; i < this.dimensions.size(); i++) {
            Dimension dimension = dimensions().get(i);
            Dimension dimension2 = tensorType.dimensions().get(i);
            if (!dimension.name().equals(dimension2.name())) {
                return Optional.empty();
            }
            if (!dimension.isIndexed() || !dimension2.isIndexed()) {
                if (dimension.isIndexed() || dimension2.isIndexed()) {
                    return Optional.empty();
                }
                builder.dimension(dimension);
            } else if (dimension.size().isPresent() && dimension2.size().isPresent()) {
                if (!dimension.size().equals(dimension2.size())) {
                    return Optional.empty();
                }
                builder.dimension(dimension);
            } else if (dimension.size().isPresent()) {
                builder.dimension(dimension2);
            } else if (dimension2.size().isPresent()) {
                builder.dimension(dimension);
            } else {
                builder.dimension(dimension);
            }
        }
        return Optional.of(builder.build());
    }

    public int hashCode() {
        return Objects.hash(this.dimensions, this.valueType);
    }
}
