blob: ec82c91bb6522ced00bbd236f94a4303cdf24c91 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Reflection utilities.
* \file node/reflection.cc
*/
#include <tvm/ir/attrs.h>
#include <tvm/node/container.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/registry.h>
namespace tvm {
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
// Attr getter.
class AttrGetter : public AttrVisitor {
public:
const String& skey;
TVMRetValue* ret;
AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {}
bool found_ref_object{false};
void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, int64_t* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, int* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
void Visit(const char* key, DataType* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
void Visit(const char* key, runtime::ObjectRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};
runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const {
runtime::TVMRetValue ret;
AttrGetter getter(field_name, &ret);
bool success;
if (getter.skey == "type_key") {
ret = self->GetTypeKey();
success = true;
} else if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &getter);
success = getter.found_ref_object || ret.type_code() != kTVMNullptr;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
auto it = dnode->dict.find(getter.skey);
if (it != dnode->dict.end()) {
success = true;
ret = (*it).second;
} else {
success = false;
}
}
if (!success) {
LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed "
<< getter.skey;
}
return ret;
}
// List names;
class AttrDir : public AttrVisitor {
public:
std::vector<std::string>* names;
void Visit(const char* key, double* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { names->push_back(key); }
void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, bool* value) final { names->push_back(key); }
void Visit(const char* key, int* value) final { names->push_back(key); }
void Visit(const char* key, void** value) final { names->push_back(key); }
void Visit(const char* key, DataType* value) final { names->push_back(key); }
void Visit(const char* key, std::string* value) final { names->push_back(key); }
void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
};
std::vector<std::string> ReflectionVTable::ListAttrNames(Object* self) const {
std::vector<std::string> names;
AttrDir dir;
dir.names = &names;
if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
for (const auto& kv : dnode->dict) {
names.push_back(kv.first);
}
}
return names;
}
ReflectionVTable* ReflectionVTable::Global() {
static ReflectionVTable inst;
return &inst;
}
ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](repr_bytes);
}
class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
}
};
void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = n->GetTypeKey();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(), args[i + 1]);
}
reflection->VisitAttrs(n, &setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
os << setter.type_key << " does not contain field ";
for (const auto& kv : setter.attrs) {
os << " " << kv.first;
}
LOG(FATAL) << os.str();
}
}
ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) {
ObjectPtr<Object> n = this->CreateInitObject(type_key);
if (n->IsInstance<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(this, n.get(), kwargs);
}
return ObjectRef(n);
}
ObjectRef ReflectionVTable::CreateObject(const std::string& type_key,
const Map<String, ObjectRef>& kwargs) {
// Redirect to the TVMArgs version
// It is not the most efficient way, but CreateObject is not meant to be used
// in a fast code-path and is mainly reserved as a flexible API for frontends.
std::vector<TVMValue> values(kwargs.size() * 2);
std::vector<int32_t> tcodes(kwargs.size() * 2);
runtime::TVMArgsSetter setter(values.data(), tcodes.data());
int index = 0;
for (const auto& kv : *static_cast<const MapNode*>(kwargs.get())) {
setter(index, Downcast<String>(kv.first).c_str());
setter(index + 1, kv.second);
index += 2;
}
return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2));
}
// Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
*ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
}
void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
auto names =
std::make_shared<std::vector<std::string> >(ReflectionVTable::Global()->ListAttrNames(self));
*ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) {
int64_t i = args[0];
if (i == -1) {
*rv = static_cast<int64_t>(names->size());
} else {
*rv = (*names)[i];
}
});
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
std::string empty_str;
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
*rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs);
}
TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
} // namespace tvm