blob: c5a866009add8af0a8a3c2f118d7ab669e17fd21 [file] [log] [blame]
* Copyright (c) 2015 by Contributors
* \file executor.h
* \brief Rcpp Symbol of MXNet.
#include <Rcpp.h>
#include <string>
#include <algorithm>
#include "./base.h"
#include "./executor.h"
#include "./ndarray.h"
#include "./symbol.h"
namespace mxnet {
namespace R {
void Executor::UpdateArgArray(const Rcpp::List& array,
bool match_name,
bool skip_null) {
UpdateArray("arg.arrays", array, arg_arrays_, match_name, skip_null);
void Executor::UpdateAuxArray(const Rcpp::List& array,
bool match_name,
bool skip_null) {
UpdateArray("aux.arrays", array, aux_arrays_, match_name, skip_null);
void Executor::UpdateGradArray(const Rcpp::List& array,
bool match_name,
bool skip_null) {
UpdateArray("grad.arrays", array, grad_arrays_, match_name, skip_null);
void Executor::UpdateArray(const char* array_name,
const Rcpp::List& from,
Rcpp::List* to,
bool match_name,
bool skip_null) {
if (!match_name) {
RCHECK(from.size() == to->size())
<< "Update array list must contain names";
for (size_t i = 0; i < from.size(); ++i) {
if (to->at(i) != R_NilValue) {
if (from[i] != R_NilValue) {
NDArray dst = NDArray::FromRObject(to->at(i));
NDArray::CopyFromTo(NDArray::FromRObject(from[i]), &dst);
} else {
<< "Position " << i << " expected to be not NULL";
} else {
RCHECK(from[i] == R_NilValue)
<< "Position " << i << " expected to be NULL";
} else {
if (from.size() == 0) return;
<< " is set to TRUE, the input list must have names in all elements";
std::vector<std::string> names = from.names();
for (size_t i = 0; i < names.size(); ++i) {
RCHECK(names[i].length() != 0)
<< " is set to TRUE, the input list must have names in all elements";
<< "cannot find key " << names[i] << " in the array " << array_name;
int index = to->findName(names[i]);
if (to->at(index) != R_NilValue) {
if (from[i] != R_NilValue) {
NDArray dst = NDArray::FromRObject(to->at(index));
NDArray::CopyFromTo(NDArray::FromRObject(from[i]), &dst);
} else {
<< "Element " << names[i] << " expected to be not NULL";
} else {
RCHECK(from[i] == R_NilValue)
<< "Element " << names[i] << " expected to be NULL";
Rcpp::List Executor::CloneArray(const Rcpp::List& src) {
Rcpp::List ret(src.size());
ret.names() = src.names();
for (size_t i = 0; i < src.size(); ++i) {
if (src[i] != R_NilValue) {
<< "Expected exec to be "<< Executor::TypeName();
ret[i] = NDArray::FromRObject(src[i]).Clone().RObject();
} else {
ret[i] = R_NilValue;
return ret;
void Executor::Forward(bool is_train,
const Rcpp::List& kwargs) {
MX_CALL(MXExecutorForward(handle_, is_train));
void Executor::Backward(const Rcpp::List &output_grads) {
RCHECK(grad_arrays_ != nullptr)
<< "This executor has not been bound with req.grad";
std::vector<NDArrayHandle> grad_handles
= NDArray::GetHandles(output_grads, "output_grads", false);
inline Rcpp::List* CreateArrayList(const Rcpp::List& source_array,
const std::string& key,
const std::vector<std::string>& names,
const Context::RObjectType& ctx,
std::vector<NDArrayHandle>* handles) {
Rcpp::List* ret = new Rcpp::List(source_array.size());
try {
ret->names() = names;
for (size_t i = 0; i < source_array.size(); ++i) {
<< "Expect input " << key << " to be list of " << NDArray::TypeName();
NDArray src = NDArray::FromRObject(source_array[i]);
ret->at(i) = NDArray::Empty(src.dim(), ctx);
NDArray dst = NDArray::FromRObject(ret->at(i));
handles->at(i) = dst->handle;
NDArray::CopyFromTo(src, &dst);
} catch(const Rcpp::exception& ex) {
delete ret;
throw ex;
return ret;
inline Rcpp::List* CreateGradList(const Rcpp::List& source_array,
const Rcpp::List& grad_reqs,
const std::vector<std::string>& names,
const Context::RObjectType& ctx,
std::vector<NDArrayHandle> *handles,
std::vector<mx_uint> *grad_req_type) {
Rcpp::List* ret = new Rcpp::List(grad_reqs.size(), R_NilValue);
try {
ret->names() = names;
handles->resize(grad_reqs.size(), nullptr);
grad_req_type->resize(grad_reqs.size(), 0);
std::map<std::string, int> req_map;
req_map["null"] = 0;
req_map["write"] = 1;
req_map["add"] = 3;
for (size_t i = 0; i < grad_reqs.size(); ++i) {
if (Rcpp::as<std::string>(grad_reqs[i]) != "null"
&& Rcpp::as<std::string>(grad_reqs[i]) != "write"
&& Rcpp::as<std::string>(grad_reqs[i]) != "add") {
RLOG_FATAL << "grad_req must be one of 'null', 'write' or 'add'";
if (Rcpp::as<std::string>(grad_reqs[i]) != "null") {
ret->at(i) = NDArray::Empty(NDArray::FromRObject(source_array[i]).dim(), ctx);
handles->at(i) = NDArray::FromRObject(ret->at(i))->handle;
grad_req_type->at(i) = req_map[Rcpp::as<std::string>(grad_reqs[i])];
} catch(const Rcpp::exception& ex) {
delete ret;
throw ex;
return ret;
inline Rcpp::List* CreateOutList(mx_uint out_size,
NDArrayHandle *out_arr,
const std::vector<std::string>& names) {
Rcpp::List* ret = new Rcpp::List(out_size);
try {
ret->names() = names;
for (size_t i = 0; i < out_size; ++i) {
ret->at(i) = NDArray::RObject(out_arr[i], false);
} catch(const Rcpp::exception& ex) {
delete ret;
throw ex;
return ret;
Executor::RObjectType Executor::Bind(const Symbol::RObjectType& symbol,
const Context::RObjectType& context,
const Rcpp::List& arg_arrays,
const Rcpp::List& aux_arrays,
const Rcpp::List& grad_reqs) {
Executor* exec = new Executor();
try {
Symbol *sym = Symbol::XPtr(symbol);
// handles
std::vector<mx_uint> grad_req_type;
std::vector<NDArrayHandle> arg_handles, grad_handles, aux_handles;
// for failure handling
exec->arg_arrays_ = CreateArrayList(
arg_arrays, "arg_arrays",
context, &arg_handles);
exec->aux_arrays_ = CreateArrayList(
aux_arrays, "aux_arrays",
context, &aux_handles);
exec->grad_arrays_ = CreateGradList(
arg_arrays, grad_reqs,
context, &grad_handles, &grad_req_type);
Context ctx(context);
ctx.dev_type, ctx.dev_id,
static_cast<mx_uint>(arg_handles.size()), dmlc::BeginPtr(arg_handles),
dmlc::BeginPtr(grad_handles), dmlc::BeginPtr(grad_req_type),
static_cast<mx_uint>(aux_handles.size()), dmlc::BeginPtr(aux_handles),
mx_uint out_size;
NDArrayHandle *out_arr;
MX_CALL(MXExecutorOutputs(exec->handle_, &out_size, &out_arr));
exec->out_arrays_ = CreateOutList(
out_size, out_arr, sym->ListOuputs());
} catch(const Rcpp::exception& ex) {
delete exec;
throw ex;
return Rcpp::internal::make_new_object(exec);
void Executor::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
"Update auxilary states array of executor, this will mutate the executor")
"Update arguments array of executor, this will mutate the executor")
"Update gradient array of executor, this will mutate the executor")
"Peform a forward operation on exec, this will set the outputs.")
"Peform a backward operation on exec, this will set the gradients requested.")
.property("ref.arg.arrays", &Executor::arg_arrays)
.property("ref.grad.arrays", &Executor::grad_arrays)
.property("ref.aux.arrays", &Executor::aux_arrays)
.property("ref.outputs", &Executor::out_arrays)
.property("arg.arrays", &Executor::GetArgArrays)
.property("grad.arrays", &Executor::GetGradArrays)
.property("aux.arrays", &Executor::GetAuxArrays)
.property("outputs", &Executor::GetOuputArrays);
List::create(_["symbol"], _["ctx"],
_["arg.arrays"], _["aux.arrays"], _["grad.reqs"]),
"Bind the symbol on argument arrays, generate gradient array according to grad_reqs");
} // namespace R
} // namespace mxnet