package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import onnx.Onnx;

/* loaded from: input_file:com/yahoo/vespa/model/ml/OnnxModelInfo.class */
public class OnnxModelInfo {
    private static final Logger log = Logger.getLogger(OnnxModelInfo.class.getName());
    private final ApplicationPackage app;
    private final String modelPath;
    private final String defaultOutput;
    private final Map<String, OnnxTypeInfo> inputs;
    private final Map<String, OnnxTypeInfo> outputs;
    private final Map<String, TensorType> vespaTypes = new HashMap();
    private final Set<String> initializers;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/vespa/model/ml/OnnxModelInfo$OnnxDimensionInfo.class */
    public static class OnnxDimensionInfo {
        private final long size;
        private final String symbolicName;

        OnnxDimensionInfo(long j) {
            this.size = j;
            this.symbolicName = null;
        }

        OnnxDimensionInfo(String str) {
            this.size = 0L;
            this.symbolicName = str;
        }

        long getSize() {
            return this.size;
        }

        String getSymbolicName() {
            return this.symbolicName;
        }

        boolean hasSymbolicName() {
            return this.symbolicName != null;
        }

        boolean unknownDimensionSize() {
            return hasSymbolicName() || this.size <= 0;
        }

        public String toString() {
            return hasSymbolicName() ? "\"" + this.symbolicName + "\"" : Long.toString(this.size);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/vespa/model/ml/OnnxModelInfo$OnnxTypeInfo.class */
    public static class OnnxTypeInfo {
        private final TensorType.Value valueType;
        private final List<OnnxDimensionInfo> dimensions = new ArrayList();

        OnnxTypeInfo(TensorType.Value value) {
            this.valueType = value;
        }

        void addDimension(long j) {
            this.dimensions.add(new OnnxDimensionInfo(j));
        }

        void addDimension(String str) {
            this.dimensions.add(new OnnxDimensionInfo(str));
        }

        boolean containsUnknownDimensionSizes() {
            return this.dimensions.stream().anyMatch((v0) -> {
                return v0.unknownDimensionSize();
            });
        }

        TensorType.Value valueType() {
            return this.valueType;
        }

        List<OnnxDimensionInfo> dimensions() {
            return this.dimensions;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public TensorType toVespaTensorType() {
            return toVespaTensorType(null, null);
        }

        TensorType toVespaTensorType(Map<String, Long> map, Set<Long> set) {
            TensorType.Builder builder = new TensorType.Builder(this.valueType);
            for (int i = 0; i < this.dimensions.size(); i++) {
                String str = "d" + i;
                OnnxDimensionInfo onnxDimensionInfo = this.dimensions.get(i);
                long size = onnxDimensionInfo.getSize();
                if (onnxDimensionInfo.hasSymbolicName() && map != null && map.containsKey(onnxDimensionInfo.getSymbolicName())) {
                    size = map.get(onnxDimensionInfo.getSymbolicName()).longValue();
                }
                if (size == 0 && map != null) {
                    HashSet hashSet = new HashSet(map.values());
                    if (hashSet.size() == 1) {
                        size = ((Long) hashSet.iterator().next()).longValue();
                    }
                }
                if (size < 0 && set != null && set.size() > 0) {
                    size = set.iterator().next().longValue();
                }
                if (size <= 0) {
                    return TensorType.empty;
                }
                builder.indexed(str, size);
            }
            return builder.build();
        }

        boolean needModelProbe(Map<String, Long> map) {
            for (OnnxDimensionInfo onnxDimensionInfo : this.dimensions) {
                if (onnxDimensionInfo.hasSymbolicName()) {
                    if (map == null || !map.containsKey(onnxDimensionInfo.getSymbolicName())) {
                        return true;
                    }
                } else if (onnxDimensionInfo.getSize() == 0) {
                    return true;
                }
            }
            return false;
        }

        public String toString() {
            return "(" + this.valueType.id() + ")[" + ((String) this.dimensions.stream().map((v0) -> {
                return v0.toString();
            }).collect(Collectors.joining(","))) + "]";
        }
    }

    private OnnxModelInfo(ApplicationPackage applicationPackage, String str, Map<String, OnnxTypeInfo> map, Map<String, OnnxTypeInfo> map2, Set<String> set, String str2) {
        this.app = applicationPackage;
        this.modelPath = str;
        this.inputs = Collections.unmodifiableMap(map);
        this.outputs = Collections.unmodifiableMap(map2);
        this.defaultOutput = str2;
        this.initializers = Set.copyOf(set);
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public Set<String> getInputs() {
        return this.inputs.keySet();
    }

    public Set<String> getOutputs() {
        return this.outputs.keySet();
    }

    public Set<String> getInitializers() {
        return this.initializers;
    }

    public String getDefaultOutput() {
        return this.defaultOutput;
    }

    public TensorType getTensorType(String str, Map<String, TensorType> map) {
        OnnxTypeInfo onnxTypeInfo = this.outputs.get(str);
        if (onnxTypeInfo == null) {
            throw new IllegalArgumentException("Could not find type for output '" + str + "'");
        }
        if (!onnxTypeInfo.containsUnknownDimensionSizes()) {
            return this.vespaTypes.computeIfAbsent(str, str2 -> {
                return onnxTypeInfo.toVespaTensorType();
            });
        }
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        resolveUnknownDimensionSizes(map, hashMap, hashSet);
        TensorType tensorType = TensorType.empty;
        if (map.size() > 0 && onnxTypeInfo.needModelProbe(hashMap)) {
            tensorType = OnnxModelProbe.probeModel(this.app, Path.fromString(this.modelPath), str, map);
        }
        if (tensorType.equals(TensorType.empty)) {
            tensorType = onnxTypeInfo.toVespaTensorType(hashMap, hashSet);
        }
        return tensorType;
    }

    private void resolveUnknownDimensionSizes(Map<String, TensorType> map, Map<String, Long> map2, Set<Long> set) {
        for (Map.Entry<String, OnnxTypeInfo> entry : this.inputs.entrySet()) {
            String key = entry.getKey();
            OnnxTypeInfo value = entry.getValue();
            TensorType tensorType = map.get(key);
            if (tensorType != null && tensorType.dimensions().size() == value.dimensions().size()) {
                for (int i = 0; i < tensorType.dimensions().size(); i++) {
                    if (!((TensorType.Dimension) tensorType.dimensions().get(i)).size().isEmpty()) {
                        Long l = (Long) ((TensorType.Dimension) tensorType.dimensions().get(i)).size().get();
                        if (value.dimensions().get(i).getSize() == -1) {
                            set.add(l);
                            if (set.size() > 1) {
                                throw new IllegalArgumentException("Found conflicting sizes for unbound dimension for type '" + value + "'");
                            }
                        } else if (value.dimensions().get(i).hasSymbolicName()) {
                            String symbolicName = value.dimensions().get(i).getSymbolicName();
                            if (map2.containsKey(symbolicName) && !map2.get(symbolicName).equals(l)) {
                                throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + symbolicName + "' for input '" + key + "'");
                            }
                            map2.put(symbolicName, l);
                        } else {
                            continue;
                        }
                    }
                }
            }
        }
    }

    public static OnnxModelInfo load(String str, ApplicationPackage applicationPackage) {
        Path fromString = Path.fromString(str);
        if (applicationPackage.getFile(fromString).exists()) {
            return loadFromFile(fromString, applicationPackage);
        }
        if (applicationPackage.getFile(generatedModelInfoPath(fromString)).exists()) {
            return loadFromGeneratedInfo(fromString, applicationPackage);
        }
        throw new IllegalArgumentException("Unable to find ONNX model '" + str + "'");
    }

    public static boolean modelExists(String str, ApplicationPackage applicationPackage) {
        return applicationPackage.getFile(Path.fromString(str)).exists() || applicationPackage.getFile(generatedModelInfoPath(Path.fromString(str))).exists();
    }

    private static OnnxModelInfo loadFromFile(Path path, ApplicationPackage applicationPackage) {
        try {
            InputStream createInputStream = applicationPackage.getFile(path).createInputStream();
            try {
                String onnxModelToJson = onnxModelToJson(Onnx.ModelProto.parseFrom(createInputStream), path);
                storeGeneratedInfo(onnxModelToJson, path, applicationPackage);
                OnnxModelInfo jsonToModelInfo = jsonToModelInfo(onnxModelToJson, applicationPackage);
                if (createInputStream != null) {
                    createInputStream.close();
                }
                return jsonToModelInfo;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to parse ONNX model", e);
        }
    }

    private static OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage applicationPackage) {
        try {
            return jsonToModelInfo(readGeneratedInfo(path, applicationPackage), applicationPackage);
        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to parse ONNX model", e);
        }
    }

    private static String readGeneratedInfo(Path path, ApplicationPackage applicationPackage) throws IOException {
        return IOUtils.readAll(applicationPackage.getFile(generatedModelInfoPath(path)).createReader());
    }

    private static void storeGeneratedInfo(String str, Path path, ApplicationPackage applicationPackage) throws IOException {
        IOUtils.writeFile(applicationPackage.getFileReference(generatedModelInfoPath(path)), str, false);
    }

    private static Path generatedModelInfoPath(Path path) {
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(asValidIdentifier(path.getRelative()) + ".modelinfo.json");
    }

    private static String onnxModelToJson(Onnx.ModelProto modelProto, Path path) throws IOException {
        Set set = (Set) modelProto.getGraph().getInitializerList().stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toSet());
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        JsonGenerator createGenerator = new JsonFactory().createGenerator(byteArrayOutputStream, JsonEncoding.UTF8);
        createGenerator.writeStartObject();
        createGenerator.writeStringField("path", path.toString());
        createGenerator.writeArrayFieldStart("inputs");
        int i = 0;
        for (Onnx.ValueInfoProto valueInfoProto : modelProto.getGraph().getInputList()) {
            if (set.contains(valueInfoProto.getName())) {
                log.fine(() -> {
                    return "For '%s': skipping name '%s' as it's an initializer".formatted(path.getName(), valueInfoProto.getName());
                });
                i++;
            } else {
                onnxTypeToJson(createGenerator, valueInfoProto);
            }
        }
        if (i > 0) {
            log.info("For '%s': skipped %d inputs that were also listed in initializers".formatted(path.getName(), Integer.valueOf(i)));
        }
        createGenerator.writeEndArray();
        createGenerator.writeArrayFieldStart("outputs");
        Iterator<Onnx.ValueInfoProto> it = modelProto.getGraph().getOutputList().iterator();
        while (it.hasNext()) {
            onnxTypeToJson(createGenerator, it.next());
        }
        createGenerator.writeEndArray();
        createGenerator.writeArrayFieldStart("initializers");
        for (Onnx.TensorProto tensorProto : modelProto.getGraph().getInitializerList()) {
            createGenerator.writeStartObject();
            createGenerator.writeStringField("name", tensorProto.getName());
            createGenerator.writeEndObject();
        }
        createGenerator.writeEndArray();
        createGenerator.writeEndObject();
        createGenerator.close();
        return byteArrayOutputStream.toString();
    }

    public static OnnxModelInfo jsonToModelInfo(String str, ApplicationPackage applicationPackage) throws IOException {
        JsonNode readTree = new ObjectMapper().readTree(str);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashSet hashSet = new HashSet();
        String str2 = VespaModel.ROOT_CONFIGID;
        String str3 = null;
        if (readTree.has("path")) {
            str3 = readTree.get("path").textValue();
        }
        Iterator it = readTree.get("inputs").iterator();
        while (it.hasNext()) {
            JsonNode jsonNode = (JsonNode) it.next();
            hashMap.put(jsonNode.get("name").textValue(), jsonToTypeInfo(jsonNode));
        }
        Iterator it2 = readTree.get("outputs").iterator();
        while (it2.hasNext()) {
            JsonNode jsonNode2 = (JsonNode) it2.next();
            hashMap2.put(jsonNode2.get("name").textValue(), jsonToTypeInfo(jsonNode2));
        }
        if (readTree.get("outputs").has(0)) {
            str2 = readTree.get("outputs").get(0).get("name").textValue();
        }
        JsonNode jsonNode3 = readTree.get("initializers");
        if (jsonNode3 != null) {
            Iterator it3 = jsonNode3.iterator();
            while (it3.hasNext()) {
                hashSet.add(((JsonNode) it3.next()).get("name").textValue());
            }
        }
        return new OnnxModelInfo(applicationPackage, str3, hashMap, hashMap2, hashSet, str2);
    }

    private static void onnxTypeToJson(JsonGenerator jsonGenerator, Onnx.ValueInfoProto valueInfoProto) throws IOException {
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("name", valueInfoProto.getName());
        jsonGenerator.writeStringField("type", onnxValueTypeToString(valueInfoProto.getType().getTensorType().getElemType()));
        jsonGenerator.writeArrayFieldStart("dim");
        for (Onnx.TensorShapeProto.Dimension dimension : valueInfoProto.getType().getTensorType().getShape().getDimList()) {
            jsonGenerator.writeStartObject();
            if (dimension.hasDimParam()) {
                jsonGenerator.writeStringField("type", "param");
                jsonGenerator.writeStringField("size", dimension.getDimParam());
            } else {
                jsonGenerator.writeStringField("type", "value");
                jsonGenerator.writeNumberField("size", dimension.getDimValue());
            }
            jsonGenerator.writeEndObject();
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeEndObject();
    }

    private static OnnxTypeInfo jsonToTypeInfo(JsonNode jsonNode) {
        OnnxTypeInfo onnxTypeInfo = new OnnxTypeInfo(stringToValueType(jsonNode.get("type").textValue()));
        Iterator it = jsonNode.get("dim").iterator();
        while (it.hasNext()) {
            JsonNode jsonNode2 = (JsonNode) it.next();
            if (jsonNode2.get("type").textValue().equals("param")) {
                onnxTypeInfo.addDimension(jsonNode2.get("size").textValue());
            } else {
                onnxTypeInfo.addDimension(jsonNode2.get("size").longValue());
            }
        }
        return onnxTypeInfo;
    }

    private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) {
        switch (dataType) {
            case FLOAT:
                return "float";
            case DOUBLE:
                return "double";
            case BOOL:
                return "float";
            case INT8:
                return "float";
            case INT16:
                return "float";
            case INT32:
                return "float";
            case INT64:
                return "float";
            case UINT8:
                return "float";
            case UINT16:
                return "float";
            case UINT32:
                return "float";
            case UINT64:
                return "float";
            default:
                throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type");
        }
    }

    private static TensorType.Value stringToValueType(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1325958191:
                if (str.equals("double")) {
                    z = true;
                    break;
                }
                break;
            case 97526364:
                if (str.equals("float")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return TensorType.Value.FLOAT;
            case true:
                return TensorType.Value.DOUBLE;
            default:
                throw new IllegalArgumentException("Unknown tensor value type: " + str);
        }
    }

    public static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }
}
