/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import com.google.protobuf.CodedInputStream;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.MessageLite;
import com.google.protobuf.WireFormat;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import onnx.Onnx;

public class OnnxStreamParser {
    private OnnxStreamParser() {
    }

    public static Set<Path> getExternalDataLocations(Path model) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(Files.newInputStream(model, new OpenOption[0]));){
            int tag;
            CodedInputStream stream = CodedInputStream.newInstance((InputStream)bis);
            HashSet externalDataPaths = new HashSet();
            block9: while (!stream.isAtEnd() && WireFormat.getTagFieldNumber((int)(tag = stream.readTag())) != 0) {
                switch (WireFormat.getTagFieldNumber((int)tag)) {
                    case 7: {
                        OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseGraphProto(s, externalDataPaths));
                        continue block9;
                    }
                    case 20: {
                        OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseTrainingInfoProto(s, externalDataPaths));
                        continue block9;
                    }
                }
                stream.skipField(tag);
            }
            Set<Path> set = Set.copyOf(externalDataPaths);
            return set;
        }
    }

    private static void parseGraphProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        block5: while (!stream.isAtEnd()) {
            int tag = stream.readTag();
            if (WireFormat.getTagFieldNumber((int)tag) == 0) {
                return;
            }
            switch (WireFormat.getTagFieldNumber((int)tag)) {
                case 1: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseNodeProto(s, externalDataPaths));
                    continue block5;
                }
                case 5: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseTensorProto(s, externalDataPaths));
                    continue block5;
                }
                case 15: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseSparseTensorProto(s, externalDataPaths));
                    continue block5;
                }
            }
            stream.skipField(tag);
        }
    }

    private static void parseNodeProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        while (!stream.isAtEnd()) {
            int tag = stream.readTag();
            if (WireFormat.getTagFieldNumber((int)tag) == 0) {
                return;
            }
            if (WireFormat.getTagFieldNumber((int)tag) == 5) {
                OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseAttributeProto(s, externalDataPaths));
                continue;
            }
            stream.skipField(tag);
        }
    }

    private static void parseAttributeProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        int tag;
        int fieldNumber;
        Onnx.AttributeProto.AttributeType type = Onnx.AttributeProto.AttributeType.UNDEFINED;
        ArrayList<byte[]> tensorData = new ArrayList<byte[]>();
        ArrayList<byte[]> sparseTensorData = new ArrayList<byte[]>();
        block5: while (!stream.isAtEnd() && (fieldNumber = WireFormat.getTagFieldNumber((int)(tag = stream.readTag()))) != 0) {
            switch (fieldNumber) {
                case 5: 
                case 10: {
                    tensorData.add(OnnxStreamParser.readDelimitedBytes(stream));
                    continue block5;
                }
                case 22: 
                case 23: {
                    sparseTensorData.add(OnnxStreamParser.readDelimitedBytes(stream));
                    continue block5;
                }
                case 20: {
                    type = Onnx.AttributeProto.AttributeType.forNumber(stream.readEnum());
                    continue block5;
                }
            }
            stream.skipField(tag);
        }
        if (type == Onnx.AttributeProto.AttributeType.TENSOR || type == Onnx.AttributeProto.AttributeType.TENSORS) {
            for (byte[] data : tensorData) {
                OnnxStreamParser.parseTensorProto(CodedInputStream.newInstance((byte[])data), externalDataPaths);
            }
        }
        if (type == Onnx.AttributeProto.AttributeType.SPARSE_TENSOR || type == Onnx.AttributeProto.AttributeType.SPARSE_TENSORS) {
            for (byte[] data : sparseTensorData) {
                OnnxStreamParser.parseSparseTensorProto(CodedInputStream.newInstance((byte[])data), externalDataPaths);
            }
        }
    }

    private static void parseTrainingInfoProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        block3: while (!stream.isAtEnd()) {
            int tag = stream.readTag();
            if (WireFormat.getTagFieldNumber((int)tag) == 0) {
                return;
            }
            switch (WireFormat.getTagFieldNumber((int)tag)) {
                case 1: 
                case 2: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseGraphProtoInitializers(s, externalDataPaths));
                    continue block3;
                }
            }
            stream.skipField(tag);
        }
    }

    private static void parseGraphProtoInitializers(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        block4: while (!stream.isAtEnd()) {
            int tag = stream.readTag();
            if (WireFormat.getTagFieldNumber((int)tag) == 0) {
                return;
            }
            switch (WireFormat.getTagFieldNumber((int)tag)) {
                case 5: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseTensorProto(s, externalDataPaths));
                    continue block4;
                }
                case 15: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseSparseTensorProto(s, externalDataPaths));
                    continue block4;
                }
            }
            stream.skipField(tag);
        }
    }

    private static void parseTensorProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        int tag;
        Onnx.TensorProto.DataLocation dataLocation = Onnx.TensorProto.DataLocation.DEFAULT;
        String location = null;
        block4: while (!stream.isAtEnd() && WireFormat.getTagFieldNumber((int)(tag = stream.readTag())) != 0) {
            switch (WireFormat.getTagFieldNumber((int)tag)) {
                case 13: {
                    Onnx.StringStringEntryProto.Builder builder = Onnx.StringStringEntryProto.newBuilder();
                    stream.readMessage((MessageLite.Builder)builder, ExtensionRegistryLite.getEmptyRegistry());
                    Onnx.StringStringEntryProto entry = builder.build();
                    if (!"location".equals(entry.getKey())) continue block4;
                    location = entry.getValue();
                    continue block4;
                }
                case 14: {
                    dataLocation = Onnx.TensorProto.DataLocation.forNumber(stream.readEnum());
                    continue block4;
                }
            }
            stream.skipField(tag);
        }
        if (dataLocation == Onnx.TensorProto.DataLocation.EXTERNAL && location != null) {
            if (location.contains("..")) {
                throw new IllegalArgumentException("External data path '" + location + "' must not contain '..'");
            }
            externalDataPaths.add(Paths.get(location, new String[0]));
        }
    }

    private static void parseSparseTensorProto(CodedInputStream stream, Set<Path> externalDataPaths) throws IOException {
        block3: while (!stream.isAtEnd()) {
            int tag = stream.readTag();
            if (WireFormat.getTagFieldNumber((int)tag) == 0) {
                return;
            }
            switch (WireFormat.getTagFieldNumber((int)tag)) {
                case 1: 
                case 2: {
                    OnnxStreamParser.parseDelimited(stream, s -> OnnxStreamParser.parseTensorProto(s, externalDataPaths));
                    continue block3;
                }
            }
            stream.skipField(tag);
        }
    }

    private static byte[] readDelimitedBytes(CodedInputStream stream) throws IOException {
        int length = stream.readRawVarint32();
        return stream.readRawBytes(length);
    }

    private static void parseDelimited(CodedInputStream stream, CodedInputStreamConsumer consumer) throws IOException {
        int length = stream.readRawVarint32();
        int limit = stream.pushLimit(length);
        consumer.accept(stream);
        stream.popLimit(limit);
    }

    @FunctionalInterface
    private static interface CodedInputStreamConsumer {
        public void accept(CodedInputStream var1) throws IOException;
    }
}

