| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file reflection.cc |
| * \brief Utilities to save/load/construct TVM objects |
| */ |
| #include <tvm/base.h> |
| #include <tvm/expr.h> |
| #include <tvm/attrs.h> |
| #include <tvm/node/container.h> |
| #include <tvm/packed_func_ext.h> |
| #include <tvm/runtime/ndarray.h> |
| #include <tvm/runtime/packed_func.h> |
| #include <dmlc/json.h> |
| #include <dmlc/memory_io.h> |
| #include <string> |
| #include "../common/base64.h" |
| |
| namespace dmlc { |
| DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); |
| } // namespace dmlc |
| |
| namespace tvm { |
| |
| ::dmlc::Registry<NodeFactoryReg>* NodeFactoryReg::Registry() { |
| return ::dmlc::Registry<NodeFactoryReg>::Get(); |
| } |
| |
| inline std::string Type2String(const Type& t) { |
| return runtime::TVMType2String(Type2TVMType(t)); |
| } |
| |
| |
| inline Type String2Type(std::string s) { |
| return TVMType2Type(runtime::String2TVMType(s)); |
| } |
| |
| |
| // indexer to index all the ndoes |
| class NodeIndexer : public AttrVisitor { |
| public: |
| std::unordered_map<Node*, size_t> node_index{{nullptr, 0}}; |
| std::vector<Node*> node_list{nullptr}; |
| std::unordered_map<DLTensor*, size_t> tensor_index; |
| std::vector<DLTensor*> tensor_list; |
| |
| void Visit(const char* key, double* value) final {} |
| void Visit(const char* key, int64_t* value) final {} |
| void Visit(const char* key, uint64_t* value) final {} |
| void Visit(const char* key, int* value) final {} |
| void Visit(const char* key, bool* value) final {} |
| void Visit(const char* key, std::string* value) final {} |
| void Visit(const char* key, void** value) final {} |
| void Visit(const char* key, Type* value) final {} |
| void Visit(const char* key, NodeRef* value) final { |
| MakeIndex(value->node_.get()); |
| } |
| void Visit(const char* key, runtime::NDArray* value) final { |
| DLTensor* ptr = const_cast<DLTensor*>((*value).operator->()); |
| if (tensor_index.count(ptr)) return; |
| CHECK_EQ(tensor_index.size(), tensor_list.size()); |
| tensor_index[ptr] = tensor_list.size(); |
| tensor_list.push_back(ptr); |
| } |
| // make index of all the children of node |
| void MakeIndex(Node* node) { |
| if (node == nullptr) return; |
| if (node_index.count(node)) return; |
| CHECK_EQ(node_index.size(), node_list.size()); |
| node_index[node] = node_list.size(); |
| node_list.push_back(node); |
| |
| if (node->is_type<ArrayNode>()) { |
| ArrayNode* n = static_cast<ArrayNode*>(node); |
| for (const auto& sp : n->data) { |
| MakeIndex(sp.get()); |
| } |
| } else if (node->is_type<MapNode>()) { |
| MapNode* n = static_cast<MapNode*>(node); |
| for (const auto& kv : n->data) { |
| MakeIndex(kv.first.get()); |
| MakeIndex(kv.second.get()); |
| } |
| } else if (node->is_type<StrMapNode>()) { |
| StrMapNode* n = static_cast<StrMapNode*>(node); |
| for (const auto& kv : n->data) { |
| MakeIndex(kv.second.get()); |
| } |
| } else { |
| node->VisitAttrs(this); |
| } |
| } |
| }; |
| |
| // use map so attributes are ordered. |
| using AttrMap = std::map<std::string, std::string>; |
| |
| // A Node structure for JSON node. |
| struct JSONNode { |
| // The type key of the data |
| std::string type_key; |
| // The global key for global object |
| std::string global_key; |
| // the attributes |
| AttrMap attrs; |
| // container keys |
| std::vector<std::string> keys; |
| // container data |
| std::vector<size_t> data; |
| |
| void Save(dmlc::JSONWriter *writer) const { |
| writer->BeginObject(); |
| writer->WriteObjectKeyValue("type_key", type_key); |
| if (global_key.size() != 0) { |
| writer->WriteObjectKeyValue("global_key", global_key); |
| } |
| if (attrs.size() != 0) { |
| writer->WriteObjectKeyValue("attrs", attrs); |
| } |
| if (keys.size() != 0) { |
| writer->WriteObjectKeyValue("keys", keys); |
| } |
| if (data.size() != 0) { |
| writer->WriteObjectKeyValue("data", data); |
| } |
| writer->EndObject(); |
| } |
| |
| void Load(dmlc::JSONReader *reader) { |
| attrs.clear(); |
| data.clear(); |
| global_key.clear(); |
| type_key.clear(); |
| dmlc::JSONObjectReadHelper helper; |
| helper.DeclareOptionalField("type_key", &type_key); |
| helper.DeclareOptionalField("global_key", &global_key); |
| helper.DeclareOptionalField("attrs", &attrs); |
| helper.DeclareOptionalField("keys", &keys); |
| helper.DeclareOptionalField("data", &data); |
| helper.ReadAllFields(reader); |
| } |
| }; |
| |
| class JSONAttrGetter : public AttrVisitor { |
| public: |
| const std::unordered_map<Node*, size_t>* node_index_; |
| const std::unordered_map<DLTensor*, size_t>* tensor_index_; |
| JSONNode* node_; |
| |
| void Visit(const char* key, double* value) final { |
| node_->attrs[key] = std::to_string(*value); |
| } |
| void Visit(const char* key, int64_t* value) final { |
| node_->attrs[key] = std::to_string(*value); |
| } |
| void Visit(const char* key, uint64_t* value) final { |
| node_->attrs[key] = std::to_string(*value); |
| } |
| void Visit(const char* key, int* value) final { |
| node_->attrs[key] = std::to_string(*value); |
| } |
| void Visit(const char* key, bool* value) final { |
| node_->attrs[key] = std::to_string(*value); |
| } |
| void Visit(const char* key, std::string* value) final { |
| node_->attrs[key] = *value; |
| } |
| void Visit(const char* key, void** value) final { |
| LOG(FATAL) << "not allowed to serialize a pointer"; |
| } |
| void Visit(const char* key, Type* value) final { |
| node_->attrs[key] = Type2String(*value); |
| } |
| void Visit(const char* key, NodeRef* value) final { |
| node_->attrs[key] = std::to_string( |
| node_index_->at(value->node_.get())); |
| } |
| void Visit(const char* key, runtime::NDArray* value) final { |
| node_->attrs[key] = std::to_string( |
| tensor_index_->at(const_cast<DLTensor*>((*value).operator->()))); |
| } |
| // Get the node |
| void Get(Node* node) { |
| if (node == nullptr) { |
| node_->type_key.clear(); |
| return; |
| } |
| node_->type_key = node->type_key(); |
| // sepcially handle global object |
| auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key); |
| CHECK(f != nullptr) |
| << "Node type \'" << node_->type_key << "\' is not registered in TVM"; |
| if (f->fglobal_key != nullptr) { |
| node_->global_key = f->fglobal_key(node); |
| return; |
| } |
| node_->attrs.clear(); |
| node_->data.clear(); |
| if (node->is_type<ArrayNode>()) { |
| ArrayNode* n = static_cast<ArrayNode*>(node); |
| for (size_t i = 0; i < n->data.size(); ++i) { |
| node_->data.push_back( |
| node_index_->at(n->data[i].get())); |
| } |
| } else if (node->is_type<MapNode>()) { |
| MapNode* n = static_cast<MapNode*>(node); |
| for (const auto& kv : n->data) { |
| node_->data.push_back( |
| node_index_->at(kv.first.get())); |
| node_->data.push_back( |
| node_index_->at(kv.second.get())); |
| } |
| } else if (node->is_type<StrMapNode>()) { |
| StrMapNode* n = static_cast<StrMapNode*>(node); |
| for (const auto& kv : n->data) { |
| node_->keys.push_back(kv.first); |
| node_->data.push_back( |
| node_index_->at(kv.second.get())); |
| } |
| } else { |
| // do not need to recover content of global singleton object |
| // they are registered via the environment |
| auto* f = dmlc::Registry<NodeFactoryReg>::Find(node->type_key()); |
| if (f != nullptr && f->fglobal_key != nullptr) return; |
| // recursively index normal object. |
| node->VisitAttrs(this); |
| } |
| } |
| }; |
| |
| class JSONAttrSetter : public AttrVisitor { |
| public: |
| const std::vector<NodePtr<Node> >* node_list_; |
| const std::vector<runtime::NDArray>* tensor_list_; |
| JSONNode* node_; |
| |
| std::string GetValue(const char* key) const { |
| auto it = node_->attrs.find(key); |
| if (it == node_->attrs.end()) { |
| LOG(FATAL) << "JSONReader: cannot find field " << key; |
| } |
| return it->second; |
| } |
| template<typename T> |
| void ParseValue(const char* key, T* value) const { |
| std::istringstream is(GetValue(key)); |
| is >> *value; |
| if (is.fail()) { |
| LOG(FATAL) << "Wrong value format for field " << key; |
| } |
| } |
| void Visit(const char* key, double* value) final { |
| ParseValue(key, value); |
| } |
| void Visit(const char* key, int64_t* value) final { |
| ParseValue(key, value); |
| } |
| void Visit(const char* key, uint64_t* value) final { |
| ParseValue(key, value); |
| } |
| void Visit(const char* key, int* value) final { |
| ParseValue(key, value); |
| } |
| void Visit(const char* key, bool* value) final { |
| ParseValue(key, value); |
| } |
| void Visit(const char* key, std::string* value) final { |
| *value = GetValue(key); |
| } |
| void Visit(const char* key, void** value) final { |
| LOG(FATAL) << "not allowed to deserialize a pointer"; |
| } |
| void Visit(const char* key, Type* value) final { |
| std::string stype = GetValue(key); |
| *value = String2Type(stype); |
| } |
| void Visit(const char* key, NodeRef* value) final { |
| size_t index; |
| ParseValue(key, &index); |
| CHECK_LE(index, node_list_->size()); |
| value->node_ = node_list_->at(index); |
| } |
| void Visit(const char* key, runtime::NDArray* value) final { |
| size_t index; |
| ParseValue(key, &index); |
| CHECK_LE(index, tensor_list_->size()); |
| *value = tensor_list_->at(index); |
| } |
| // set node to be current JSONNode |
| void Set(Node* node) { |
| if (node == nullptr) return; |
| if (node->is_type<ArrayNode>()) { |
| ArrayNode* n = static_cast<ArrayNode*>(node); |
| n->data.clear(); |
| for (size_t index : node_->data) { |
| n->data.push_back(node_list_->at(index)); |
| } |
| } else if (node->is_type<MapNode>()) { |
| MapNode* n = static_cast<MapNode*>(node); |
| CHECK_EQ(node_->data.size() % 2, 0U); |
| for (size_t i = 0; i < node_->data.size(); i += 2) { |
| n->data[node_list_->at(node_->data[i])] |
| = node_list_->at(node_->data[i + 1]); |
| } |
| } else if (node->is_type<StrMapNode>()) { |
| StrMapNode* n = static_cast<StrMapNode*>(node); |
| CHECK_EQ(node_->data.size(), node_->keys.size()); |
| for (size_t i = 0; i < node_->data.size(); ++i) { |
| n->data[node_->keys[i]] |
| = node_list_->at(node_->data[i]); |
| } |
| } else { |
| node->VisitAttrs(this); |
| } |
| } |
| }; |
| |
| // json graph structure to store node |
| struct JSONGraph { |
| // the root of the graph |
| size_t root; |
| // the nodes of the graph |
| std::vector<JSONNode> nodes; |
| // base64 b64ndarrays of arrays |
| std::vector<std::string> b64ndarrays; |
| // global attributes |
| AttrMap attrs; |
| |
| void Save(dmlc::JSONWriter *writer) const { |
| writer->BeginObject(); |
| writer->WriteObjectKeyValue("root", root); |
| writer->WriteObjectKeyValue("nodes", nodes); |
| writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); |
| if (attrs.size() != 0) { |
| writer->WriteObjectKeyValue("attrs", attrs); |
| } |
| writer->EndObject(); |
| } |
| |
| void Load(dmlc::JSONReader *reader) { |
| attrs.clear(); |
| dmlc::JSONObjectReadHelper helper; |
| helper.DeclareField("root", &root); |
| helper.DeclareField("nodes", &nodes); |
| helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); |
| helper.DeclareOptionalField("attrs", &attrs); |
| helper.ReadAllFields(reader); |
| } |
| |
| static JSONGraph Create(const NodeRef& root) { |
| JSONGraph g; |
| NodeIndexer indexer; |
| indexer.MakeIndex(root.node_.get()); |
| JSONAttrGetter getter; |
| getter.node_index_ = &indexer.node_index; |
| getter.tensor_index_ = &indexer.tensor_index; |
| for (Node* n : indexer.node_list) { |
| JSONNode jnode; |
| getter.node_ = &jnode; |
| getter.Get(n); |
| g.nodes.emplace_back(std::move(jnode)); |
| } |
| g.attrs["tvm_version"] = TVM_VERSION; |
| g.root = indexer.node_index.at(root.node_.get()); |
| // serialize tensor |
| for (DLTensor* tensor : indexer.tensor_list) { |
| std::string blob; |
| dmlc::MemoryStringStream mstrm(&blob); |
| common::Base64OutStream b64strm(&mstrm); |
| runtime::SaveDLTensor(&b64strm, tensor); |
| b64strm.Finish(); |
| g.b64ndarrays.emplace_back(std::move(blob)); |
| } |
| return g; |
| } |
| }; |
| |
| std::string SaveJSON(const NodeRef& n) { |
| auto jgraph = JSONGraph::Create(n); |
| std::ostringstream os; |
| dmlc::JSONWriter writer(&os); |
| jgraph.Save(&writer); |
| return os.str(); |
| } |
| |
| NodePtr<Node> LoadJSON_(std::string json_str) { |
| std::istringstream is(json_str); |
| dmlc::JSONReader reader(&is); |
| JSONGraph jgraph; |
| // load in json graph. |
| jgraph.Load(&reader); |
| std::vector<NodePtr<Node> > nodes; |
| std::vector<runtime::NDArray> tensors; |
| // load in tensors |
| for (const std::string& blob : jgraph.b64ndarrays) { |
| dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob)); |
| common::Base64InStream b64strm(&mstrm); |
| b64strm.InitPosition(); |
| runtime::NDArray temp; |
| CHECK(temp.Load(&b64strm)); |
| tensors.emplace_back(temp); |
| } |
| // node 0 is always null |
| nodes.reserve(jgraph.nodes.size()); |
| for (const JSONNode& jnode : jgraph.nodes) { |
| if (jnode.type_key.length() != 0) { |
| auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key); |
| CHECK(f != nullptr) |
| << "Node type \'" << jnode.type_key << "\' is not registered in TVM"; |
| nodes.emplace_back(f->fcreator(jnode.global_key)); |
| } else { |
| nodes.emplace_back(NodePtr<Node>()); |
| } |
| } |
| CHECK_EQ(nodes.size(), jgraph.nodes.size()); |
| JSONAttrSetter setter; |
| setter.node_list_ = &nodes; |
| setter.tensor_list_ = &tensors; |
| |
| for (size_t i = 0; i < nodes.size(); ++i) { |
| setter.node_ = &jgraph.nodes[i]; |
| // do not need to recover content of global singleton object |
| // they are registered via the environment |
| if (setter.node_->global_key.length() == 0) { |
| setter.Set(nodes[i].get()); |
| } |
| } |
| return nodes.at(jgraph.root); |
| } |
| |
| class NodeAttrSetter : public AttrVisitor { |
| public: |
| std::string type_key; |
| std::unordered_map<std::string, runtime::TVMArgValue> attrs; |
| |
| void Visit(const char* key, double* value) final { |
| *value = GetAttr(key).operator double(); |
| } |
| void Visit(const char* key, int64_t* value) final { |
| *value = GetAttr(key).operator int64_t(); |
| } |
| void Visit(const char* key, uint64_t* value) final { |
| *value = GetAttr(key).operator uint64_t(); |
| } |
| void Visit(const char* key, int* value) final { |
| *value = GetAttr(key).operator int(); |
| } |
| void Visit(const char* key, bool* value) final { |
| *value = GetAttr(key).operator bool(); |
| } |
| void Visit(const char* key, std::string* value) final { |
| *value = GetAttr(key).operator std::string(); |
| } |
| void Visit(const char* key, void** value) final { |
| *value = GetAttr(key).operator void*(); |
| } |
| void Visit(const char* key, Type* value) final { |
| *value = GetAttr(key).operator Type(); |
| } |
| void Visit(const char* key, NodeRef* value) final { |
| *value = GetAttr(key).operator NodeRef(); |
| } |
| void Visit(const char* key, runtime::NDArray* value) final { |
| *value = GetAttr(key).operator runtime::NDArray(); |
| } |
| |
| private: |
| runtime::TVMArgValue GetAttr(const char* key) { |
| auto it = attrs.find(key); |
| if (it == attrs.end()) { |
| LOG(FATAL) << type_key << ": require field " << key; |
| } |
| runtime::TVMArgValue v = it->second; |
| attrs.erase(it); |
| return v; |
| } |
| }; |
| |
| |
| void InitNodeByPackedArgs(Node* n, const TVMArgs& args) { |
| NodeAttrSetter setter; |
| setter.type_key = n->type_key(); |
| CHECK_EQ(args.size() % 2, 0); |
| for (int i = 0; i < args.size(); i += 2) { |
| setter.attrs.emplace(args[i].operator std::string(), |
| args[i + 1]); |
| } |
| n->VisitAttrs(&setter); |
| if (setter.attrs.size() != 0) { |
| std::ostringstream os; |
| os << setter.type_key << " does not contain field "; |
| for (const auto &kv : setter.attrs) { |
| os << " " << kv.first; |
| } |
| LOG(FATAL) << os.str(); |
| } |
| } |
| |
| // API function to make node. |
| // args format: |
| // key1, value1, ..., key_n, value_n |
| void MakeNode(const TVMArgs& args, TVMRetValue* rv) { |
| std::string type_key = args[0]; |
| std::string empty_str; |
| auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key); |
| CHECK(f != nullptr) |
| << "Node type \'" << type_key << "\' is not registered in TVM"; |
| TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); |
| CHECK(f->fglobal_key == nullptr) |
| << "Cannot make node type \'" << type_key << "\' with global_key."; |
| NodePtr<Node> n = f->fcreator(empty_str); |
| if (n->derived_from<BaseAttrsNode>()) { |
| static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs); |
| } else { |
| InitNodeByPackedArgs(n.get(), kwargs); |
| } |
| *rv = NodeRef(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("make._Node") |
| .set_body(MakeNode); |
| } // namespace tvm |