package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:ai/vespa/models/evaluation/Model.class */
public class Model implements AutoCloseable {
    private static final Logger logger = Logger.getLogger(Model.class.getName());
    private static final String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
    private final String name;
    private final List<ExpressionFunction> functions;
    private final List<ExpressionFunction> publicFunctions;
    private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
    private final Map<String, LazyArrayContext> contextPrototypes;
    private final ExpressionOptimizer expressionOptimizer;
    private final List<Runnable> closeActions;

    /* loaded from: input_file:ai/vespa/models/evaluation/Model$OnnxReplacer.class */
    static class OnnxReplacer extends ExpressionTransformer<TransformContext> {
        private final List<OnnxModel> onnxModels;
        private final Map<String, TensorType> declaredTypes;

        private OnnxModel getModel(String str) {
            for (OnnxModel onnxModel : this.onnxModels) {
                if (onnxModel.name().equals(str)) {
                    return onnxModel;
                }
            }
            return null;
        }

        public OnnxReplacer(List<OnnxModel> list, Map<String, TensorType> map) {
            this.onnxModels = list;
            this.declaredTypes = map;
        }

        public ExpressionNode transform(ExpressionNode expressionNode, TransformContext transformContext) {
            if (expressionNode instanceof ReferenceNode) {
                Reference reference = ((ReferenceNode) expressionNode).reference();
                if (reference.name().equals("onnx") || reference.name().equals("onnxModel")) {
                    Model.logger.fine("consider replacing: " + String.valueOf(reference));
                    OnnxModel model = getModel((String) reference.simpleArgument().orElse(null));
                    if (model != null) {
                        model.load();
                        ExpressionNode expressionForOutput = model.getExpressionForOutput(reference.output());
                        if (expressionForOutput != null) {
                            Model.logger.fine("Replacing " + String.valueOf(expressionNode) + " => " + String.valueOf(expressionForOutput));
                            expressionNode = expressionForOutput;
                            for (OnnxModel.InputSpec inputSpec : model.inputSpecs) {
                                TensorType tensorType = this.declaredTypes.get(inputSpec.source);
                                if (tensorType == null) {
                                    this.declaredTypes.put(inputSpec.source, inputSpec.wantedType);
                                } else if (!tensorType.isAssignableTo(inputSpec.wantedType)) {
                                    throw new IllegalArgumentException("Conflicting types needed for " + inputSpec.source + "; " + String.valueOf(tensorType) + " cannot be assigned to " + String.valueOf(inputSpec.wantedType));
                                }
                            }
                        } else {
                            Model.logger.fine("no output named " + reference.output() + " from " + String.valueOf(model));
                        }
                    } else {
                        Model.logger.fine("no onnx model named " + String.valueOf(reference.simpleArgument()));
                    }
                }
            }
            if (expressionNode instanceof CompositeNode) {
                expressionNode = transformChildren((CompositeNode) expressionNode, transformContext);
            }
            if (expressionNode != expressionNode) {
                Model.logger.fine("transformed: " + String.valueOf(expressionNode) + " => " + String.valueOf(expressionNode));
            }
            return expressionNode;
        }
    }

    public Model(String str, Collection<ExpressionFunction> collection) {
        this(str, (Map) collection.stream().collect(Collectors.toMap(expressionFunction -> {
            return FunctionReference.fromName(expressionFunction.getName());
        }, expressionFunction2 -> {
            return expressionFunction2;
        })), Map.of(), Map.of(), List.of(), List.of());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Model(String str, Map<FunctionReference, ExpressionFunction> map, Map<FunctionReference, ExpressionFunction> map2, Map<String, TensorType> map3, List<Constant> list, List<OnnxModel> list2) {
        this.expressionOptimizer = new ExpressionOptimizer();
        this.name = str;
        BindingExtractor bindingExtractor = new BindingExtractor(map2, list2);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<FunctionReference, ExpressionFunction> entry : map.entrySet()) {
            try {
                RankingExpression body = entry.getValue().getBody();
                body.setRoot(new OnnxReplacer(list2, map3).transform(body.getRoot(), null));
                LazyArrayContext lazyArrayContext = new LazyArrayContext(entry.getValue(), bindingExtractor, map2, list, this);
                linkedHashMap.put(entry.getValue().getName(), lazyArrayContext);
                if (entry.getValue().returnType().isEmpty()) {
                    map.put(entry.getKey(), entry.getValue().withReturnType(TensorType.empty));
                }
                Iterator<Map.Entry<String, OnnxModel>> it = lazyArrayContext.onnxModels().entrySet().iterator();
                while (it.hasNext()) {
                    for (Map.Entry<String, TensorType> entry2 : it.next().getValue().inputs().entrySet()) {
                        map.put(entry.getKey(), entry.getValue().withArgument(entry2.getKey(), entry2.getValue()));
                    }
                }
                for (String str2 : lazyArrayContext.arguments()) {
                    if (entry.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
                        if (!entry.getValue().arguments().contains(str2)) {
                            map.put(entry.getKey(), entry.getValue().withArgument(str2));
                        }
                    } else if (entry.getValue().getArgumentType(str2) == null) {
                        map.put(entry.getKey(), entry.getValue().withArgument(str2, map3.getOrDefault(str2, TensorType.empty)));
                    }
                }
            } catch (RuntimeException e) {
                throw new IllegalArgumentException("Could not prepare an evaluation context for " + String.valueOf(entry), e);
            }
        }
        this.contextPrototypes = Map.copyOf(linkedHashMap);
        this.functions = List.copyOf((Collection) map.entrySet().stream().map(entry3 -> {
            return optimize((ExpressionFunction) entry3.getValue(), this.contextPrototypes.get(((FunctionReference) entry3.getKey()).functionName()));
        }).collect(Collectors.toList()));
        this.publicFunctions = map.values().stream().filter(expressionFunction -> {
            return !expressionFunction.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX);
        }).toList();
        this.referencedFunctions = Map.copyOf(map2);
        this.closeActions = list2.stream().map(onnxModel -> {
            Objects.requireNonNull(onnxModel);
            return onnxModel::close;
        }).toList();
    }

    private ExpressionFunction optimize(ExpressionFunction expressionFunction, ContextIndex contextIndex) {
        this.expressionOptimizer.optimize(expressionFunction.getBody(), contextIndex);
        return expressionFunction;
    }

    public String name() {
        return this.name;
    }

    public List<ExpressionFunction> functions() {
        return this.publicFunctions;
    }

    private LazyArrayContext requireContextPrototype(String str) {
        LazyArrayContext lazyArrayContext = this.contextPrototypes.get(str);
        if (lazyArrayContext == null) {
            throw new IllegalArgumentException("No function named '" + str + "' in " + String.valueOf(this) + ". Available functions: " + ((String) this.functions.stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.joining(", "))));
        }
        return lazyArrayContext;
    }

    ExpressionFunction function(String str) {
        for (ExpressionFunction expressionFunction : this.functions) {
            if (expressionFunction.getName().equals(str)) {
                return expressionFunction;
            }
        }
        return null;
    }

    Map<FunctionReference, ExpressionFunction> referencedFunctions() {
        return Map.copyOf(this.referencedFunctions);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ExpressionFunction requireReferencedFunction(FunctionReference functionReference) {
        ExpressionFunction expressionFunction = this.referencedFunctions.get(functionReference);
        if (expressionFunction == null) {
            throw new IllegalArgumentException("No " + String.valueOf(functionReference) + " in " + String.valueOf(this) + ". References: " + ((String) this.referencedFunctions.keySet().stream().map((v0) -> {
                return v0.serialForm();
            }).collect(Collectors.joining(", "))));
        }
        return expressionFunction;
    }

    public FunctionEvaluator evaluatorOf(String... strArr) {
        if (strArr.length == 0) {
            if (this.functions.size() > 1) {
                throwUndeterminedFunction("More than one function is available in " + String.valueOf(this) + ", but no name is given");
            }
            return evaluatorOf(this.functions.get(0));
        }
        if (strArr.length == 1) {
            String str = strArr[0];
            ExpressionFunction function = function(str);
            if (function != null) {
                return evaluatorOf(function);
            }
            List<ExpressionFunction> list = this.functions.stream().filter(expressionFunction -> {
                return expressionFunction.getName().startsWith(str + ".");
            }).toList();
            if (list.size() == 1) {
                return evaluatorOf(list.get(0));
            }
            if (list.size() > 1) {
                throwUndeterminedFunction("Multiple functions start by '" + str + "' in " + String.valueOf(this));
            }
            List<ExpressionFunction> list2 = this.functions.stream().filter(expressionFunction2 -> {
                return expressionFunction2.getName().endsWith("." + str);
            }).toList();
            if (list2.size() == 1) {
                return evaluatorOf(list2.get(0));
            }
            if (list2.size() > 1) {
                throwUndeterminedFunction("Multiple functions called '" + str + "' in " + String.valueOf(this));
            }
            if (str.startsWith("serving_default")) {
                return evaluatorOf("default" + str.substring("serving_default".length()));
            }
            if (str.startsWith("default.")) {
                return evaluatorOf(str.substring("default.".length()));
            }
            throwUndeterminedFunction("No function '" + str + "' in " + String.valueOf(this));
        } else if (strArr.length == 2) {
            return evaluatorOf(strArr[0] + "." + strArr[1]);
        }
        throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + Arrays.toString(strArr));
    }

    private FunctionEvaluator evaluatorOf(ExpressionFunction expressionFunction) {
        return new FunctionEvaluator(expressionFunction, requireContextPrototype(expressionFunction.getName()).copy());
    }

    private void throwUndeterminedFunction(String str) {
        throw new IllegalArgumentException(str + ". Available functions: " + ((String) this.functions.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.joining(", "))));
    }

    public String toString() {
        return "model '" + this.name + "'";
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.closeActions.forEach((v0) -> {
            v0.run();
        });
    }
}
