package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/ModelImporter.class */
public abstract class ModelImporter implements MlModelImporter {
    private static final Logger log = Logger.getLogger(ModelImporter.class.getName());

    @Override // ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter
    public abstract boolean canImport(String str);

    @Override // ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter
    public final ImportedModel importModel(String str, File file) {
        return importModel(str, file.toString());
    }

    public abstract ImportedModel importModel(String str, String str2);

    /* JADX INFO: Access modifiers changed from: protected */
    public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph intermediateGraph, String str, ImportedMlModel.ModelType modelType) {
        ImportedModel importedModel = new ImportedModel(intermediateGraph.name(), str, modelType);
        log.log(Level.FINER, () -> {
            return "Intermediate graph created from '" + str + "':\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(intermediateGraph.toFullString());
        });
        intermediateGraph.optimize();
        importSignatures(intermediateGraph, importedModel);
        importExpressions(intermediateGraph, importedModel);
        reportWarnings(intermediateGraph, importedModel);
        logVariableTypes(intermediateGraph);
        return importedModel;
    }

    private static void importSignatures(IntermediateGraph intermediateGraph, ImportedModel importedModel) {
        for (String str : intermediateGraph.signatures()) {
            ImportedModel.Signature signature = importedModel.signature(str);
            for (Map.Entry<String, String> entry : intermediateGraph.inputs(str).entrySet()) {
                signature.input(entry.getKey(), entry.getValue());
            }
            for (Map.Entry<String, String> entry2 : intermediateGraph.outputs(str).entrySet()) {
                signature.output(IntermediateOperation.vespaName(entry2.getKey()), entry2.getValue());
            }
        }
    }

    private static boolean isSignatureOutput(ImportedModel importedModel, IntermediateOperation intermediateOperation) {
        Iterator<ImportedModel.Signature> it = importedModel.signatures().values().iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().outputs().values().iterator();
            while (it2.hasNext()) {
                if (it2.next().equals(intermediateOperation.name())) {
                    return true;
                }
            }
        }
        return false;
    }

    private static void importExpressions(IntermediateGraph intermediateGraph, ImportedModel importedModel) {
        for (ImportedModel.Signature signature : importedModel.signatures().values()) {
            for (String str : signature.outputs().values()) {
                try {
                    if (importExpression(intermediateGraph.get(str), importedModel).isEmpty()) {
                        signature.skippedOutput(str, "No valid output function could be found.");
                    }
                } catch (IllegalArgumentException e) {
                    signature.skippedOutput(str, Exceptions.toMessageString(e));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Optional<TensorFunction<Reference>> importExpression(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        if (importedModel.expressions().containsKey(intermediateOperation.name())) {
            return intermediateOperation.function();
        }
        if (intermediateOperation.type().isEmpty()) {
            return Optional.empty();
        }
        if (intermediateOperation.isConstant()) {
            return importConstant(intermediateOperation, importedModel);
        }
        importExpressionInputs(intermediateOperation, importedModel);
        importRankingExpression(intermediateOperation, importedModel);
        importArgumentExpression(intermediateOperation, importedModel);
        importFunctionExpression(intermediateOperation, importedModel);
        return intermediateOperation.function();
    }

    private static void importExpressionInputs(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        intermediateOperation.inputs().forEach(intermediateOperation2 -> {
            importExpression(intermediateOperation2, importedModel);
        });
    }

    private static Optional<TensorFunction<Reference>> importConstant(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        String vespaName = intermediateOperation.vespaName();
        if (importedModel.hasLargeConstant(vespaName) || importedModel.hasSmallConstant(vespaName)) {
            return intermediateOperation.function();
        }
        Value orElseThrow = intermediateOperation.getConstantValue().orElseThrow(() -> {
            return new IllegalArgumentException("Operation '" + intermediateOperation.vespaName() + "' is constant but does not have a value.");
        });
        if (!(orElseThrow instanceof TensorValue)) {
            return intermediateOperation.function();
        }
        Tensor asTensor = orElseThrow.asTensor();
        if (asTensor.type().rank() == 0) {
            importedModel.smallConstant(vespaName, asTensor);
        } else {
            importedModel.largeConstant(vespaName, asTensor);
        }
        return intermediateOperation.function();
    }

    private static void importRankingExpression(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        if (intermediateOperation.function().isPresent()) {
            String name = intermediateOperation.name();
            if (importedModel.expressions().containsKey(name)) {
                return;
            }
            TensorFunction tensorFunction = intermediateOperation.function().get();
            if (isSignatureOutput(importedModel, intermediateOperation)) {
                OrderedTensorType orderedTensorType = intermediateOperation.type().get();
                OrderedTensorType standardType = OrderedTensorType.standardType(orderedTensorType);
                if (!orderedTensorType.equals(standardType)) {
                    tensorFunction = new Rename(tensorFunction, orderedTensorType.dimensionNames(), standardType.dimensionNames());
                }
            }
            try {
                importedModel.expression(name, new RankingExpression(name, tensorFunction.toString()));
            } catch (ParseException e) {
                throw new RuntimeException("Imported function " + String.valueOf(tensorFunction) + " cannot be parsed as a ranking expression", e);
            }
        }
    }

    private static void importArgumentExpression(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        if (intermediateOperation.isInput()) {
            importedModel.input(intermediateOperation.vespaName(), OrderedTensorType.standardType(intermediateOperation.type().get()).type());
        }
    }

    private static void importFunctionExpression(IntermediateOperation intermediateOperation, ImportedModel importedModel) {
        if (intermediateOperation.rankingExpressionFunction().isPresent()) {
            TensorFunction<Reference> tensorFunction = intermediateOperation.rankingExpressionFunction().get();
            try {
                importedModel.function(intermediateOperation.rankingExpressionFunctionName(), new RankingExpression(intermediateOperation.rankingExpressionFunctionName(), tensorFunction.toString()));
            } catch (ParseException e) {
                throw new RuntimeException("Model function " + String.valueOf(tensorFunction) + " cannot be parsed as a ranking expression", e);
            }
        }
    }

    private static void reportWarnings(IntermediateGraph intermediateGraph, ImportedModel importedModel) {
        Iterator<ImportedModel.Signature> it = importedModel.signatures().values().iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().outputs().values().iterator();
            while (it2.hasNext()) {
                reportWarnings(intermediateGraph.get(it2.next()), importedModel, new HashSet());
            }
        }
    }

    private static void reportWarnings(IntermediateOperation intermediateOperation, ImportedModel importedModel, Set<String> set) {
        if (set.contains(intermediateOperation.name())) {
            return;
        }
        for (String str : intermediateOperation.warnings()) {
        }
        Iterator<IntermediateOperation> it = intermediateOperation.inputs().iterator();
        while (it.hasNext()) {
            reportWarnings(it.next(), importedModel, set);
        }
        set.add(intermediateOperation.name());
    }

    private static void logVariableTypes(IntermediateGraph intermediateGraph) {
        for (IntermediateOperation intermediateOperation : intermediateGraph.operations().values()) {
            if ((intermediateOperation instanceof Constant) && intermediateOperation.type().isPresent()) {
                log.info("Importing model variable " + intermediateOperation.name() + " as " + intermediateOperation.vespaName() + " of type " + String.valueOf(intermediateOperation.type().get()));
            }
        }
    }
}
