/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;

public class ImportedModel
implements ImportedMlModel {
    private static final String defaultSignatureName = "default";
    private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
    private final String name;
    private final String source;
    private final ImportedMlModel.ModelType modelType;
    private final Map<String, Signature> signatures = new HashMap<String, Signature>();
    private final Map<String, TensorType> inputs = new HashMap<String, TensorType>();
    private final Map<String, Tensor> smallConstants = new HashMap<String, Tensor>();
    private final Map<String, Tensor> largeConstants = new HashMap<String, Tensor>();
    private final Map<String, RankingExpression> expressions = new HashMap<String, RankingExpression>();
    private final Map<String, RankingExpression> functions = new HashMap<String, RankingExpression>();

    public ImportedModel(String name, String source, ImportedMlModel.ModelType modelType) {
        if (!nameRegexp.matcher(name).matches()) {
            throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
        }
        this.name = name;
        this.source = source;
        this.modelType = modelType;
    }

    @Override
    public String name() {
        return this.name;
    }

    @Override
    public String source() {
        return this.source;
    }

    @Override
    public ImportedMlModel.ModelType modelType() {
        return this.modelType;
    }

    public String toString() {
        return "imported model '" + this.name + "' from " + this.source;
    }

    public Map<String, TensorType> inputs() {
        return Map.copyOf(this.inputs);
    }

    @Override
    public Optional<String> inputTypeSpec(String input) {
        return Optional.ofNullable(this.inputs.get(input)).map(TensorType::toString);
    }

    @Override
    public Map<String, Tensor> smallConstantTensors() {
        return Map.copyOf(this.smallConstants);
    }

    @Override
    @Deprecated(forRemoval=true)
    public Map<String, String> smallConstants() {
        return this.asStrings(this.smallConstants);
    }

    boolean hasSmallConstant(String name) {
        return this.smallConstants.containsKey(name);
    }

    @Override
    public Map<String, Tensor> largeConstantTensors() {
        return Map.copyOf(this.largeConstants);
    }

    @Override
    @Deprecated(forRemoval=true)
    public Map<String, String> largeConstants() {
        return this.asStrings(this.largeConstants);
    }

    boolean hasLargeConstant(String name) {
        return this.largeConstants.containsKey(name);
    }

    public Map<String, RankingExpression> expressions() {
        return Map.copyOf(this.expressions);
    }

    @Override
    public Map<String, String> functions() {
        return this.asExpressionStrings(this.functions);
    }

    public Map<String, Signature> signatures() {
        return Map.copyOf(this.signatures);
    }

    public Signature signature(String name) {
        return this.signatures.computeIfAbsent(name, x$0 -> new Signature((String)x$0));
    }

    public Signature defaultSignature() {
        return this.signature(defaultSignatureName);
    }

    public void input(String name, TensorType argumentType) {
        this.inputs.put(name, argumentType);
    }

    public void smallConstant(String name, Tensor constant) {
        this.smallConstants.put(name, constant);
    }

    public void largeConstant(String name, Tensor constant) {
        this.largeConstants.put(name, constant);
    }

    public void expression(String name, RankingExpression expression) {
        this.expressions.put(name, expression);
    }

    public void function(String name, RankingExpression expression) {
        this.functions.put(name, expression);
    }

    public void expression(String name, String expression) {
        try {
            expression = expression.trim();
            if (expression.startsWith("file:")) {
                Object filePath = expression.substring("file:".length()).trim();
                if (!((String)filePath).endsWith(".expression")) {
                    filePath = (String)filePath + ".expression";
                }
                expression = IOUtils.readFile((File)this.relativeFile((String)filePath, "function '" + name + "'"));
            }
            this.expression(name, new RankingExpression(expression));
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not read file referenced in '" + name + "'");
        }
        catch (ParseException e) {
            throw new IllegalArgumentException("Could not parse function '" + name + "'", e);
        }
    }

    public File relativeFile(String relativePath, String descriptionOfPath) {
        File file = new File(new File(this.source()).getParent(), relativePath);
        if (!file.exists()) {
            throw new IllegalArgumentException(descriptionOfPath + " references '" + relativePath + "', but this file does not exist");
        }
        return file;
    }

    @Override
    public List<ImportedMlFunction> outputExpressions() {
        ArrayList<ImportedMlFunction> functions = new ArrayList<ImportedMlFunction>();
        for (Map.Entry<String, Signature> entry : this.signatures().entrySet()) {
            for (Map.Entry<String, String> outputEntry : entry.getValue().outputs().entrySet()) {
                functions.add(entry.getValue().outputFunction(outputEntry.getKey(), entry.getKey() + "." + outputEntry.getKey()));
            }
            if (!entry.getValue().outputs().isEmpty()) continue;
            functions.add(new ImportedMlFunction(entry.getKey(), new ArrayList<String>(entry.getValue().inputs().values()), this.expressions().get(entry.getKey()).getRoot().toString(), this.asStrings(entry.getValue().inputMap()), Optional.empty()));
        }
        if (this.signatures().isEmpty()) {
            if (this.expressions().size() == 1) {
                Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
                functions.add(new ImportedMlFunction(singleEntry.getKey(), new ArrayList<String>(this.inputs.keySet()), singleEntry.getValue().getRoot().toString(), this.asStrings(this.inputs), Optional.empty()));
            } else {
                for (Map.Entry<String, Signature> entry : this.expressions().entrySet()) {
                    functions.add(new ImportedMlFunction(entry.getKey(), new ArrayList<String>(this.inputs.keySet()), ((RankingExpression)entry.getValue()).getRoot().toString(), this.asStrings(this.inputs), Optional.empty()));
                }
            }
        }
        return functions;
    }

    private Map<String, String> asStrings(Map<String, ?> map) {
        HashMap<String, String> values = new HashMap<String, String>();
        for (Map.Entry<String, ?> entry : map.entrySet()) {
            values.put(entry.getKey(), entry.getValue().toString());
        }
        return values;
    }

    private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
        HashMap<String, String> values = new HashMap<String, String>();
        for (Map.Entry<String, RankingExpression> entry : map.entrySet()) {
            values.put(entry.getKey(), entry.getValue().getRoot().toString());
        }
        return values;
    }

    @Override
    public boolean isNative() {
        return true;
    }

    @Override
    public ImportedModel asNative() {
        return this;
    }

    public class Signature {
        private final String name;
        private final Map<String, String> inputs = new LinkedHashMap<String, String>();
        private final Map<String, String> outputs = new LinkedHashMap<String, String>();
        private final Map<String, String> skippedOutputs = new HashMap<String, String>();

        Signature(String name) {
            this.name = name;
        }

        public String name() {
            return this.name;
        }

        ImportedModel owner() {
            return ImportedModel.this;
        }

        public Map<String, String> inputs() {
            return Map.copyOf(this.inputs);
        }

        Map<String, TensorType> inputMap() {
            return Map.copyOf((Map)this.inputs.entrySet().stream().collect(CustomCollectors.toLinkedMap(Map.Entry::getValue, e -> this.owner().inputs.get(e.getValue()))));
        }

        public TensorType inputArgument(String inputName) {
            return this.owner().inputs().get(this.inputs.get(inputName));
        }

        public Map<String, String> outputs() {
            return Map.copyOf(this.outputs);
        }

        public Map<String, String> skippedOutputs() {
            return Map.copyOf(this.skippedOutputs);
        }

        public ImportedMlFunction outputFunction(String outputName, String functionName) {
            RankingExpression outputExpression = this.owner().expressions().get(this.outputs.get(outputName));
            if (outputExpression == null) {
                throw new IllegalArgumentException("Missing output '" + outputName + "' in " + String.valueOf(this));
            }
            return new ImportedMlFunction(functionName, new ArrayList<String>(this.inputs.values()), outputExpression.getRoot().toString(), ImportedModel.this.asStrings(this.inputMap()), Optional.empty());
        }

        public String toString() {
            return "signature '" + this.name + "'";
        }

        void input(String inputName, String argumentName) {
            this.inputs.put(inputName, argumentName);
        }

        void output(String name, String expressionName) {
            this.outputs.put(name, expressionName);
        }

        void skippedOutput(String name, String reason) {
            this.skippedOutputs.put(name, reason);
        }
    }
}

