package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/ImportedModel.class */
public class ImportedModel implements ImportedMlModel {
    private static final String defaultSignatureName = "default";
    private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
    private final String name;
    private final String source;
    private final ImportedMlModel.ModelType modelType;
    private final Map<String, Signature> signatures = new HashMap();
    private final Map<String, TensorType> inputs = new HashMap();
    private final Map<String, Tensor> smallConstants = new HashMap();
    private final Map<String, Tensor> largeConstants = new HashMap();
    private final Map<String, RankingExpression> expressions = new HashMap();
    private final Map<String, RankingExpression> functions = new HashMap();

    /* loaded from: input_file:ai/vespa/rankingexpression/importer/ImportedModel$Signature.class */
    public class Signature {
        private final String name;
        private final Map<String, String> inputs = new LinkedHashMap();
        private final Map<String, String> outputs = new LinkedHashMap();
        private final Map<String, String> skippedOutputs = new HashMap();

        Signature(String str) {
            this.name = str;
        }

        public String name() {
            return this.name;
        }

        ImportedModel owner() {
            return ImportedModel.this;
        }

        public Map<String, String> inputs() {
            return Map.copyOf(this.inputs);
        }

        Map<String, TensorType> inputMap() {
            return Map.copyOf((Map) this.inputs.entrySet().stream().collect(CustomCollectors.toLinkedMap((v0) -> {
                return v0.getValue();
            }, entry -> {
                return owner().inputs.get(entry.getValue());
            })));
        }

        public TensorType inputArgument(String str) {
            return owner().inputs().get(this.inputs.get(str));
        }

        public Map<String, String> outputs() {
            return Map.copyOf(this.outputs);
        }

        public Map<String, String> skippedOutputs() {
            return Map.copyOf(this.skippedOutputs);
        }

        public ImportedMlFunction outputFunction(String str, String str2) {
            RankingExpression rankingExpression = owner().expressions().get(this.outputs.get(str));
            if (rankingExpression == null) {
                throw new IllegalArgumentException("Missing output '" + str + "' in " + this);
            }
            return new ImportedMlFunction(str2, new ArrayList(this.inputs.values()), rankingExpression.getRoot().toString(), ImportedModel.this.asStrings(inputMap()), Optional.empty());
        }

        public String toString() {
            return "signature '" + this.name + "'";
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void input(String str, String str2) {
            this.inputs.put(str, str2);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void output(String str, String str2) {
            this.outputs.put(str, str2);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void skippedOutput(String str, String str2) {
            this.skippedOutputs.put(str, str2);
        }
    }

    public ImportedModel(String str, String str2, ImportedMlModel.ModelType modelType) {
        if (!nameRegexp.matcher(str).matches()) {
            throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + str + "'");
        }
        this.name = str;
        this.source = str2;
        this.modelType = modelType;
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public String name() {
        return this.name;
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public String source() {
        return this.source;
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public ImportedMlModel.ModelType modelType() {
        return this.modelType;
    }

    public String toString() {
        return "imported model '" + this.name + "' from " + this.source;
    }

    public Map<String, TensorType> inputs() {
        return Map.copyOf(this.inputs);
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public Optional<String> inputTypeSpec(String str) {
        return Optional.ofNullable(this.inputs.get(str)).map((v0) -> {
            return v0.toString();
        });
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public Map<String, Tensor> smallConstantTensors() {
        return Map.copyOf(this.smallConstants);
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    @Deprecated(forRemoval = true)
    public Map<String, String> smallConstants() {
        return asStrings(this.smallConstants);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean hasSmallConstant(String str) {
        return this.smallConstants.containsKey(str);
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public Map<String, Tensor> largeConstantTensors() {
        return Map.copyOf(this.largeConstants);
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    @Deprecated(forRemoval = true)
    public Map<String, String> largeConstants() {
        return asStrings(this.largeConstants);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean hasLargeConstant(String str) {
        return this.largeConstants.containsKey(str);
    }

    public Map<String, RankingExpression> expressions() {
        return Map.copyOf(this.expressions);
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public Map<String, String> functions() {
        return asExpressionStrings(this.functions);
    }

    public Map<String, Signature> signatures() {
        return Map.copyOf(this.signatures);
    }

    public Signature signature(String str) {
        return this.signatures.computeIfAbsent(str, str2 -> {
            return new Signature(str2);
        });
    }

    public Signature defaultSignature() {
        return signature(defaultSignatureName);
    }

    public void input(String str, TensorType tensorType) {
        this.inputs.put(str, tensorType);
    }

    public void smallConstant(String str, Tensor tensor) {
        this.smallConstants.put(str, tensor);
    }

    public void largeConstant(String str, Tensor tensor) {
        this.largeConstants.put(str, tensor);
    }

    public void expression(String str, RankingExpression rankingExpression) {
        this.expressions.put(str, rankingExpression);
    }

    public void function(String str, RankingExpression rankingExpression) {
        this.functions.put(str, rankingExpression);
    }

    public void expression(String str, String str2) {
        try {
            String trim = str2.trim();
            if (trim.startsWith("file:")) {
                String trim2 = trim.substring("file:".length()).trim();
                if (!trim2.endsWith(".expression")) {
                    trim2 = trim2 + ".expression";
                }
                trim = IOUtils.readFile(relativeFile(trim2, "function '" + str + "'"));
            }
            expression(str, new RankingExpression(trim));
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not read file referenced in '" + str + "'");
        } catch (ParseException e2) {
            throw new IllegalArgumentException("Could not parse function '" + str + "'", e2);
        }
    }

    public File relativeFile(String str, String str2) {
        File file = new File(new File(source()).getParent(), str);
        if (file.exists()) {
            return file;
        }
        throw new IllegalArgumentException(str2 + " references '" + str + "', but this file does not exist");
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public List<ImportedMlFunction> outputExpressions() {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Signature> entry : signatures().entrySet()) {
            for (Map.Entry<String, String> entry2 : entry.getValue().outputs().entrySet()) {
                arrayList.add(entry.getValue().outputFunction(entry2.getKey(), entry.getKey() + "." + entry2.getKey()));
            }
            if (entry.getValue().outputs().isEmpty()) {
                arrayList.add(new ImportedMlFunction(entry.getKey(), new ArrayList(entry.getValue().inputs().values()), expressions().get(entry.getKey()).getRoot().toString(), asStrings(entry.getValue().inputMap()), Optional.empty()));
            }
        }
        if (signatures().isEmpty()) {
            if (expressions().size() == 1) {
                Map.Entry<String, RankingExpression> next = this.expressions.entrySet().iterator().next();
                arrayList.add(new ImportedMlFunction(next.getKey(), new ArrayList(this.inputs.keySet()), next.getValue().getRoot().toString(), asStrings(this.inputs), Optional.empty()));
            } else {
                for (Map.Entry<String, RankingExpression> entry3 : expressions().entrySet()) {
                    arrayList.add(new ImportedMlFunction(entry3.getKey(), new ArrayList(this.inputs.keySet()), entry3.getValue().getRoot().toString(), asStrings(this.inputs), Optional.empty()));
                }
            }
        }
        return arrayList;
    }

    private Map<String, String> asStrings(Map<String, ?> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, ?> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().toString());
        }
        return hashMap;
    }

    private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, RankingExpression> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().getRoot().toString());
        }
        return hashMap;
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public boolean isNative() {
        return true;
    }

    @Override // ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel
    public ImportedModel asNative() {
        return this;
    }
}
