blob: 8645d8576b0e83d907e8cc12166aeb7a99f0cb7d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief Rcpp interface of MXNet
* All the interface is done through C API,
* to achieve maximum portability when we need different compiler for libmxnet.
*/
#ifndef MXNET_RCPP_BASE_H_
#define MXNET_RCPP_BASE_H_
#include <Rcpp.h>
#include <dmlc/base.h>
#include <mxnet/c_api.h>
#include <string>
#include <sstream>
#include <set>
#include <vector>
#include <algorithm>
/*! \brief namespace of mxnet */
namespace mxnet {
/*! \brief namespace of R package */
namespace R {
/*! \brief macro to be compatible with non c++11 env */
#if DMLC_USE_CXX11 == 0
#ifndef nullptr
#define nullptr NULL
#endif
#endif
/*!
* \brief Log that enables Stop and print message to R console
*/
class RLogFatal {
public:
RLogFatal(const char* file, int line) {
log_stream_ << file << ":"
<< line << ": ";
}
inline std::ostringstream &stream() {
return log_stream_;
}
~RLogFatal()
#if DMLC_USE_CXX11
noexcept(false)
#endif
{
std::string msg = log_stream_.str() + '\n';
throw Rcpp::exception(msg.c_str());
}
private:
std::ostringstream log_stream_;
};
/*!
* \brief LOG FATAL to report error to R console
* Need to append newline to it.
*/
#define RLOG_FATAL ::mxnet::R::RLogFatal(__FILE__, __LINE__).stream()
/*! \brief LOG INFO to report message to R console, need to append newline */
#define RLOG_INFO ::Rcpp::Rcout
/*!
* \brief Checking macro for Rcpp code, report error ro R console
* \code
* RCHECK(data.size() == 1) << "Data size must be 1";
* \endcode
*/
#define RCHECK(x) \
if (!(x)) RLOG_FATAL << "RCheck failed: " #x << ' ' /* NOLINT(*) */
/*!
* \brief protected MXNet C API call, report R error if happens.
* \param func Expression to call.
*/
#define MX_CALL(func) \
{ \
int e = (func); \
if (e != 0) { \
throw Rcpp::exception(MXGetLastError()); \
} \
}
/*!
* \brief set seed to the random number generator
* \param seed the seed to set.
*/
void SetSeed(int seed);
/*!
* \brief Base Movable class of MXNet Module object.
* This class will define several common functions.
* \tparam Class The class name of subclass
*/
template<typename Class>
class MXNetMovable {
public:
/*! \brief The type of Class in R's side */
typedef Rcpp::RObject RObjectType;
/*!
* \brief Get a pointer representation of obj.
* \param obj The R object.
* \return The pointer of the object.
* \throw Rcpp::exception if the object is moved.
*/
inline static Class* XPtr(const Rcpp::RObject& obj) {
Class* ptr = Rcpp::as<Class*>(obj);
bool has_been_moved = static_cast<MXNetMovable<Class>*>(ptr)->moved_;
RCHECK(!has_been_moved)
<< "Passed in a moved " << Class::TypeName() << " as parameter."
<< " Moved parameters should no longer be used";
return ptr;
}
protected:
/*! \brief default constructor */
MXNetMovable() : moved_(false) {}
/*!
* \brief Default implement to Move a existing R Class object to a new one.
* \param src The source R Object.
* \return A new R object containing moved information as old one.
*/
inline static RObjectType Move(const Rcpp::RObject& src) {
Class* old = Class::XPtr(src);
Class* moved = old->CreateMoveObject();
static_cast<MXNetMovable<Class>*>(old)->moved_ = true;
return Rcpp::internal::make_new_object(moved);
}
/*! \brief Whether the object has been moved */
bool moved_;
};
/*! \brief Context of device enviroment */
struct Context {
/*! \brief The device ID of the context */
int dev_type;
/*! \brief The device ID of the context */
int dev_id;
/*! \brief The R object type of the context */
typedef Rcpp::List RObjectType;
/*! \brief default constructor */
Context() {}
/*!
* \brief Constructor
* \param src source R representation.
*/
explicit Context(const Rcpp::RObject& src) {
Rcpp::List list(src);
this->dev_id = list[1];
this->dev_type = list[2];
}
/*! \return R object representation of the context */
inline RObjectType RObject() const {
const char *dev_name = "cpu";
if (dev_type == kGPU) dev_name = "gpu";
Rcpp::List ret = Rcpp::List::create(
Rcpp::Named("device") = dev_name,
Rcpp::Named("device_id") = dev_id,
Rcpp::Named("device_typeid") = dev_type);
ret.attr("class") = "MXContext";
return ret;
}
/*!
* Create a CPU context.
* \param dev_id the device id.
* \return CPU Context.
*/
inline static RObjectType CPU(int dev_id = 0) {
Context ctx;
ctx.dev_type = kCPU;
ctx.dev_id = dev_id;
return ctx.RObject();
}
/*!
* Create a GPU context.
* \param dev_id the device id.
* \return GPU Context.
*/
inline static RObjectType GPU(int dev_id) {
Context ctx;
ctx.dev_type = kGPU;
ctx.dev_id = dev_id;
return ctx.RObject();
}
/*! \brief initialize all the Rcpp module functions */
inline static void InitRcppModule() {
using namespace Rcpp; // NOLINT(*);
function("mx.cpu", &CPU,
List::create(_["dev.id"] = 0),
"Create a CPU context.");
function("mx.gpu", &GPU,
List::create(_["dev.id"] = 0),
"Create a GPU context with specific device_id.");
}
/*! \brief the device type id for CPU */
static const int kCPU = 1;
/*! \brief the device type id for GPU */
static const int kGPU = 2;
};
/*!
* \brief Get a C char pointer vector representation of keys
* The keys must stay alive when using c_keys
* \param keys the string vector to get keys from
* \return the C char pointer
*/
inline std::vector<const char*> CKeys(const std::vector<std::string> &keys) {
std::vector<const char*> c_keys(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
c_keys[i] = keys[i].c_str();
}
return c_keys;
}
/*!
*\return whether the expression is simple arguments
* That is not module object and can be converted to string
*/
inline const char* TypeName(const Rcpp::RObject& args) {
switch (TYPEOF(args)) {
case REALSXP: return "numeric";
case VECSXP: return "list";
case INTSXP: return "integer";
case CPLXSXP: return "complex";
case LGLSXP: return "logical";
case STRSXP: return "string";
default: return "object type";
}
}
/*!
* \brief A simple function to convert value of known type to string.
* \param val the value
* \return the corresponding string
*/
template<typename T>
inline std::string toString(const Rcpp::RObject& val) {
std::ostringstream os;
os << Rcpp::as<T>(val);
return os.str();
}
/*!
* \brief Check whether the value is simple parameter
* \param val The value to check.
*/
inline bool isSimple(const Rcpp::RObject& val) {
switch (TYPEOF(val)) {
case STRSXP:
case INTSXP:
case REALSXP:
case LGLSXP: return true;
default: return false;
}
}
/*!
* \brief Create a API compatile string presentation of value
* \param key The key name of the parameter
* \param val The value of the parameter
* \return A python string representation of val
*/
inline std::string toPyString(const std::string &key, const Rcpp::RObject& val) {
std::ostringstream os;
int len = Rf_length(val);
if (len != 1 ||
key.substr(std::max(5, static_cast<int>(key.size())) - 5) == std::string("shape")) {
RCHECK(TYPEOF(val) == INTSXP || TYPEOF(val) == REALSXP)
<< "Only accept integer vectors or simple types";
// Do shape convesion back to reversed shape.
Rcpp::IntegerVector vec(val);
os << "(";
for (size_t i = 0; i < vec.size(); ++i) {
int value = vec[vec.size() - i - 1];
if (i != 0) os << ", ";
os << value;
}
if (vec.size() == 1) os << ",";
os << ")";
return os.str();
}
switch (TYPEOF(val)) {
case STRSXP: return Rcpp::as<std::string>(val);
case INTSXP: return toString<int>(val);
case REALSXP: return toString<double>(val);
case LGLSXP: return toString<bool>(val);
default: {
RLOG_FATAL << "Unsupported parameter type " << TypeName(val)
<< " for argument " << key
<< ", expect integer, logical, or string.";
}
}
return os.str();
}
/*!
* \brief Convert dot . style seperator into underscore _
* So num_hidden -> num.hidden
* This allows R user to use the dot style seperators.
* \param src the source key
* \retunr a converted key
*/
inline std::string FormatParamKey(std::string src) {
for (size_t i = 0; i < src.size(); ++i) {
if (src[i] == '.') src[i] = '_';
}
return src;
}
/*! \return wher list has names */
inline bool HasName(const Rcpp::List& src) {
Rcpp::RObject obj = src.names();
return obj != R_NilValue;
}
/*!
* \brief Get names from list, return vector of empty strings if names do not present
* \param src the source list
* \retunr vector of string of same length as src.
*/
inline std::vector<std::string> SafeGetListNames(const Rcpp::List& src) {
if (!HasName(src)) {
return std::vector<std::string>(src.size(), std::string());
} else {
return src.names();
}
}
/*!
* \brief convert Rcpp's Dimension to internal shape vector
* This will reverse the shape layout internally
* \param rshape The dimension in R
* \return A internal vector representation of shapes in mxnet.
*/
inline std::vector<mx_uint> Dim2InternalShape(const Rcpp::Dimension &rshape) {
std::vector<mx_uint> shape(rshape.size());
for (size_t i = 0; i < rshape.size(); ++i) {
shape[rshape.size() - i - 1] = rshape[i];
}
return shape;
}
class NDArray;
class Symbol;
class Executor;
class KVStore;
} // namespace R
} // namespace mxnet
// This is Rcpp namespace, contains patches to Rcpp
// The following section follows style of Rcpp
namespace Rcpp {
namespace internal { // NOLINT(*)
inline bool is_module_object_internal_fix(SEXP obj, const char* clazz) {
Environment env(obj);
SEXP sexp = env.get(".cppclass");
if (TYPEOF(sexp) != EXTPTRSXP) return false;
XPtr<class_Base> xp(sexp);
return xp->has_typeinfo_name(clazz);
}
template <typename T> bool is__module__object_fix(SEXP x) {
typedef typename Rcpp::traits::un_pointer<T>::type CLASS;
if (!is__simple<S4>(x)) return false;
return is_module_object_internal_fix(x, typeid(CLASS).name());
}
} // namespace internal NOLINT(*)
template<>
inline bool is<mxnet::R::NDArray>(SEXP x);
template<>
inline bool is<mxnet::R::Symbol>(SEXP x);
template<>
inline bool is<mxnet::R::Executor>(SEXP x);
} // namespace Rcpp
#endif // MXNET_RCPP_BASE_H_