package com.yahoo.schema;

import com.google.common.collect.ImmutableMap;
import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
import com.yahoo.schema.expressiontransforms.TokenTransformer;
import com.yahoo.searchlib.ranking.features.FeatureNames;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;

/* loaded from: input_file:com/yahoo/schema/MapEvaluationTypeContext.class */
public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> {
    private final Optional<MapEvaluationTypeContext> parent;
    private final Map<Reference, TensorType> featureTypes;
    private final Map<Reference, TensorType> resolvedTypes;
    private final Map<Reference, TensorType> globallyResolvedTypes;
    private final Deque<Reference> currentResolutionCallStack;
    private final SortedSet<Reference> queryFeaturesNotDeclared;
    private boolean tensorsAreUsed;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MapEvaluationTypeContext(ImmutableMap<String, ExpressionFunction> immutableMap, Map<Reference, TensorType> map) {
        super(immutableMap);
        this.featureTypes = new HashMap();
        this.resolvedTypes = new HashMap();
        this.parent = Optional.empty();
        this.featureTypes.putAll(map);
        this.currentResolutionCallStack = new ArrayDeque();
        this.queryFeaturesNotDeclared = new TreeSet();
        this.tensorsAreUsed = false;
        this.globallyResolvedTypes = new HashMap();
    }

    private MapEvaluationTypeContext(Map<String, ExpressionFunction> map, Map<String, String> map2, Optional<MapEvaluationTypeContext> optional, Map<Reference, TensorType> map3, Deque<Reference> deque, SortedSet<Reference> sortedSet, boolean z, Map<Reference, TensorType> map4) {
        super(map, map2);
        this.featureTypes = new HashMap();
        this.resolvedTypes = new HashMap();
        this.parent = optional;
        this.featureTypes.putAll(map3);
        this.currentResolutionCallStack = deque;
        this.queryFeaturesNotDeclared = sortedSet;
        this.tensorsAreUsed = z;
        this.globallyResolvedTypes = map4;
    }

    public void setType(Reference reference, TensorType tensorType) {
        this.featureTypes.put(reference, tensorType);
        this.queryFeaturesNotDeclared.remove(reference);
    }

    public Map<Reference, TensorType> featureTypes() {
        return Collections.unmodifiableMap(this.featureTypes);
    }

    public TensorType getType(String str) {
        throw new UnsupportedOperationException("Not able to parse general references from string form");
    }

    public void forgetResolvedTypes() {
        this.resolvedTypes.clear();
    }

    private boolean referenceCanBeResolvedGlobally(Reference reference) {
        Optional<ExpressionFunction> functionInvocation = functionInvocation(reference);
        return functionInvocation.isPresent() && functionInvocation.get().arguments().size() == 0;
    }

    public TensorType getType(Reference reference) {
        boolean referenceCanBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);
        TensorType tensorType = this.resolvedTypes.get(reference);
        if (tensorType == null && referenceCanBeResolvedGlobally) {
            tensorType = this.globallyResolvedTypes.get(reference);
        }
        if (tensorType != null) {
            return tensorType;
        }
        TensorType resolveType = resolveType(reference);
        if (resolveType == null) {
            return defaultTypeOf(reference);
        }
        this.resolvedTypes.put(reference, resolveType);
        if (resolveType.rank() > 0) {
            this.tensorsAreUsed = true;
        }
        if (referenceCanBeResolvedGlobally) {
            this.globallyResolvedTypes.put(reference, resolveType);
        }
        return resolveType;
    }

    MapEvaluationTypeContext getParent(String str, String str2) {
        return this.parent.orElseThrow(() -> {
            return new IllegalArgumentException("argument " + str + " is bound to " + str2 + " but there is no parent context");
        });
    }

    public String resolveBinding(String str) {
        String binding = getBinding(str);
        return binding == null ? str : getParent(str, binding).resolveBinding(binding);
    }

    private TensorType resolveType(Reference reference) {
        if (this.currentResolutionCallStack.contains(reference)) {
            throw new IllegalArgumentException("Invocation loop: " + ((String) this.currentResolutionCallStack.stream().map((v0) -> {
                return v0.toString();
            }).collect(Collectors.joining(" -> "))) + " -> " + String.valueOf(reference));
        }
        Optional<String> boundIdentifier = boundIdentifier(reference);
        if (boundIdentifier.isPresent()) {
            try {
                return new RankingExpression(boundIdentifier.get()).type(getParent(reference.name(), boundIdentifier.get()));
            } catch (ParseException e) {
                throw new IllegalArgumentException((Throwable) e);
            }
        }
        try {
            this.currentResolutionCallStack.addLast(reference);
            if (FeatureNames.isSimpleFeature(reference)) {
                TensorType tensorType = this.featureTypes.get(Reference.simple(reference.name(), resolveBinding((String) reference.simpleArgument().get())));
                this.currentResolutionCallStack.removeLast();
                return tensorType;
            }
            Optional<ExpressionFunction> functionInvocation = functionInvocation(reference);
            if (functionInvocation.isPresent()) {
                TensorType type = functionInvocation.get().getBody().type(withBindings(bind(functionInvocation.get().arguments(), reference.arguments())));
                this.currentResolutionCallStack.removeLast();
                return type;
            }
            Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
            if (onnxFeatureType.isPresent()) {
                TensorType tensorType2 = onnxFeatureType.get();
                this.currentResolutionCallStack.removeLast();
                return tensorType2;
            }
            Optional<TensorType> transformerTokensFeatureType = transformerTokensFeatureType(reference);
            if (transformerTokensFeatureType.isPresent()) {
                TensorType tensorType3 = transformerTokensFeatureType.get();
                this.currentResolutionCallStack.removeLast();
                return tensorType3;
            }
            Optional<TensorType> tensorFeatureType = tensorFeatureType(reference);
            if (tensorFeatureType.isPresent()) {
                TensorType tensorType4 = tensorFeatureType.get();
                this.currentResolutionCallStack.removeLast();
                return tensorType4;
            }
            if (reference.isIdentifier() && this.featureTypes.containsKey(reference)) {
                TensorType tensorType5 = this.featureTypes.get(reference);
                this.currentResolutionCallStack.removeLast();
                return tensorType5;
            }
            if (reference.isIdentifier()) {
                Reference asConstantFeature = FeatureNames.asConstantFeature(reference.name());
                if (this.featureTypes.containsKey(asConstantFeature)) {
                    TensorType tensorType6 = this.featureTypes.get(asConstantFeature);
                    this.currentResolutionCallStack.removeLast();
                    return tensorType6;
                }
            }
            TensorType tensorType7 = TensorType.empty;
            this.currentResolutionCallStack.removeLast();
            return tensorType7;
        } catch (Throwable th) {
            this.currentResolutionCallStack.removeLast();
            throw th;
        }
    }

    public TensorType defaultTypeOf(Reference reference) {
        if (!FeatureNames.isSimpleFeature(reference)) {
            throw new IllegalArgumentException("This can only be called for simple references, not " + String.valueOf(reference));
        }
        if (!reference.name().equals("query")) {
            return null;
        }
        this.queryFeaturesNotDeclared.add(reference);
        return TensorType.empty;
    }

    private Optional<String> boundIdentifier(Reference reference) {
        if (reference.arguments().isEmpty() && reference.output() == null) {
            return Optional.ofNullable(getBinding(reference.name()));
        }
        return Optional.empty();
    }

    private Optional<ExpressionFunction> functionInvocation(Reference reference) {
        ExpressionFunction expressionFunction;
        if (reference.output() == null && (expressionFunction = (ExpressionFunction) getFunctions().get(reference.name())) != null && expressionFunction.arguments().size() == reference.arguments().size()) {
            return Optional.of(expressionFunction);
        }
        return Optional.empty();
    }

    private Optional<TensorType> onnxFeatureType(Reference reference) {
        if (!reference.name().equals("onnxModel") && !reference.name().equals("onnx")) {
            return Optional.empty();
        }
        if (!this.featureTypes.containsKey(reference)) {
            String expressionNode = ((ExpressionNode) reference.arguments().expressions().get(0)).toString();
            String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
            reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), OnnxModelTransformer.getModelOutput(reference, null));
            if (!this.featureTypes.containsKey(reference)) {
                throw new IllegalArgumentException("Missing onnx-model config for '" + expressionNode + "'");
            }
        }
        return Optional.of(this.featureTypes.get(reference));
    }

    private Optional<TensorType> transformerTokensFeatureType(Reference reference) {
        if (!reference.name().equals("tokenTypeIds") && !reference.name().equals("tokenInputIds") && !reference.name().equals("tokenAttentionMask")) {
            return Optional.empty();
        }
        if (reference.arguments().size() <= 1) {
            throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments");
        }
        return Optional.of(TokenTransformer.createTensorType(reference.name(), (ExpressionNode) reference.arguments().expressions().get(0)));
    }

    private static ExpressionNode getArgExp(Reference reference, int i) {
        if (reference.arguments().size() > i) {
            return (ExpressionNode) reference.arguments().expressions().get(i);
        }
        return null;
    }

    private Optional<TensorType> tensorFeatureType(Reference reference) {
        Reference asAttributeFeature;
        TensorType tensorType;
        if (!reference.name().equals("tensorFromLabels") && !reference.name().equals("tensorFromWeightedSet") && !reference.name().equals("elementwise") && !reference.name().equals("closest")) {
            return Optional.empty();
        }
        ReferenceNode argExp = getArgExp(reference, 0);
        if (reference.name().equals("closest")) {
            if (argExp == null || reference.arguments().size() > 2) {
                throw new IllegalArgumentException(reference.name() + " must have one or two arguments");
            }
            if (argExp instanceof ReferenceNode) {
                Reference reference2 = argExp.reference();
                if (reference2.isIdentifier() && (tensorType = this.featureTypes.get((asAttributeFeature = FeatureNames.asAttributeFeature(reference2.name())))) != null && tensorType.rank() > 0) {
                    TensorType mappedSubtype = tensorType.mappedSubtype();
                    if (mappedSubtype.rank() > 0) {
                        return Optional.of(mappedSubtype);
                    }
                    throw new IllegalArgumentException("Unexpected tensor type " + String.valueOf(tensorType) + " for " + String.valueOf(asAttributeFeature) + " used by " + String.valueOf(reference));
                }
            }
            throw new IllegalArgumentException("The first argument of " + reference.name() + " must be the name of a tensor attribute, not " + String.valueOf(argExp));
        }
        String str = null;
        TensorType.Value value = TensorType.Value.DOUBLE;
        ReferenceNode argExp2 = getArgExp(reference, 1);
        ExpressionNode argExp3 = getArgExp(reference, 2);
        if (reference.name().equals("elementwise")) {
            if (argExp2 == null || reference.arguments().size() > 3) {
                throw new IllegalArgumentException(String.valueOf(reference) + " must have two or three arguments");
            }
            if (argExp3 != null) {
                value = TensorType.Value.fromId(argExp3.toString());
            }
        }
        if (reference.name().equals("tensorFromLabels") || reference.name().equals("tensorFromWeightedSet")) {
            if (argExp instanceof ReferenceNode) {
                ReferenceNode referenceNode = argExp;
                if (FeatureNames.isSimpleFeature(referenceNode.reference())) {
                    if (argExp2 == null) {
                        str = ((ExpressionNode) referenceNode.reference().arguments().expressions().get(0)).toString();
                    }
                }
            }
            throw new IllegalArgumentException("The first argument of " + reference.name() + " must be a simple feature, not " + String.valueOf(argExp));
        }
        if (argExp2 != null) {
            if (!(argExp2 instanceof NameNode) && (!(argExp2 instanceof ReferenceNode) || !argExp2.reference().isIdentifier())) {
                throw new IllegalArgumentException("The second argument of " + reference.name() + " must be a dimension name, not " + String.valueOf(argExp2));
            }
            str = argExp2.toString();
        }
        if (str == null) {
            throw new IllegalArgumentException("Missing dimension name for " + reference.name());
        }
        return Optional.of(new TensorType.Builder(value).mapped(str).build());
    }

    private Map<String, String> bind(List<String> list, Arguments arguments) {
        HashMap hashMap = new HashMap(list.size());
        for (int i = 0; i < list.size(); i++) {
            hashMap.put(list.get(i), ((ExpressionNode) arguments.expressions().get(i)).toString());
        }
        return hashMap;
    }

    public SortedSet<Reference> queryFeaturesNotDeclared() {
        return Collections.unmodifiableSortedSet(this.queryFeaturesNotDeclared);
    }

    public boolean tensorsAreUsed() {
        return this.tensorsAreUsed;
    }

    public MapEvaluationTypeContext withBindings(Map<String, String> map) {
        return new MapEvaluationTypeContext(getFunctions(), map, Optional.of(this), this.featureTypes, this.currentResolutionCallStack, this.queryFeaturesNotDeclared, this.tensorsAreUsed, this.globallyResolvedTypes);
    }

    /* renamed from: withBindings, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ FunctionReferenceContext m20withBindings(Map map) {
        return withBindings((Map<String, String>) map);
    }
}
