blob: b7db7c6f41def54ded4beae1724c8014b45b1282 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file executor.h
* \brief Rcpp Symbolic Execution interface of MXNet
*/
#ifndef MXNET_RCPP_EXECUTOR_H_
#define MXNET_RCPP_EXECUTOR_H_
#include <Rcpp.h>
#include <mxnet/c_api.h>
#include <string>
#include "./base.h"
#include "./symbol.h"
namespace mxnet {
namespace R {
/*! \brief The Rcpp Symbol class of MXNet */
class Executor : public MXNetMovable<Executor> {
public:
/*! \return typename from R side. */
inline static const char* TypeName() {
return "MXExecutor";
}
/*!
* \return Get reference of the arg arrays of executor.
*/
const Rcpp::List& arg_arrays() const {
return *arg_arrays_;
}
/*!
* \return Get reference of the aux arrays of executor.
*/
const Rcpp::List& aux_arrays() const {
return *aux_arrays_;
}
/*!
* \return Get reference of gradient arrays of executor.
*/
const Rcpp::List& grad_arrays() const {
return *grad_arrays_;
}
/*!
* \return Get reference of gradient arrays of executor.
*/
const Rcpp::List& out_arrays() const {
return *out_arrays_;
}
/*!
* \return Get the arg arrays of executor.
*/
Rcpp::List GetArgArrays() const {
return CloneArray(*arg_arrays_);
}
/*!
* \return Get the grad arrays of executor.
*/
Rcpp::List GetGradArrays() const {
return CloneArray(*grad_arrays_);
}
/*!
* \return Get the auxiliary arrays of executor.
*/
Rcpp::List GetAuxArrays() const {
return CloneArray(*aux_arrays_);
}
/*!
* \return Get the outputx arrays of executor.
*/
Rcpp::List GetOuputArrays() const {
return CloneArray(*out_arrays_);
}
/*!
* \brief Update the arg_arrays of executor, based on name-matching.
* \param array The array to update
* \param match_name whether to use name to match the input, instead of index.
* \param skip_null Whether null is allowed, when there is NULL in the array, simply ignore.
* \return a result executor, moved from exec.
*/
void UpdateArgArray(const Rcpp::List& array,
bool match_name,
bool allow_null);
/*!
* \brief Update the aux_arrays of executor, based on name-matching.
* \param array The array to update
* \param match_name whether to use name to match the input, instead of index.
* \param skip_null Whether null is allowed, when there is NULL in the array, simply ignore.
* \return a result executor, moved from exec.
*/
void UpdateAuxArray(const Rcpp::List& array,
bool match_name,
bool allow_null);
/*!
* \brief Update the grad_arrays of executor, based on name-matching.
* \param array The array to update
* \param match_name whether to use name to match the input, instead of index.
* \param skip_null Whether null is allowed, when there is NULL in the array, simply ignore.
* \return a result executor, moved from exec.
*/
void UpdateGradArray(const Rcpp::List& array,
bool match_name,
bool allow_null);
/*!
* \brief Peform a forward operation on exec, this will set the out_arrays.
* \param is_train whether it is training phase.
* \param kwargs additional parameters.
* \return a result executor, moved from exec.
*/
void Forward(bool is_train,
const Rcpp::List& kwargs);
/*!
* \brief Peform a backward operation on exec, this will set the grad_arrays.
* \param output_grads the gradient on outputs, to be propagated back.
* \return a result executor, moved from exec.
*/
void Backward(const Rcpp::List& output_grads);
/*!
* \brief Create a new R Executor by bind on symbol
* \param symbol The R symbol to bind.
* \param context The device to bind.
* \param arg_arrays The argument arrays giving the initial value of arguments.
* \param aux_arrays The auxiliary state arrays giving the initial value of auxiliary states.
* \param grad_reqs Array of booleans, giving the requirements of gradient.
*/
static RObjectType Bind(const Symbol::RObjectType& symbol,
const Context::RObjectType& context,
const Rcpp::List& arg_arrays,
const Rcpp::List& aux_arrays,
const Rcpp::List& grad_reqs);
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();
// destructor
~Executor() {
delete out_arrays_;
delete arg_arrays_;
delete grad_arrays_;
delete aux_arrays_;
if (!this->moved_) {
MX_CALL(MXExecutorFree(handle_));
}
}
private:
// friend with symbol
friend class Symbol;
// internal constructor, enable trivial operator=
Executor()
: out_arrays_(nullptr),
arg_arrays_(nullptr),
grad_arrays_(nullptr),
aux_arrays_(nullptr) {}
/*! \return a new Object that is moved from current one */
inline Executor* CreateMoveObject() {
Executor *moved = new Executor();
*moved = *this;
out_arrays_ = nullptr;
arg_arrays_ = nullptr;
grad_arrays_ = nullptr;
aux_arrays_ = nullptr;
return moved;
}
/*!
* \brief Clone src into a new space.
* \param src source list of arrays to clone.
* \return A cloned list of arrays under same context.
*/
static Rcpp::List CloneArray(const Rcpp::List& src);
/*!
* \brief Copy arrays from to to
* \param array_name The name of the array, used for error message.
* \param from source list to copy from.
* \param to target list to copy to.
* \param match_name whether to use name to match the input, instead of index.
* \param skip_null Whether null is allowed, when there is NULL in the array, simply ignore.
*/
static void UpdateArray(const char* array_name,
const Rcpp::List& from, Rcpp::List *to,
bool match_name, bool skip_null);
/*! \brief output arrays of Executor */
Rcpp::List *out_arrays_;
/*! \brief argument arrays of Executor */
Rcpp::List *arg_arrays_;
/*! \brief gradient arrays of Executor */
Rcpp::List *grad_arrays_;
/*! \brief auxiliary arrays of Executor */
Rcpp::List *aux_arrays_;
/*! \brief internal executor handle */
ExecutorHandle handle_;
};
} // namespace R
} // namespace mxnet
RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::Executor);
namespace Rcpp {
template<>
inline bool is<mxnet::R::Executor>(SEXP x) {
return internal::is__module__object_fix<mxnet::R::Executor>(x);
}
}
#endif // MXNET_RCPP_EXECUTOR_H_