package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import onnx.Onnx;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/OnnxImporter.class */
public class OnnxImporter extends ModelImporter {
    @Override // ai.vespa.rankingexpression.importer.ModelImporter, ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter
    public boolean canImport(String str) {
        File file = new File(str);
        if (file.isFile()) {
            return file.toString().endsWith(".onnx");
        }
        return false;
    }

    @Override // ai.vespa.rankingexpression.importer.ModelImporter
    public ImportedModel importModel(String str, String str2) {
        try {
            FileInputStream fileInputStream = new FileInputStream(str2);
            try {
                Onnx.ModelProto parseFrom = Onnx.ModelProto.parseFrom(fileInputStream);
                ImportedOnnxModel importedOnnxModel = new ImportedOnnxModel(str, str2, parseFrom);
                for (int i = 0; i < parseFrom.getGraph().getOutputCount(); i++) {
                    String asValidIdentifier = asValidIdentifier(parseFrom.getGraph().getOutput(i).getName());
                    importedOnnxModel.expression(asValidIdentifier, "onnx(" + str + ")." + asValidIdentifier);
                }
                fileInputStream.close();
                return importedOnnxModel;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not import ONNX model from '" + str2 + "'", e);
        }
    }

    public ImportedModel importModelAsNative(String str, String str2, ImportedMlModel.ModelType modelType) {
        try {
            FileInputStream fileInputStream = new FileInputStream(str2);
            try {
                ImportedModel convertModel = convertModel(str, str2, Onnx.ModelProto.parseFrom(fileInputStream), modelType);
                fileInputStream.close();
                return convertModel;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not import ONNX model from '" + str2 + "'", e);
        }
    }

    public static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ImportedModel convertModel(String str, String str2, Onnx.ModelProto modelProto, ImportedMlModel.ModelType modelType) {
        return convertIntermediateGraphToModel(GraphImporter.importGraph(str, modelProto), str2, modelType);
    }
}
