blob: 1c2c294a5f3065d51fc150e6569c7708fa92a0fa [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* Implementation of DSL API
* \file dsl_api.cc
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <vector>
#include <string>
#include <exception>
#include "../runtime/dsl_api.h"
namespace tvm {
namespace runtime {
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief result holder for retruning string */
std::string ret_str;
};
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
using TVMAPINode = NodePtr<Node>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
TVMRetValue* ret;
bool found_ref_object{false};
void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, int64_t* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, int* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
void Visit(const char* key, Type* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, NodeRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};
struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names;
void Visit(const char* key, double* value) final {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, void** value) final {
names->push_back(key);
}
void Visit(const char* key, Type* value) final {
names->push_back(key);
}
void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, NodeRef* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
};
class DSLAPIImpl : public DSLAPI {
public:
void NodeFree(NodeHandle handle) const final {
delete static_cast<TVMAPINode*>(handle);
}
void NodeTypeKey2Index(const char* type_key,
int* out_index) const final {
*out_index = static_cast<int>(Node::TypeKey2Index(type_key));
}
void NodeGetTypeIndex(NodeHandle handle,
int* out_index) const final {
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
}
void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) const final {
TVMRetValue rv;
APIAttrGetter getter;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
getter.skey = key;
getter.ret = &rv;
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_type_code = kStr;
*ret_success = 1;
return;
} else if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_ref_object || rv.type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
auto it = dnode->dict.find(key);
if (it != dnode->dict.end()) {
*ret_success = 1;
rv = (*it).second;
} else {
*ret_success = 0;
}
}
if (*ret_success) {
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
}
}
void NodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) const final {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
for (const auto& kv : dnode->dict) {
ret->ret_vec_str.push_back(kv.first);
}
}
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
}
};
TVM_REGISTER_GLOBAL("dsl_api.singleton")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static DSLAPIImpl impl;
void* ptr = &impl;
*rv = ptr;
});
} // namespace runtime
} // namespace tvm