/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Result;
import org.tensorflow.Session;
import org.tensorflow.SessionFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.CollectionDef;
import org.tensorflow.proto.ConfigProto;
import org.tensorflow.proto.MetaGraphDef;
import org.tensorflow.proto.RunOptions;
import org.tensorflow.proto.SavedModel;
import org.tensorflow.proto.SaverDef;
import org.tensorflow.proto.SignatureDef;

public class SavedModelBundle
implements AutoCloseable {
    public static final String DEFAULT_TAG = "serve";
    private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";
    private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";
    private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";
    private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer";
    private final Graph graph;
    private final Session session;
    private final MetaGraphDef metaGraphDef;
    private final Map<String, SessionFunction> functions;

    public static SavedModelBundle load(String exportDir, String ... tags) {
        Loader loader = SavedModelBundle.loader(exportDir);
        if (tags != null && tags.length > 0) {
            loader.withTags(tags);
        }
        return loader.load();
    }

    public static Loader loader(String exportDir) {
        return new Loader(exportDir);
    }

    public static Exporter exporter(String exportDir) {
        return new Exporter(exportDir);
    }

    public MetaGraphDef metaGraphDef() {
        return this.metaGraphDef;
    }

    public Graph graph() {
        return this.graph;
    }

    public Session session() {
        return this.session;
    }

    public List<Signature> signatures() {
        return this.functions.values().stream().map(SessionFunction::signature).filter(s -> !s.key().equals(INIT_OP_SIGNATURE_KEY)).collect(Collectors.toList());
    }

    public SessionFunction function(String signatureKey) {
        SessionFunction function = this.functions.get(signatureKey);
        if (function == null) {
            throw new IllegalArgumentException(String.format("Function with signature [%s] not found", signatureKey));
        }
        return function;
    }

    public List<SessionFunction> functions() {
        return new ArrayList<SessionFunction>(this.functions.values());
    }

    public Result call(Map<String, Tensor> arguments) {
        SessionFunction function = null;
        function = this.functions.size() == 1 ? this.functions.values().iterator().next() : this.functions.get("serving_default");
        if (function == null) {
            throw new IllegalArgumentException("Cannot elect a default function for this model");
        }
        return function.call(arguments);
    }

    @Override
    public void close() {
        this.session.close();
        this.graph.close();
    }

    private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map<String, SessionFunction> functions) {
        this.graph = graph;
        this.session = session;
        this.metaGraphDef = metaGraphDef;
        this.functions = functions;
    }

    private static GraphOperation findInitOp(Graph graph, Map<String, Signature> signatures, Map<String, CollectionDef> collections) {
        Signature initSig = signatures.get(INIT_OP_SIGNATURE_KEY);
        if (initSig != null) {
            return (GraphOperation)graph.outputOrThrow(initSig.getOutputs().get((Object)INIT_OP_SIGNATURE_KEY).name).op();
        }
        CollectionDef initCollection = collections.containsKey(MAIN_OP_COLLECTION_KEY) ? collections.get(MAIN_OP_COLLECTION_KEY) : collections.get(LEGACY_INIT_OP_COLLECTION_KEY);
        if (initCollection != null) {
            CollectionDef.NodeList nodes = initCollection.getNodeList();
            if (nodes.getValueCount() != 1) {
                throw new IllegalArgumentException("Expected exactly one main op in saved model.");
            }
            return (GraphOperation)graph.outputOrThrow(nodes.getValue(0)).op();
        }
        return null;
    }

    private static SavedModelBundle fromHandle(TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
        Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
        Session session = new Session(graph, sessionHandle);
        HashMap<String, Signature> functions = new HashMap<String, Signature>(metaGraphDef.getSignatureDefCount());
        metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> {
            if (!functions.containsKey(signatureName)) {
                Signature signature = new Signature((String)signatureName, (SignatureDef)signatureDef);
                functions.put((String)signatureName, signature);
            }
        });
        GraphOperation initOp = SavedModelBundle.findInitOp(graph, functions, metaGraphDef.getCollectionDefMap());
        if (initOp != null) {
            graph.registerInitializer(initOp, false);
        }
        session.markAllInitializersAsRan();
        if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
            metaGraphDef.getCollectionDefMap().get(TABLE_INITIALIZERS_COLLECTION_KEY).getNodeList().getValueList().forEach(node -> graph.registerInitializer(graph.operationOrThrow((String)node), false));
        }
        return new SavedModelBundle(graph, session, metaGraphDef, functions.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> new SessionFunction((Signature)e.getValue(), session))));
    }

    private static SavedModelBundle load(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
        SavedModelBundle bundle = null;
        try (PointerScope scope = new PointerScope();){
            TF_Status status = TF_Status.newStatus();
            TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
            if (config != null) {
                BytePointer configBytes = new BytePointer(config.toByteArray());
                tensorflow.TF_SetConfig(opts, (Pointer)configBytes, configBytes.capacity(), status);
                status.throwExceptionIfNotOK();
            }
            TF_Buffer runOpts = TF_Buffer.newBufferFromString((Message)runOptions);
            TF_Graph graph = tensorflow.TF_NewGraph();
            TF_Buffer metagraphDef = TF_Buffer.newBuffer();
            TF_Session session = TF_Session.loadSessionFromSavedModel(opts, runOpts, exportDir, tags, graph, metagraphDef, status);
            status.throwExceptionIfNotOK();
            try {
                bundle = SavedModelBundle.fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer()));
                graph.retainReference();
                session.retainReference();
            }
            catch (InvalidProtocolBufferException e) {
                throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
            }
        }
        bundle.session.initialize();
        return bundle;
    }

    private static void validateTags(String[] tags) {
        if (tags == null || Arrays.stream(tags).anyMatch(Objects::isNull)) {
            throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
        }
    }

    static {
        try {
            Class.forName("org.tensorflow.TensorFlow");
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static final class Exporter {
        private final String exportDir;
        private String[] tags = new String[]{"serve"};
        private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
        private Session session;
        private final Map<String, SessionFunction> functions = new LinkedHashMap<String, SessionFunction>();

        public Exporter withTags(String ... tags) {
            SavedModelBundle.validateTags(tags);
            this.tags = tags;
            return this;
        }

        public Exporter withSession(Session session) {
            if (this.session != null && this.session != session) {
                throw new IllegalStateException("This exporter already has a session that differs from the passed session");
            }
            this.session = session;
            return this;
        }

        public Exporter withFunction(SessionFunction function) {
            Signature signature = function.signature();
            if (this.functions.containsKey(signature.key())) {
                throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model");
            }
            if (this.session != null && this.session != function.session()) {
                throw new UnsupportedOperationException("This exporter already has a session that differs from the passed function's session");
            }
            this.session = function.session();
            this.functions.put(signature.key(), function);
            this.metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef());
            return this;
        }

        public Exporter withFunctions(SessionFunction ... functions) {
            for (SessionFunction f : functions) {
                this.withFunction(f);
            }
            return this;
        }

        public Exporter withSignature(Signature signature) {
            if (this.session == null) {
                throw new IllegalStateException("Session has not been set yet, you must call withSession or withFunction first.");
            }
            return this.withFunction(this.session.function(signature));
        }

        public Exporter withSignatures(Signature ... signatures) {
            for (Signature s : signatures) {
                this.withSignature(s);
            }
            return this;
        }

        public void export() throws IOException {
            if (this.functions.isEmpty()) {
                throw new IllegalStateException("Model should contain at least one valid function");
            }
            Graph graph = this.session.graph();
            SaverDef saverDef = graph.saverDef();
            MetaGraphDef.Builder metaGraphDef = this.metaGraphDefBuilder.setSaverDef(saverDef).setGraphDef(graph.toGraphDef()).setMetaInfoDef(MetaGraphDef.MetaInfoDef.newBuilder().addAllTags(Arrays.asList(this.tags)));
            this.functions.forEach((k, f) -> metaGraphDef.putSignatureDef((String)k, f.signature().asSignatureDef()));
            Path variableDir = Paths.get(this.exportDir, "variables");
            variableDir.toFile().mkdirs();
            this.session.save(variableDir.resolve("variables").toString());
            SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build();
            try (FileOutputStream file = new FileOutputStream(Paths.get(this.exportDir, "saved_model.pb").toString());){
                savedModelDef.writeTo(file);
            }
        }

        Exporter(String exportDir) {
            this.exportDir = exportDir;
        }
    }

    public static final class Loader {
        private String exportDir = null;
        private String[] tags = new String[]{"serve"};
        private ConfigProto configProto = null;
        private RunOptions runOptions = null;

        public SavedModelBundle load() {
            return SavedModelBundle.load(this.exportDir, this.tags, this.configProto, this.runOptions);
        }

        public Loader withRunOptions(RunOptions options) {
            this.runOptions = options;
            return this;
        }

        public Loader withConfigProto(ConfigProto configProto) {
            this.configProto = configProto;
            return this;
        }

        public Loader withTags(String ... tags) {
            SavedModelBundle.validateTags(tags);
            this.tags = tags;
            return this;
        }

        private Loader(String exportDir) {
            this.exportDir = exportDir;
        }
    }
}

