package com.yahoo.tensor;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.UnmodifiableIterator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:com/yahoo/tensor/TensorType.class */
public class TensorType {
    public static final TensorType empty = new TensorType(Collections.emptyList());
    private final ImmutableList<Dimension> dimensions;

    /* loaded from: input_file:com/yahoo/tensor/TensorType$Builder.class */
    public static class Builder {
        private final Map<String, Dimension> dimensions = new LinkedHashMap();
        private static final boolean supportsMixedTypes = false;

        public Builder() {
        }

        public Builder(TensorType... tensorTypeArr) {
            int length = tensorTypeArr.length;
            for (int i = supportsMixedTypes; i < length; i++) {
                addDimensionsOf(tensorTypeArr[i]);
            }
        }

        private void addDimensionsOf(TensorType tensorType) {
            addDimensionsOfAndDisallowMixedDimensions(tensorType);
        }

        private void addDimensionsOfAndDisallowMixedDimensions(TensorType tensorType) {
            boolean z = this.dimensions.values().stream().anyMatch(dimension -> {
                return !dimension.isIndexed();
            }) || tensorType.dimensions().stream().anyMatch(dimension2 -> {
                return !dimension2.isIndexed();
            });
            UnmodifiableIterator it = tensorType.dimensions.iterator();
            while (it.hasNext()) {
                Dimension dimension3 = (Dimension) it.next();
                if (z) {
                    dimension3 = new MappedDimension(dimension3.name());
                }
                set(dimension3.combineWith(Optional.ofNullable(this.dimensions.get(dimension3.name()))));
            }
        }

        private Builder add(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            if (this.dimensions.containsKey(dimension.name())) {
                throw new IllegalArgumentException("Could not add dimension " + dimension + " as this dimension is already present");
            }
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder set(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder indexed(String str, int i) {
            return add(new IndexedBoundDimension(str, i));
        }

        public Builder indexed(String str) {
            return add(new IndexedUnboundDimension(str));
        }

        public Builder mapped(String str) {
            return add(new MappedDimension(str));
        }

        public Builder dimension(Dimension dimension) {
            return add(dimension);
        }

        public Optional<Dimension> getDimension(String str) {
            return Optional.ofNullable(this.dimensions.get(str));
        }

        public Builder dimension(String str, Dimension.Type type) {
            switch (type) {
                case mapped:
                    mapped(str);
                    break;
                case indexedUnbound:
                    indexed(str);
                    break;
                default:
                    throw new IllegalArgumentException("This can not create a dimension of type " + type);
            }
            return this;
        }

        public TensorType build() {
            return new TensorType(this.dimensions.values());
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$Dimension.class */
    public static abstract class Dimension implements Comparable<Dimension> {
        private final String name;

        /* loaded from: input_file:com/yahoo/tensor/TensorType$Dimension$Type.class */
        public enum Type {
            indexedBound,
            indexedUnbound,
            mapped
        }

        private Dimension(String str) {
            Objects.requireNonNull(str, "A tensor name cannot be null");
            this.name = str;
        }

        public final String name() {
            return this.name;
        }

        public abstract Optional<Integer> size();

        public abstract Type type();

        public abstract Dimension withName(String str);

        public boolean isIndexed() {
            return type() == Type.indexedBound || type() == Type.indexedUnbound;
        }

        Dimension combineWith(Optional<Dimension> optional) {
            if (optional.isPresent() && !(this instanceof MappedDimension)) {
                if (optional.get() instanceof MappedDimension) {
                    return optional.get();
                }
                if (this instanceof IndexedUnboundDimension) {
                    return this;
                }
                if (optional.get() instanceof IndexedUnboundDimension) {
                    return optional.get();
                }
                IndexedBoundDimension indexedBoundDimension = (IndexedBoundDimension) this;
                IndexedBoundDimension indexedBoundDimension2 = (IndexedBoundDimension) optional.get();
                return indexedBoundDimension.size().get().intValue() < indexedBoundDimension2.size().get().intValue() ? indexedBoundDimension : indexedBoundDimension2;
            }
            return this;
        }

        public abstract String toString();

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            return this.name.equals(((Dimension) obj).name);
        }

        public int hashCode() {
            return this.name.hashCode();
        }

        @Override // java.lang.Comparable
        public int compareTo(Dimension dimension) {
            return this.name.compareTo(dimension.name);
        }

        public static Dimension indexed(String str, int i) {
            return new IndexedBoundDimension(str, i);
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$IndexedBoundDimension.class */
    public static class IndexedBoundDimension extends Dimension {
        private final Integer size;

        private IndexedBoundDimension(String str, int i) {
            super(str);
            if (i < 1) {
                throw new IllegalArgumentException("Size of bound dimension '" + str + "' must be at least 1");
            }
            this.size = Integer.valueOf(i);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Integer> size() {
            return Optional.of(this.size);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.indexedBound;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public IndexedBoundDimension withName(String str) {
            return new IndexedBoundDimension(str, this.size.intValue());
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "[" + this.size + "]";
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && super.equals(obj) && this.size.equals(((IndexedBoundDimension) obj).size);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public int hashCode() {
            return (31 * super.hashCode()) + this.size.hashCode();
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$IndexedUnboundDimension.class */
    public static class IndexedUnboundDimension extends Dimension {
        private IndexedUnboundDimension(String str) {
            super(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Integer> size() {
            return Optional.empty();
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.indexedUnbound;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public IndexedUnboundDimension withName(String str) {
            return new IndexedUnboundDimension(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "[]";
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/TensorType$MappedDimension.class */
    public static class MappedDimension extends Dimension {
        private MappedDimension(String str) {
            super(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Optional<Integer> size() {
            return Optional.empty();
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public Dimension.Type type() {
            return Dimension.Type.mapped;
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public MappedDimension withName(String str) {
            return new MappedDimension(str);
        }

        @Override // com.yahoo.tensor.TensorType.Dimension
        public String toString() {
            return name() + "{}";
        }
    }

    private TensorType(Collection<Dimension> collection) {
        ArrayList arrayList = new ArrayList(collection);
        Collections.sort(arrayList);
        this.dimensions = ImmutableList.copyOf(arrayList);
    }

    public static TensorType fromSpec(String str) {
        return TensorTypeParser.fromSpec(str);
    }

    public List<Dimension> dimensions() {
        return this.dimensions;
    }

    public Set<String> dimensionNames() {
        return (Set) this.dimensions.stream().map((v0) -> {
            return v0.name();
        }).collect(Collectors.toSet());
    }

    public Optional<Dimension> dimension(String str) {
        return indexOfDimension(str).map(num -> {
            return (Dimension) this.dimensions.get(num.intValue());
        });
    }

    public Optional<Integer> indexOfDimension(String str) {
        for (int i = 0; i < this.dimensions.size(); i++) {
            if (((Dimension) this.dimensions.get(i)).name().equals(str)) {
                return Optional.of(Integer.valueOf(i));
            }
        }
        return Optional.empty();
    }

    public boolean isAssignableTo(TensorType tensorType) {
        if (tensorType.dimensions().size() != dimensions().size()) {
            return false;
        }
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            Dimension dimension = dimensions().get(i);
            Dimension dimension2 = tensorType.dimensions().get(i);
            if (dimension.isIndexed() != dimension2.isIndexed() || !dimension.name().equals(dimension2.name())) {
                return false;
            }
            if (dimension2.size().isPresent() && (!dimension.size().isPresent() || dimension.size().get().intValue() > dimension2.size().get().intValue())) {
                return false;
            }
        }
        return true;
    }

    public String toString() {
        return "tensor(" + ((String) this.dimensions.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(","))) + ")";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return this.dimensions.equals(((TensorType) obj).dimensions);
    }

    public boolean mathematicallyEquals(TensorType tensorType) {
        if (dimensions().size() != tensorType.dimensions().size()) {
            return false;
        }
        for (int i = 0; i < dimensions().size(); i++) {
            if (!dimensions().get(i).name().equals(tensorType.dimensions().get(i).name())) {
                return false;
            }
        }
        return true;
    }

    public int hashCode() {
        return this.dimensions.hashCode();
    }
}
