blob: d7dff4098b27b90c5ce0b6a8e3b57895092d891b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ndarray.h
* \brief NDArray interface that handles array arithematics.
*/
#ifndef MXNET_NDARRAY_H_
#define MXNET_NDARRAY_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dmlc/registry.h>
#include <nnvm/node.h>
#include <vector>
#include <map>
#include <string>
#include <memory>
#include "./base.h"
#include "./storage.h"
#include "./engine.h"
#if MKL_EXPERIMENTAL == 1
#include <mkl_memory.h>
#endif
// check c++11
#if DMLC_USE_CXX11 == 0
#error "cxx11 was required for ndarray module"
#endif
namespace mxnet {
// forward declaration
namespace autograd {
class AGNode;
using AGNodePtr = std::shared_ptr<AGNode>;
class AGNodeEntry {
public:
AGNodePtr ag_node;
uint32_t index;
uint32_t version;
void clear() {
ag_node.reset();
index = version = 0;
}
nnvm::NodeEntry nn_entry() const;
bool is_none() const;
};
class AutogradRuntime;
} // namespace autograd
/*!
* \brief ndarray interface
*/
class NDArray {
public:
/*! \brief default constructor */
NDArray() {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = MKLMemHolder::create();
#endif
}
/*!
* \brief constructs a new dynamic NDArray
* \param shape the shape of array
* \param ctx context of NDArray
* \param delay_alloc whether delay the allocation
* \param dtype data type of this ndarray
*/
NDArray(const TShape &shape, Context ctx,
bool delay_alloc = false, int dtype = mshadow::default_type_flag)
: ptr_(std::make_shared<Chunk>(shape.Size(), ctx, delay_alloc, dtype)),
shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}
/*!
* \brief constructing a static NDArray that shares data with TBlob
* Use with caution: allocate ONLY ONE NDArray for each TBlob,
* make sure the memory region is available through out the life of NDArray
* \param data the memory content of static data
* \param dev_id the device id this tensor sits at
*/
NDArray(const TBlob &data, int dev_id)
: ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_),
dtype_(data.type_flag_), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}
/*!
* \return the shape of current NDArray
*/
inline const TShape& shape() const {
return shape_;
}
/*!
* \return the data TBlob
*/
inline const TBlob& data() const {
CheckAndAlloc();
SetTBlob();
return tblob_;
}
/*!
* \return the gradient ndarray.
*/
NDArray grad() const;
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
*/
inline Context ctx() const {
return ptr_->shandle.ctx;
}
/*!
* \return the data type of NDArray, this function is only valid when the NDArray is not empty
*/
inline int dtype() const {
return dtype_;
}
/*! \return whether this ndarray is not initialized */
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*! \return updated grad state in entry_ */
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
void set_fresh_out_grad(bool state) const;
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
*/
inline void WaitToRead() const {
if (is_none()) return;
Engine::Get()->WaitForVar(ptr_->var);
}
/*!
* \brief Block until all the pending read/write operations with respect
* to current NDArray are finished, and write can be performed.
*/
inline void WaitToWrite() const {
if (is_none()) return;
/*!
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var});
Engine::Get()->WaitForVar(ptr_->var);
}
/*! \return the associated variable of the ndarray.*/
inline Engine::VarHandle var() const {
return ptr_->var;
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
*/
void Save(dmlc::Stream *strm) const;
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \return whether the load is successful
*/
bool Load(dmlc::Stream *strm);
/*!
* \brief set all the elements in ndarray to be scalar
* \param scalar the scalar to set
* \return reference of self
*/
NDArray &operator=(real_t scalar);
/*!
* \brief elementwise add to current space
* this mutate the current NDArray
* \param src the data to add
* \return reference of self
*/
NDArray &operator+=(const NDArray &src);
/*!
* \brief elementwise add to current space
* this mutate the current NDArray
* \param src the data to add
* \return reference of self
*/
NDArray &operator+=(const real_t &src);
/*!
* \brief elementwise subtract from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator-=(const NDArray &src);
/*!
* \brief elementwise subtract from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator-=(const real_t &src);
/*!
* \brief elementwise multiplication to current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator*=(const NDArray &src);
/*!
* \brief elementwise multiplication to current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator*=(const real_t &src);
/*!
* \brief elementwise division from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator/=(const NDArray &src);
/*!
* \brief elementwise division from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator/=(const real_t &src);
/*!
* \brief return transpose of current NDArray
* \return a new transposed NDArray
*/
NDArray T() const;
/*!
* \brief return a new copy this NDArray
* \param ctx the new context of this NDArray
* \return the new copy
*/
NDArray Copy(Context ctx) const;
/*!
* \brief Do a synchronize copy from a continugous CPU memory region.
*
* This function will call WaitToWrite before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NDArray(thus dependency not being tracked).
*
* \param data the data source to copy from.
* \param size the size of the source array, in sizeof(DType) not raw btyes.
*/
void SyncCopyFromCPU(const void *data, size_t size) const;
/*!
* \brief Do a synchronize copy to a continugous CPU memory region.
*
* This function will call WaitToRead before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NDArray(thus dependency not being tracked).
*
* \param data the data source to copyinto.
* \param size the memory size we want to copy into, in sizeof(DType) not raw btyes.
*/
void SyncCopyToCPU(void *data, size_t size) const;
/*!
* \brief Slice a NDArray
* \param begin begin index in first dim
* \param end end index in first dim
* \return sliced NDArray
*/
NDArray Slice(index_t begin, index_t end) const;
/*!
* \brief Index a NDArray
* \param idx the index
* \return idx-th sub array NDArray
*/
NDArray At(index_t idx) const;
/*!
* \brief Create a NDArray that shares memory with current one
* The new array must have smaller memory size than the current array.
* \param shape new shape
* \param dtype The data type.
* \return NDArray in new shape and type.
*/
inline NDArray AsArray(const TShape &shape, int dtype) const {
CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_),
shape.Size() * mshadow::mshadow_sizeof(dtype))
<< "NDArray.AsArray: target memory size is bigger";
#if MKL_EXPERIMENTAL == 1
if (Mkl_mem_ != nullptr) {
// convert prv to cpu
Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr);
}
#endif
NDArray ret = *this;
ret.shape_ = shape;
ret.dtype_ = dtype;
return ret;
}
/*!
* \brief Get an reshaped NDArray
* \param shape new shape
* \return NDArray in new shape
*/
NDArray Reshape(const TShape &shape) const;
/*!
* \brief Return a copy of this NDArray without autograd history
*/
NDArray Detach() const {
NDArray ret(*this);
ret.entry_ = autograd::AGNodeEntry{nullptr, 0, 0};
return ret;
}
nnvm::Symbol get_autograd_symbol() {
CHECK(!entry_.is_none())
<< "NDArray is not part of a computation graph. Did you forget to turn on recording?";
nnvm::Symbol ret;
ret.outputs.emplace_back(entry_.nn_entry());
return ret;
}
/*!
* \brief Allocate the space if it is delayed allocated.
* This is an internal function used by system that normal user should not use
*/
inline void CheckAndAlloc() const {
ptr_->CheckAndAlloc();
}
/*!
* \brief Save list of ndarray into the Stream.x
* \param fo The stream of output.
* \param data the NDArrays to be saved.
* \param names the name of the NDArray, optional, can be zero length.
*/
static void Save(dmlc::Stream* fo,
const std::vector<NDArray>& data,
const std::vector<std::string>& names);
/*!
* \brief Load list of ndarray into from the stream.
* \param fi The stream of the input file.
* \param data the NDArrays to be loaded
* \param keys the name of the NDArray, if saved in the file.
*/
static void Load(dmlc::Stream* fi,
std::vector<NDArray>* data,
std::vector<std::string>* keys);
private:
friend class autograd::AutogradRuntime;
/*! \brief the real data chunk that backs NDArray */
struct Chunk {
/*! \brief storage handlefrom storage engine */
Storage::Handle shandle;
/*! \brief variable from engine */
Engine::VarHandle var;
/*!
* \brief if this is true, this means the data do not come
* from Storage, and do not need to be freed
*/
bool static_data;
/*! \brief whether allocation is delayed */
bool delay_alloc;
/*! \brief default cosntructor */
Chunk() : static_data(true), delay_alloc(false) {
var = Engine::Get()->NewVariable();
}
/*! \brief construct from static data */
Chunk(const TBlob &data, int dev_id)
: static_data(true),
delay_alloc(false) {
var = Engine::Get()->NewVariable();
if (data.dev_mask() == cpu::kDevMask) {
shandle.ctx = Context::CPU();
} else {
CHECK_EQ(data.dev_mask(), gpu::kDevMask);
shandle.ctx = Context::GPU(dev_id);
}
shandle.dptr = data.dptr_;
shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
}
/*! \brief construct a new chunk */
Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype)
: static_data(false), delay_alloc(true) {
var = Engine::Get()->NewVariable();
shandle.size = size * mshadow::mshadow_sizeof(dtype);
shandle.ctx = ctx;
if (!delay_alloc_) this->CheckAndAlloc();
}
/*! \brief check if delay alloc is on, do alloc if not yet done */
inline void CheckAndAlloc(void) {
if (delay_alloc) {
shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
delay_alloc = false;
}
}
/*! \brief destructor */
~Chunk() {
if (static_data || delay_alloc) {
Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var);
} else {
Storage::Handle h = this->shandle;
Engine::Get()->DeleteVariable([h](RunContext s) {
Storage::Get()->Free(h);
}, shandle.ctx, var);
}
}
};
void SetTBlob() const {
tblob_.dptr_ = static_cast<char*>(ptr_->shandle.dptr) + byte_offset_;
tblob_.shape_ = shape_;
tblob_.type_flag_ = dtype_;
tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id);
#if MKL_EXPERIMENTAL == 1
tblob_.Mkl_mem_ = Mkl_mem_;
#endif
}
#if MKL_EXPERIMENTAL == 1
std::shared_ptr<MKLMemHolder> Mkl_mem_;
#endif
/*! \brief internal data of NDArray */
std::shared_ptr<Chunk> ptr_;
/*! \brief shape of current NDArray */
TShape shape_;
/*! \brief byte offset in chunk */
size_t byte_offset_ = 0;
/*! \brief type of data */
int dtype_ = -1;
/*! \brief node entry for autograd */
autograd::AGNodeEntry entry_;
/*!
* \brief internal TBlob
* \note When user access tblob_ by some const methods like
* NDArray::data(), the dptr in tblob_ still need to be updated
* in case that allocation happens. So we make it mutable for
* this situation.
*/
mutable TBlob tblob_;
};
/*!
* \brief issue an copy operation from one NDArray to another
* the two ndarray can sit on different devices
* this operation will be scheduled by the engine
*
* \param from the ndarray we want to copy data from
* \param to the target ndarray
* \param priority Priority of the action.
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);
/*!
* \brief Perform elementwise sum over each data from source, store result into out.
* \param source the ndarray we want to sum
* \param out the target ndarray
* \param priority Priority of the action.
*/
void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priority = 0);
/*!
* \brief elementwise add
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator+(const NDArray &lhs, const NDArray &rhs);
/*!
* \brief elementwise add
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator+(const NDArray &lhs, const real_t &rhs);
/*!
* \brief elementwise subtraction
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator-(const NDArray &lhs, const NDArray &rhs);
/*!
* \brief elementwise subtraction
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator-(const NDArray &lhs, const real_t &rhs);
/*!
* \brief elementwise multiplication
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator*(const NDArray &lhs, const NDArray &rhs); \
/*!
* \brief elementwise multiplication
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator*(const NDArray &lhs, const real_t &rhs);
/*!
* \brief elementwise division
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator/(const NDArray &lhs, const NDArray &rhs);
/*!
* \brief elementwise division
* \param lhs left operand
* \param rhs right operand
* \return a new result ndarray
*/
NDArray operator/(const NDArray &lhs, const real_t &rhs);
/*!
* \brief Seed the random number generator.
* \param seed the seed to set to global random number generators.
*/
void RandomSeed(uint32_t seed);
/*!
* \brief Sample uniform distribution for each elements of out.
* \param begin lower bound of distribution.
* \param end upper bound of distribution.
* \param out output NDArray.
*/
void SampleUniform(real_t begin, real_t end, NDArray *out);
/*!
* \brief Sample gaussian distribution for each elements of out.
* \param mu mean of gaussian distribution.
* \param sigma standard deviation of gaussian distribution.
* \param out output NDArray.
*/
void SampleGaussian(real_t mu, real_t sigma, NDArray *out);
/*!
* \brief Sample gamma distribution for each elements of out.
* \param alpha parameter (shape) of the gamma distribution
* \param beta parameter (scale) of the gamma distribution
* \param out output NDArray.
*/
void SampleGamma(real_t alpha, real_t beta, NDArray *out);
/*!
* \brief Sample exponential distribution for each elements of out.
* \param lambda parameter (rate) of the exponential distribution
* \param out output NDArray.
*/
void SampleExponential(real_t lambda, NDArray *out);
/*!
* \brief Sample Poisson distribution for each elements of out.
* \param lambda parameter (rate) of the Poisson distribution
* \param out output NDArray.
*/
void SamplePoisson(real_t lambda, NDArray *out);
/*!
* \brief Sample negative binomial distribution for each elements of out.
* \param k failure limit
* \param p success probability
* \param out output NDArray.
*/
void SampleNegBinomial(int32_t k, real_t p, NDArray *out);
/*!
* \brief Sample generalized negative binomial distribution for each elements of out.
* \param mu parameter (mean) of the distribution
* \param alpha parameter (over dispersion) of the distribution
* \param out output NDArray.
*/
void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out);
//--------------------------------------------------------------
// The following part are API Registration of NDArray functions.
//--------------------------------------------------------------
/*! \brief definition of NDArray function */
typedef std::function<void (NDArray **used_vars,
real_t *scalars,
NDArray **mutate_vars,
int num_params,
char **param_keys,
char **param_vals)> NDArrayAPIFunction;
/*! \brief mask information on how functions can be exposed */
enum NDArrayFunctionTypeMask {
/*! \brief all the use_vars should go before scalar */
kNDArrayArgBeforeScalar = 1,
/*! \brief all the scalar should go before use_vars */
kScalarArgBeforeNDArray = 1 << 1,
/*!
* \brief whether this function allows the handles in the target to
* be empty NDArray that are not yet initialized, and will initialize
* them when the function is invoked.
*
* most function should support this, except copy between different
* devices, which requires the NDArray to be pre-initialized with context
*/
kAcceptEmptyMutateTarget = 1 << 2
};
/*! \brief Registry entry for NDArrayFunction */
struct NDArrayFunctionReg
: public dmlc::FunctionRegEntryBase<NDArrayFunctionReg,
NDArrayAPIFunction> {
/*! \brief number of variable used by this function */
unsigned num_use_vars;
/*! \brief number of variable mutated by this function */
unsigned num_mutate_vars;
/*! \brief number of scalars used by this function */
unsigned num_scalars;
/*! \brief information on how function should be called from API */
int type_mask;
/*!
* \brief constructor
*/
NDArrayFunctionReg()
: num_use_vars(0),
num_mutate_vars(0),
num_scalars(0),
type_mask(0) {}
/*!
* \brief set the function body to a NDArray setvalue function
* this will also auto set the parameters correctly
* \param fsetvalue function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs,
NDArray *out)) {
body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
(*fsetvalue)(s[0], mutate_vars[0]);
};
num_mutate_vars = 1; num_scalars = 1;
this->add_argument("src", "real_t", "Source input to the function.");
return *this;
}
/*!
* \brief set the function body to a ternary NDArray function
* this will also auto set the parameters correctly
* \param fternary function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void(*fternary)(const NDArray &lhs,
const NDArray &mhs,
const NDArray &rhs,
NDArray *out)) {
body = [fternary](NDArray **used_vars,
real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
(*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]);
};
num_use_vars = 3; num_mutate_vars = 1;
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
this->add_argument("lhs", "NDArray", "Left operand to the function.");
this->add_argument("mhs", "NDArray", "Middle operand to the function.");
this->add_argument("rhs", "NDArray", "Right operand to the function.");
return *this;
}
/*!
* \brief set the function body to a binary NDArray function
* this will also auto set the parameters correctly
* \param fbinary function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs,
const NDArray &rhs,
NDArray *out)) {
body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
(*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
};
num_use_vars = 2; num_mutate_vars = 1;
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
this->add_argument("lhs", "NDArray", "Left operand to the function.");
this->add_argument("rhs", "NDArray", "Right operand to the function.");
return *this;
}
/*!
* \brief set the function body to a binary NDArray function
* this will also auto set the parameters correctly
* \param fscalar function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs,
const real_t &rhs,
NDArray *out)) {
body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
(*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
};
num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
this->add_argument("lhs", "NDArray", "Left operand to the function.");
this->add_argument("rhs", "real_t", "Right operand to the function.");
return *this;
}
/*!
* \brief set the function body to a unary NDArray function
* this will also auto set the parameters correctly
* \param funary function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src,
NDArray *out)) {
body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
(*funary)(*used_vars[0], mutate_vars[0]);
};
num_use_vars = 1; num_mutate_vars = 1;
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
this->add_argument("src", "NDArray", "Source input to the function.");
return *this;
}
/*!
* \brief set the function body to a unary NDArray function
* this will also auto set the parameters correctly
* \param fgeneric function body to set
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_function(
void (*fgeneric)(NDArray **used_vars,
real_t *s,
NDArray **mutate_vars,
const std::map<std::string, std::string>& param)) {
body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
int num_params, char **param_keys, char **param_vals) {
std::map<std::string, std::string> param;
for (int i = 0; i < num_params; ++i) {
param[param_keys[i]] = param_vals[i];
}
fgeneric(used_vars, s, mutate_vars, param);
};
return *this;
}
/*!
* \brief set the number of mutate variables
* \param n number of mutate variablesx
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_num_use_vars(unsigned n) {
num_use_vars = n; return *this;
}
/*!
* \brief set the number of mutate variables
* \param n number of mutate variablesx
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_num_mutate_vars(unsigned n) {
num_mutate_vars = n; return *this;
}
/*!
* \brief set the number of scalar arguments
* \param n number of scalar arguments
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_num_scalars(unsigned n) {
num_scalars = n; return *this;
}
/*!
* \brief set type mask
* \param tmask typemask
* \return ref to the registered entry, used to set properties
*/
inline NDArrayFunctionReg &set_type_mask(int tmask) {
type_mask = tmask; return *this;
}
}; // NDArrayFunctionReg
/*!
* \brief Macro to register NDArray function
*
* Example: the following code is example to register a plus
* \code
*
* REGISTER_NDARRAY_FUN(Plus)
* .set_function(Plus);
*
* \endcode
*/
#define MXNET_REGISTER_NDARRAY_FUN(name) \
DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name)
} // namespace mxnet
namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
} // namespace dmlc
#endif // MXNET_NDARRAY_H_