package com.yahoo.tensor;

import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:com/yahoo/tensor/TypeResolver.class */
public class TypeResolver {
    private static final Logger logger = Logger.getLogger(TypeResolver.class.getName());

    private static TensorType scalar() {
        return TensorType.empty;
    }

    public static TensorType map(TensorType tensorType) {
        TensorType.Value valueType = tensorType.valueType();
        TensorType.Value largestOf = TensorType.Value.largestOf(valueType, TensorType.Value.FLOAT);
        return largestOf == valueType ? tensorType : new TensorType(largestOf, tensorType.dimensions());
    }

    public static TensorType reduce(TensorType tensorType, List<String> list) {
        if (list.isEmpty()) {
            return scalar();
        }
        HashMap hashMap = new HashMap();
        for (TensorType.Dimension dimension : tensorType.dimensions()) {
            hashMap.put(dimension.name(), dimension);
        }
        for (String str : list) {
            if (hashMap.containsKey(str)) {
                hashMap.remove(str);
            } else {
                logger.log(Level.WARNING, "reducing non-existing dimension " + str + " in type " + tensorType);
            }
        }
        return hashMap.isEmpty() ? scalar() : new TensorType(TensorType.Value.largestOf(tensorType.valueType(), TensorType.Value.FLOAT), hashMap.values());
    }

    public static TensorType peek(TensorType tensorType, List<String> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Peeking no dimensions makes no sense");
        }
        HashMap hashMap = new HashMap();
        for (TensorType.Dimension dimension : tensorType.dimensions()) {
            hashMap.put(dimension.name(), dimension);
        }
        for (String str : list) {
            if (!hashMap.containsKey(str)) {
                throw new IllegalArgumentException("Peeking non-existing dimension '" + str + "'");
            }
            hashMap.remove(str);
        }
        return hashMap.isEmpty() ? scalar() : new TensorType(tensorType.valueType(), hashMap.values());
    }

    public static TensorType rename(TensorType tensorType, List<String> list, List<String> list2) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Renaming no dimensions");
        }
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Bad rename, from size " + list.size() + " != to.size " + list2.size());
        }
        HashMap hashMap = new HashMap();
        for (TensorType.Dimension dimension : tensorType.dimensions()) {
            hashMap.put(dimension.name(), dimension);
        }
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            String str2 = list2.get(i);
            if (hashMap.containsKey(str)) {
                hashMap2.put(str2, ((TensorType.Dimension) hashMap.remove(str)).withName(str2));
            } else {
                logger.log(Level.WARNING, "Renaming non-existing dimension " + str + " in type " + tensorType);
            }
        }
        for (TensorType.Dimension dimension2 : hashMap.values()) {
            hashMap2.put(dimension2.name(), dimension2);
        }
        if (tensorType.dimensions().size() == hashMap2.size()) {
            return new TensorType(tensorType.valueType(), hashMap2.values());
        }
        throw new IllegalArgumentException("Bad rename, lost some dimensions");
    }

    public static TensorType cell_cast(TensorType tensorType, TensorType.Value value) {
        if (value == TensorType.Value.DOUBLE || !tensorType.dimensions().isEmpty()) {
            return new TensorType(value, tensorType.dimensions());
        }
        throw new IllegalArgumentException("Cannot cast " + tensorType + " to valueType" + value);
    }

    private static boolean firstIsBoundSecond(TensorType.Dimension dimension, TensorType.Dimension dimension2) {
        return dimension.type() == TensorType.Dimension.Type.indexedBound && dimension2.type() == TensorType.Dimension.Type.indexedUnbound && dimension.name().equals(dimension2.name());
    }

    private static boolean firstIsSmaller(TensorType.Dimension dimension, TensorType.Dimension dimension2) {
        return dimension.type() == TensorType.Dimension.Type.indexedBound && dimension2.type() == TensorType.Dimension.Type.indexedBound && dimension.name().equals(dimension2.name()) && dimension.size().isPresent() && dimension2.size().isPresent() && dimension.size().get().longValue() < dimension2.size().get().longValue();
    }

    public static TensorType join(TensorType tensorType, TensorType tensorType2) {
        TensorType.Value value = TensorType.Value.DOUBLE;
        if (tensorType.rank() > 0 && tensorType2.rank() > 0) {
            value = TensorType.Value.largestOf(tensorType.valueType(), tensorType2.valueType());
        } else if (tensorType.rank() > 0) {
            value = tensorType.valueType();
        } else if (tensorType2.rank() > 0) {
            value = tensorType2.valueType();
        }
        TensorType.Value largestOf = TensorType.Value.largestOf(value, TensorType.Value.FLOAT);
        HashMap hashMap = new HashMap();
        for (TensorType.Dimension dimension : tensorType.dimensions()) {
            hashMap.put(dimension.name(), dimension);
        }
        for (TensorType.Dimension dimension2 : tensorType2.dimensions()) {
            if (hashMap.containsKey(dimension2.name())) {
                TensorType.Dimension dimension3 = (TensorType.Dimension) hashMap.get(dimension2.name());
                if (dimension3.equals(dimension2)) {
                    continue;
                } else if (firstIsBoundSecond(dimension2, dimension3)) {
                    hashMap.put(dimension2.name(), dimension2);
                } else if (firstIsBoundSecond(dimension3, dimension2)) {
                    hashMap.put(dimension2.name(), dimension3);
                } else if (dimension2.isMapped() && dimension3.isIndexed()) {
                    hashMap.put(dimension2.name(), dimension2);
                } else {
                    if (!dimension2.isIndexed() || !dimension3.isMapped()) {
                        throw new IllegalArgumentException("Unequal dimension " + dimension2.name() + " in " + tensorType + " and " + tensorType2);
                    }
                    hashMap.put(dimension2.name(), dimension3);
                }
            } else {
                hashMap.put(dimension2.name(), dimension2);
            }
        }
        return new TensorType(largestOf, hashMap.values());
    }

    public static TensorType merge(TensorType tensorType, TensorType tensorType2) {
        int size = tensorType.dimensions().size();
        boolean z = tensorType2.dimensions().size() == size;
        if (z) {
            for (int i = 0; i < size; i++) {
                if (!tensorType.dimensions().get(i).name().equals(tensorType2.dimensions().get(i).name())) {
                    z = false;
                }
            }
        }
        if (z) {
            return join(tensorType, tensorType2);
        }
        throw new IllegalArgumentException("Types in merge() dimensions mismatch: " + tensorType + " != " + tensorType2);
    }

    public static TensorType concat(TensorType tensorType, TensorType tensorType2, String str) {
        TensorType.Value value = TensorType.Value.DOUBLE;
        if (tensorType.rank() > 0 && tensorType2.rank() > 0) {
            value = tensorType.valueType() == tensorType2.valueType() ? tensorType.valueType() : TensorType.Value.largestOf(TensorType.Value.largestOf(tensorType.valueType(), tensorType2.valueType()), TensorType.Value.FLOAT);
        } else if (tensorType.rank() > 0) {
            value = tensorType.valueType();
        } else if (tensorType2.rank() > 0) {
            value = tensorType2.valueType();
        }
        TensorType.Dimension indexed = TensorType.Dimension.indexed(str, 1L);
        TensorType.Dimension indexed2 = TensorType.Dimension.indexed(str, 1L);
        HashMap hashMap = new HashMap();
        for (TensorType.Dimension dimension : tensorType.dimensions()) {
            if (dimension.name().equals(str)) {
                indexed = dimension;
            } else {
                hashMap.put(dimension.name(), dimension);
            }
        }
        for (TensorType.Dimension dimension2 : tensorType2.dimensions()) {
            if (dimension2.name().equals(str)) {
                indexed2 = dimension2;
            } else if (hashMap.containsKey(dimension2.name())) {
                TensorType.Dimension dimension3 = (TensorType.Dimension) hashMap.get(dimension2.name());
                if (dimension3.equals(dimension2)) {
                    continue;
                } else if (firstIsBoundSecond(dimension2, dimension3)) {
                    hashMap.put(dimension2.name(), dimension3);
                } else if (firstIsBoundSecond(dimension3, dimension2)) {
                    hashMap.put(dimension2.name(), dimension2);
                } else if (firstIsSmaller(dimension2, dimension3)) {
                    hashMap.put(dimension2.name(), dimension2);
                } else {
                    if (!firstIsSmaller(dimension3, dimension2)) {
                        throw new IllegalArgumentException("Unequal dimension " + dimension2.name() + " in " + tensorType + " and " + tensorType2);
                    }
                    hashMap.put(dimension2.name(), dimension3);
                }
            } else {
                hashMap.put(dimension2.name(), dimension2);
            }
        }
        if (indexed.type() == TensorType.Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension " + str + " in lhs: " + tensorType);
        }
        if (indexed2.type() == TensorType.Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension " + str + " in rhs: " + tensorType2);
        }
        if (indexed.type() == TensorType.Dimension.Type.indexedUnbound) {
            hashMap.put(str, indexed);
        } else if (indexed2.type() == TensorType.Dimension.Type.indexedUnbound) {
            hashMap.put(str, indexed2);
        } else {
            hashMap.put(str, TensorType.Dimension.indexed(str, indexed.size().get().longValue() + indexed2.size().get().longValue()));
        }
        return new TensorType(value, hashMap.values());
    }
}
