blob: b5d6eca5fbddb8d975def9538b996a8c3b507e77 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file symbol.cc
* \brief Rcpp Symbol of MXNet.
*/
#include <Rcpp.h>
#include <string>
#include <algorithm>
#include "./base.h"
#include "./symbol.h"
#include "./name.h"
#include "./export.h"
namespace mxnet {
namespace R {
NameManager* NameManager::Get() {
static NameManager inst;
return &inst;
}
inline Symbol::RObjectType Symbol::Clone() const {
SymbolHandle ohandle;
MX_CALL(MXSymbolCopy(handle_, &ohandle));
return Symbol::RObject(ohandle);
}
Symbol::RObjectType Symbol::Apply(const Rcpp::List& kwargs) const {
RObjectType ret = this->Clone();
if (kwargs.containsElementNamed("name")) {
int index = kwargs.findName("name");
std::string name = kwargs[index];
Rcpp::List kw(kwargs);
kw.erase(index);
Symbol::XPtr(ret)->Compose(kw, name);
} else {
std::string name;
Symbol::XPtr(ret)->Compose(kwargs, name);
}
return ret;
}
std::string Symbol::DebugStr() const {
const char *str;
MX_CALL(MXSymbolPrint(handle_, &str));
return str;
}
void Symbol::Compose(const Rcpp::List& kwargs, const std::string &name) {
std::string target_name;
std::vector<std::string> keys = SafeGetListNames(kwargs);
// get names
bool positional = keys.size() == 0 || keys[0].length() == 0;
for (size_t i = 0; i < keys.size(); ++i) {
RCHECK((keys[i].length() == 0) == positional)
<< "Input symbols need to be either positional or key=value style, not both\n";
}
if (positional) keys.resize(0);
// string parameter keys
std::vector<const char*> c_keys = CKeys(keys);
// string parameter values
std::vector<SymbolHandle> handles(kwargs.size());
for (size_t i = 0; i < kwargs.size(); ++i) {
handles[i] = Symbol::XPtr(kwargs[i])->handle_;
}
MX_CALL(NNSymbolCompose(
handle_, name.c_str(),
static_cast<mx_uint>(handles.size()),
dmlc::BeginPtr(c_keys), dmlc::BeginPtr(handles)));
}
std::vector<std::string> Symbol::ListArguments() const {
mx_uint size;
const char **ret;
MX_CALL(MXSymbolListArguments(handle_, &size, &ret));
return std::vector<std::string>(ret, ret + size);
}
std::vector<std::string> Symbol::ListAuxiliaryStates() const {
mx_uint size;
const char **ret;
MX_CALL(MXSymbolListAuxiliaryStates(handle_, &size, &ret));
return std::vector<std::string>(ret, ret + size);
}
std::vector<std::string> Symbol::ListOuputs() const {
mx_uint size;
const char **ret;
MX_CALL(MXSymbolListOutputs(handle_, &size, &ret));
return std::vector<std::string>(ret, ret + size);
}
Rcpp::List Symbol::getAttrs() const {
mx_uint size;
const char **ret;
MX_CALL(MXSymbolListAttrShallow(handle_, &size, &ret));
std::vector<std::string> key_values(ret, ret + 2*size);
// fill return list
Rcpp::List list;
for (size_t i = 0; i < size; i++) {
list[key_values[2*i]] = key_values[2*i+1];
}
return list;
}
void Symbol::setAttrs(Rcpp::List attr) {
RCHECK(HasName(attr))
<< "Need to pass parameters in list of key=value style.\n";
std::vector<std::string> keys = attr.names();
for (size_t i = 0; i < attr.size(); i++) {
RCHECK(TYPEOF(attr[i]) == STRSXP) << "Attribute values must be characters.\n";
}
for (size_t i = 0; i < attr.size(); i++) {
MX_CALL(MXSymbolSetAttr(handle_, keys[i].c_str(),
Rcpp::as<std::string>(attr[i]).c_str() ));
}
}
void Symbol::Save(const std::string& fname) const {
MX_CALL(MXSymbolSaveToFile(handle_, fname.c_str()));
}
std::string Symbol::AsJSON() const {
const char *json;
MX_CALL(MXSymbolSaveToJSON(handle_, &json));
return json;
}
Symbol::RObjectType Symbol::GetInternals() const {
SymbolHandle out;
MX_CALL(MXSymbolGetInternals(handle_, &out));
return Symbol::RObject(out);
}
Symbol::RObjectType Symbol::GetChildren() const {
SymbolHandle out;
MX_CALL(MXSymbolGetChildren(handle_, &out));
return Symbol::RObject(out);
}
Symbol::RObjectType Symbol::GetOutput(mx_uint index) const {
SymbolHandle out;
MX_CALL(MXSymbolGetOutput(handle_, index - 1, &out));
return Symbol::RObject(out);
}
// helper function to convert shape into Rcpp vector
inline Rcpp::List BuildShapeData(mx_uint shape_size,
const mx_uint *shape_ndim,
const mx_uint **shape_data,
const std::vector<std::string> &names) {
Rcpp::List ret(shape_size);
for (mx_uint i = 0; i < shape_size; ++i) {
Rcpp::IntegerVector dim(shape_data[i], shape_data[i] + shape_ndim[i]);
std::reverse(dim.begin(), dim.end());
ret[i] = dim;
}
ret.names() = names;
return ret;
}
SEXP Symbol::InferShape(const Rcpp::List& kwargs) const {
RCHECK(HasName(kwargs))
<< "Need to pass parameters in key=value style.\n";
std::vector<std::string> keys = kwargs.names();
std::vector<mx_uint> arg_ind_ptr(1, 0);
std::vector<mx_uint> arg_shape_data;
for (size_t i = 0; i < kwargs.size(); ++i) {
RCHECK(keys[i].length() != 0)
<< "Need to pass parameters in key=value style.\n";
std::vector<mx_uint> dim = Dim2InternalShape(kwargs[i]);
arg_shape_data.insert(arg_shape_data.end(), dim.begin(), dim.end());
arg_ind_ptr.push_back(static_cast<mx_uint>(arg_shape_data.size()));
}
std::vector<const char*> c_keys = CKeys(keys);
mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
int complete;
MX_CALL(MXSymbolInferShape(
handle_, static_cast<mx_uint>(kwargs.size()), dmlc::BeginPtr(c_keys),
dmlc::BeginPtr(arg_ind_ptr), dmlc::BeginPtr(arg_shape_data),
&in_shape_size, &in_shape_ndim, &in_shape_data,
&out_shape_size, &out_shape_ndim, &out_shape_data,
&aux_shape_size, &aux_shape_ndim, &aux_shape_data,
&complete));
if (complete != 0) {
return Rcpp::List::create(
Rcpp::Named("arg.shapes") = BuildShapeData(
in_shape_size, in_shape_ndim, in_shape_data, ListArguments()),
Rcpp::Named("out.shapes") = BuildShapeData(
out_shape_size, out_shape_ndim, out_shape_data, ListOuputs()),
Rcpp::Named("aux.shapes") = BuildShapeData(
aux_shape_size, aux_shape_ndim, aux_shape_data, ListAuxiliaryStates()));
} else {
return R_NilValue;
}
}
Symbol::RObjectType Symbol::Variable(const std::string& name) {
SymbolHandle out;
MX_CALL(MXSymbolCreateVariable(name.c_str(), &out));
return Symbol::RObject(out);
}
Symbol::RObjectType Symbol::Load(const std::string& filename) {
SymbolHandle out;
MX_CALL(MXSymbolCreateFromFile(filename.c_str(), &out));
return Symbol::RObject(out);
}
Symbol::RObjectType Symbol::LoadJSON(const std::string& json) {
SymbolHandle out;
MX_CALL(MXSymbolCreateFromJSON(json.c_str(), &out));
return Symbol::RObject(out);
}
Symbol::RObjectType Symbol::Group(const Rcpp::List& symbols) {
// allow pass in single list
Rcpp::List kwargs = symbols;
if (symbols.size() == 1 && Rcpp::is<Rcpp::List>(symbols[0])) {
kwargs = symbols[0];
}
std::vector<SymbolHandle> handles(kwargs.size());
for (size_t i = 0; i < kwargs.size(); ++i) {
RCHECK(Rcpp::is<Symbol>(kwargs[i]))
<< "Group only accept MXSymbol as input\n";
handles[i] = Symbol::XPtr(kwargs[i])->handle_;
}
SymbolHandle out;
MX_CALL(MXSymbolCreateGroup(static_cast<mx_uint>(handles.size()),
dmlc::BeginPtr(handles), &out));
return Symbol::RObject(out);
}
SymbolFunction::SymbolFunction(OpHandle handle, std::string name)
: handle_(handle) {
const char* real_name;
const char* description;
mx_uint num_args;
const char **arg_names;
const char **arg_type_infos;
const char **arg_descriptions;
const char *key_var_num_args;
const char *ret_type;
MX_CALL(MXSymbolGetAtomicSymbolInfo(
handle_, &real_name, &description, &num_args,
&arg_names, &arg_type_infos, &arg_descriptions,
&key_var_num_args, &ret_type));
if (key_var_num_args != nullptr) {
key_var_num_args_ = key_var_num_args;
}
name_hint_ = name;
std::transform(name_hint_.begin(), name_hint_.end(),
name_hint_.begin(), ::tolower);
if (name[0] == '_') {
name_ = std::string("mx.varg.symbol.internal.") + (name.c_str() + 1);
} else {
name_ = std::string("mx.varg.symbol.") + name;
}
std::ostringstream os;
os << name << ':' << description << "\n\n"
<< MakeDocString(num_args, arg_names, arg_type_infos, arg_descriptions)
<< "@param name string, optional\n"
<< " Name of the resulting symbol.\n"
<< "@return out The result mx.symbol\n\n"
<< "@export\n";
this->docstring = os.str();
}
SEXP SymbolFunction::operator() (SEXP* args) {
BEGIN_RCPP;
Rcpp::List kwargs(args[0]);
std::vector<std::string> keys = SafeGetListNames(kwargs);
// string key and values
std::vector<std::string> str_keys;
std::vector<std::string> str_vals;
// symbol key and values
std::vector<std::string> sym_keys;
std::vector<Rcpp::RObject> sym_vals;
// name of the result
std::string name;
// classify keys
for (size_t i = 0; i < kwargs.size(); ++i) {
if (keys[i] == "name") {
name = Rcpp::as<std::string>(kwargs[i]);
continue;
}
if (!isSimple(kwargs[i]) && Rcpp::is<Symbol>(kwargs[i])) {
sym_keys.push_back(keys[i]);
sym_vals.push_back(kwargs[i]);
} else {
RCHECK(keys[i].length() != 0)
<< "Non Symbol parameters is only accepted via key=value style.";
str_keys.push_back(FormatParamKey(keys[i]));
str_vals.push_back(toPyString(keys[i], kwargs[i]));
}
}
SymbolHandle shandle;
std::vector<const char*> c_str_keys = CKeys(str_keys);
std::vector<const char*> c_str_vals = CKeys(str_vals);
MX_CALL(NNSymbolCreateAtomicSymbol(
handle_, static_cast<mx_uint>(str_keys.size()),
dmlc::BeginPtr(c_str_keys),
dmlc::BeginPtr(c_str_vals),
&shandle));
Symbol::RObjectType ret = Symbol::RObject(shandle);
Rcpp::List compose_args = Rcpp::wrap(sym_vals);
compose_args.names() = sym_keys;
name = NameManager::Get()->GetName(name, name_hint_);
Symbol::XPtr(ret)->Compose(compose_args, name);
return ret;
END_RCPP;
}
void Symbol::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
class_<Symbol>("MXSymbol")
.method("debug.str", &Symbol::DebugStr,
"Return the debug string of internals of symbol")
.method("apply", &Symbol::Apply,
"Return a new Symbol by applying current symbols into input")
.method("as.json", &Symbol::AsJSON,
"Return a json string representation of symbol")
.method("save", &Symbol::Save,
"Save symbol to file")
.property("arguments", &Symbol::ListArguments,
"List the arguments names of the symbol")
.property("attributes", &Symbol::getAttrs, &Symbol::setAttrs,
"Attributes of the symbol. Specified as named list.")
.property("outputs", &Symbol::ListOuputs,
"List the outputs names of the symbol")
.property("auxiliary.states", &Symbol::ListAuxiliaryStates,
"List the auxiliary state names of the symbol")
.method("get.internals", &Symbol::GetInternals,
"Get a symbol that contains all the internals")
.method("get.children", &Symbol::GetChildren,
"Get a symbol that contains all the children")
.method("get.output", &Symbol::GetOutput,
"Get index-th output symbol of current one")
.method("[[", &Symbol::GetOutput,
"Get index-th output symbol of current one")
.method("infer.shape", &Symbol::InferShape,
"Inference the shape information given unknown ones");
function("mx.symbol.Variable",
&Symbol::Variable,
List::create(_["name"]),
"Create a symbolic variable with specified name.");
function("mx.symbol.load",
&Symbol::Load,
List::create(_["file.name"]),
"Load a symbol from file.");
function("mx.symbol.load.json",
&Symbol::LoadJSON,
List::create(_["json.str"]),
"Load a symbol from json string.");
function("mx.varg.symbol.internal.Group",
&Symbol::Group,
List::create(_["slist"]),
"Create a symbol that groups symbols together.");
}
void SymbolFunction::InitRcppModule() {
Rcpp::Module* scope = ::getCurrentScope();
RCHECK(scope != nullptr)
<< "Init Module need to be called inside scope";
mx_uint out_size;
const char** op_name_ptrs;
std::vector<std::string> op_names;
MX_CALL(MXListAllOpNames(&out_size, &op_name_ptrs));
for (size_t i = 0; i < out_size; ++i) {
op_names.push_back(std::string(op_name_ptrs[i]));
}
for (int i = 0; i < out_size; ++i) {
OpHandle handle;
MX_CALL(NNGetOpHandle(op_names[i].c_str(), &handle));
SymbolFunction *f = new SymbolFunction(handle, op_names[i]);
scope->Add(f->get_name(), f);
}
}
} // namespace R
} // namespace mxnet