blob: 86a11a7e5b422d56d20852c9121202ab2397ad9e [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file reflection.cc
* \brief Utilities to save/load/construct TVM objects
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/attrs.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string>
#include "../common/base64.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
::dmlc::Registry<NodeFactoryReg>* NodeFactoryReg::Registry() {
return ::dmlc::Registry<NodeFactoryReg>::Get();
}
inline std::string Type2String(const Type& t) {
return runtime::TVMType2String(Type2TVMType(t));
}
inline Type String2Type(std::string s) {
return TVMType2Type(runtime::String2TVMType(s));
}
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
public:
std::unordered_map<Node*, size_t> node_index{{nullptr, 0}};
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
void Visit(const char* key, uint64_t* value) final {}
void Visit(const char* key, int* value) final {}
void Visit(const char* key, bool* value) final {}
void Visit(const char* key, std::string* value) final {}
void Visit(const char* key, void** value) final {}
void Visit(const char* key, Type* value) final {}
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}
void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
CHECK_EQ(tensor_index.size(), tensor_list.size());
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}
// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
if (node_index.count(node)) return;
CHECK_EQ(node_index.size(), node_list.size());
node_index[node] = node_list.size();
node_list.push_back(node);
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
for (const auto& sp : n->data) {
MakeIndex(sp.get());
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
MakeIndex(kv.first.get());
MakeIndex(kv.second.get());
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
MakeIndex(kv.second.get());
}
} else {
node->VisitAttrs(this);
}
}
};
// use map so attributes are ordered.
using AttrMap = std::map<std::string, std::string>;
// A Node structure for JSON node.
struct JSONNode {
// The type key of the data
std::string type_key;
// The global key for global object
std::string global_key;
// the attributes
AttrMap attrs;
// container keys
std::vector<std::string> keys;
// container data
std::vector<size_t> data;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
if (global_key.size() != 0) {
writer->WriteObjectKeyValue("global_key", global_key);
}
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
if (keys.size() != 0) {
writer->WriteObjectKeyValue("keys", keys);
}
if (data.size() != 0) {
writer->WriteObjectKeyValue("data", data);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
attrs.clear();
data.clear();
global_key.clear();
type_key.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("global_key", &global_key);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);
}
};
class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, int64_t* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, uint64_t* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, int* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, bool* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, std::string* value) final {
node_->attrs[key] = *value;
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to serialize a pointer";
}
void Visit(const char* key, Type* value) final {
node_->attrs[key] = Type2String(*value);
}
void Visit(const char* key, NodeRef* value) final {
node_->attrs[key] = std::to_string(
node_index_->at(value->node_.get()));
}
void Visit(const char* key, runtime::NDArray* value) final {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
node_->type_key.clear();
return;
}
node_->type_key = node->type_key();
// sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
CHECK(f != nullptr)
<< "Node type \'" << node_->type_key << "\' is not registered in TVM";
if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node);
return;
}
node_->attrs.clear();
node_->data.clear();
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
for (size_t i = 0; i < n->data.size(); ++i) {
node_->data.push_back(
node_index_->at(n->data[i].get()));
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
node_->data.push_back(
node_index_->at(kv.first.get()));
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
node_->keys.push_back(kv.first);
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else {
// do not need to recover content of global singleton object
// they are registered via the environment
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node->type_key());
if (f != nullptr && f->fglobal_key != nullptr) return;
// recursively index normal object.
node->VisitAttrs(this);
}
}
};
class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
JSONNode* node_;
std::string GetValue(const char* key) const {
auto it = node_->attrs.find(key);
if (it == node_->attrs.end()) {
LOG(FATAL) << "JSONReader: cannot find field " << key;
}
return it->second;
}
template<typename T>
void ParseValue(const char* key, T* value) const {
std::istringstream is(GetValue(key));
is >> *value;
if (is.fail()) {
LOG(FATAL) << "Wrong value format for field " << key;
}
}
void Visit(const char* key, double* value) final {
ParseValue(key, value);
}
void Visit(const char* key, int64_t* value) final {
ParseValue(key, value);
}
void Visit(const char* key, uint64_t* value) final {
ParseValue(key, value);
}
void Visit(const char* key, int* value) final {
ParseValue(key, value);
}
void Visit(const char* key, bool* value) final {
ParseValue(key, value);
}
void Visit(const char* key, std::string* value) final {
*value = GetValue(key);
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to deserialize a pointer";
}
void Visit(const char* key, Type* value) final {
std::string stype = GetValue(key);
*value = String2Type(stype);
}
void Visit(const char* key, NodeRef* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, node_list_->size());
value->node_ = node_list_->at(index);
}
void Visit(const char* key, runtime::NDArray* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) {
if (node == nullptr) return;
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
n->data.clear();
for (size_t index : node_->data) {
n->data.push_back(node_list_->at(index));
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
CHECK_EQ(node_->data.size() % 2, 0U);
for (size_t i = 0; i < node_->data.size(); i += 2) {
n->data[node_list_->at(node_->data[i])]
= node_list_->at(node_->data[i + 1]);
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
CHECK_EQ(node_->data.size(), node_->keys.size());
for (size_t i = 0; i < node_->data.size(); ++i) {
n->data[node_->keys[i]]
= node_list_->at(node_->data[i]);
}
} else {
node->VisitAttrs(this);
}
}
};
// json graph structure to store node
struct JSONGraph {
// the root of the graph
size_t root;
// the nodes of the graph
std::vector<JSONNode> nodes;
// base64 b64ndarrays of arrays
std::vector<std::string> b64ndarrays;
// global attributes
AttrMap attrs;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
attrs.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("root", &root);
helper.DeclareField("nodes", &nodes);
helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader);
}
static JSONGraph Create(const NodeRef& root) {
JSONGraph g;
NodeIndexer indexer;
indexer.MakeIndex(root.node_.get());
JSONAttrGetter getter;
getter.node_index_ = &indexer.node_index;
getter.tensor_index_ = &indexer.tensor_index;
for (Node* n : indexer.node_list) {
JSONNode jnode;
getter.node_ = &jnode;
getter.Get(n);
g.nodes.emplace_back(std::move(jnode));
}
g.attrs["tvm_version"] = TVM_VERSION;
g.root = indexer.node_index.at(root.node_.get());
// serialize tensor
for (DLTensor* tensor : indexer.tensor_list) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
common::Base64OutStream b64strm(&mstrm);
runtime::SaveDLTensor(&b64strm, tensor);
b64strm.Finish();
g.b64ndarrays.emplace_back(std::move(blob));
}
return g;
}
};
std::string SaveJSON(const NodeRef& n) {
auto jgraph = JSONGraph::Create(n);
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
return os.str();
}
NodePtr<Node> LoadJSON_(std::string json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
std::vector<NodePtr<Node> > nodes;
std::vector<runtime::NDArray> tensors;
// load in tensors
for (const std::string& blob : jgraph.b64ndarrays) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
common::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
runtime::NDArray temp;
CHECK(temp.Load(&b64strm));
tensors.emplace_back(temp);
}
// node 0 is always null
nodes.reserve(jgraph.nodes.size());
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key);
CHECK(f != nullptr)
<< "Node type \'" << jnode.type_key << "\' is not registered in TVM";
nodes.emplace_back(f->fcreator(jnode.global_key));
} else {
nodes.emplace_back(NodePtr<Node>());
}
}
CHECK_EQ(nodes.size(), jgraph.nodes.size());
JSONAttrSetter setter;
setter.node_list_ = &nodes;
setter.tensor_list_ = &tensors;
for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
// do not need to recover content of global singleton object
// they are registered via the environment
if (setter.node_->global_key.length() == 0) {
setter.Set(nodes[i].get());
}
}
return nodes.at(jgraph.root);
}
class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
void Visit(const char* key, double* value) final {
*value = GetAttr(key).operator double();
}
void Visit(const char* key, int64_t* value) final {
*value = GetAttr(key).operator int64_t();
}
void Visit(const char* key, uint64_t* value) final {
*value = GetAttr(key).operator uint64_t();
}
void Visit(const char* key, int* value) final {
*value = GetAttr(key).operator int();
}
void Visit(const char* key, bool* value) final {
*value = GetAttr(key).operator bool();
}
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
void Visit(const char* key, void** value) final {
*value = GetAttr(key).operator void*();
}
void Visit(const char* key, Type* value) final {
*value = GetAttr(key).operator Type();
}
void Visit(const char* key, NodeRef* value) final {
*value = GetAttr(key).operator NodeRef();
}
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
}
};
void InitNodeByPackedArgs(Node* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = n->type_key();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(),
args[i + 1]);
}
n->VisitAttrs(&setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
os << setter.type_key << " does not contain field ";
for (const auto &kv : setter.attrs) {
os << " " << kv.first;
}
LOG(FATAL) << os.str();
}
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
std::string empty_str;
auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key);
CHECK(f != nullptr)
<< "Node type \'" << type_key << "\' is not registered in TVM";
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
CHECK(f->fglobal_key == nullptr)
<< "Cannot make node type \'" << type_key << "\' with global_key.";
NodePtr<Node> n = f->fcreator(empty_str);
if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
}
*rv = NodeRef(n);
}
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace tvm