| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file ndarray.h |
| * \brief NDArray interface that handles array arithematics. |
| */ |
| #ifndef MXNET_NDARRAY_H_ |
| #define MXNET_NDARRAY_H_ |
| |
| #include <dmlc/base.h> |
| #include <dmlc/logging.h> |
| #include <dmlc/io.h> |
| #include <dmlc/type_traits.h> |
| #include <dmlc/registry.h> |
| #include <nnvm/node.h> |
| #include <vector> |
| #include <map> |
| #include <string> |
| #include <memory> |
| #include "./base.h" |
| #include "./storage.h" |
| #include "./engine.h" |
| #if MKL_EXPERIMENTAL == 1 |
| #include <mkl_memory.h> |
| #endif |
| // check c++11 |
| #if DMLC_USE_CXX11 == 0 |
| #error "cxx11 was required for ndarray module" |
| #endif |
| |
| namespace mxnet { |
| |
| // forward declaration |
| namespace autograd { |
| class AGNode; |
| |
| using AGNodePtr = std::shared_ptr<AGNode>; |
| |
| class AGNodeEntry { |
| public: |
| AGNodePtr ag_node; |
| uint32_t index; |
| uint32_t version; |
| |
| void clear() { |
| ag_node.reset(); |
| index = version = 0; |
| } |
| |
| nnvm::NodeEntry nn_entry() const; |
| bool is_none() const; |
| }; |
| |
| class AutogradRuntime; |
| } // namespace autograd |
| |
| /*! |
| * \brief ndarray interface |
| */ |
| class NDArray { |
| public: |
| /*! \brief default constructor */ |
| NDArray() { |
| #if MKL_EXPERIMENTAL == 1 |
| Mkl_mem_ = MKLMemHolder::create(); |
| #endif |
| } |
| /*! |
| * \brief constructs a new dynamic NDArray |
| * \param shape the shape of array |
| * \param ctx context of NDArray |
| * \param delay_alloc whether delay the allocation |
| * \param dtype data type of this ndarray |
| */ |
| NDArray(const TShape &shape, Context ctx, |
| bool delay_alloc = false, int dtype = mshadow::default_type_flag) |
| : ptr_(std::make_shared<Chunk>(shape.Size(), ctx, delay_alloc, dtype)), |
| shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { |
| #if MKL_EXPERIMENTAL == 1 |
| Mkl_mem_ = std::make_shared<MKLMemHolder>(); |
| #endif |
| } |
| /*! |
| * \brief constructing a static NDArray that shares data with TBlob |
| * Use with caution: allocate ONLY ONE NDArray for each TBlob, |
| * make sure the memory region is available through out the life of NDArray |
| * \param data the memory content of static data |
| * \param dev_id the device id this tensor sits at |
| */ |
| NDArray(const TBlob &data, int dev_id) |
| : ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_), |
| dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { |
| #if MKL_EXPERIMENTAL == 1 |
| Mkl_mem_ = std::make_shared<MKLMemHolder>(); |
| #endif |
| } |
| /*! |
| * \return the shape of current NDArray |
| */ |
| inline const TShape& shape() const { |
| return shape_; |
| } |
| /*! |
| * \return the data TBlob |
| */ |
| inline const TBlob& data() const { |
| CheckAndAlloc(); |
| SetTBlob(); |
| return tblob_; |
| } |
| /*! |
| * \return the context of NDArray, this function is only valid when the NDArray is not empty |
| */ |
| inline Context ctx() const { |
| return ptr_->shandle.ctx; |
| } |
| /*! |
| * \return the data type of NDArray, this function is only valid when the NDArray is not empty |
| */ |
| inline int dtype() const { |
| return dtype_; |
| } |
| /*! \return whether this ndarray is not initialized */ |
| inline bool is_none() const { |
| return ptr_.get() == nullptr; |
| } |
| /*! \return updated grad state in entry_ */ |
| bool fresh_out_grad() const; |
| /*! \return updated grad state in entry_ */ |
| void set_fresh_out_grad(bool state) const; |
| /*! |
| * \brief Block until all the pending write operations with respect |
| * to current NDArray are finished, and read can be performed. |
| */ |
| inline void WaitToRead() const { |
| if (is_none()) return; |
| Engine::Get()->WaitForVar(ptr_->var); |
| } |
| /*! |
| * \brief Block until all the pending read/write operations with respect |
| * to current NDArray are finished, and write can be performed. |
| */ |
| inline void WaitToWrite() const { |
| if (is_none()) return; |
| /*! |
| * Push an empty mutable function to flush all preceding reads to the |
| * variable. |
| */ |
| Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var}); |
| Engine::Get()->WaitForVar(ptr_->var); |
| } |
| /*! \return the associated variable of the ndarray.*/ |
| inline Engine::VarHandle var() const { |
| return ptr_->var; |
| } |
| /*! |
| * \brief save the content into binary stream |
| * \param strm the output stream |
| */ |
| void Save(dmlc::Stream *strm) const; |
| /*! |
| * \brief load the content from binary stream |
| * \param strm the output stream |
| * \return whether the load is successful |
| */ |
| bool Load(dmlc::Stream *strm); |
| /*! |
| * \brief set all the elements in ndarray to be scalar |
| * \param scalar the scalar to set |
| * \return reference of self |
| */ |
| NDArray &operator=(real_t scalar); |
| /*! |
| * \brief elementwise add to current space |
| * this mutate the current NDArray |
| * \param src the data to add |
| * \return reference of self |
| */ |
| NDArray &operator+=(const NDArray &src); |
| /*! |
| * \brief elementwise add to current space |
| * this mutate the current NDArray |
| * \param src the data to add |
| * \return reference of self |
| */ |
| NDArray &operator+=(const real_t &src); |
| /*! |
| * \brief elementwise subtract from current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator-=(const NDArray &src); |
| /*! |
| * \brief elementwise subtract from current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator-=(const real_t &src); |
| /*! |
| * \brief elementwise multiplication to current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator*=(const NDArray &src); |
| /*! |
| * \brief elementwise multiplication to current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator*=(const real_t &src); |
| /*! |
| * \brief elementwise division from current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator/=(const NDArray &src); |
| /*! |
| * \brief elementwise division from current ndarray |
| * this mutate the current NDArray |
| * \param src the data to subtract |
| * \return reference of self |
| */ |
| NDArray &operator/=(const real_t &src); |
| /*! |
| * \brief return transpose of current NDArray |
| * \return a new transposed NDArray |
| */ |
| NDArray T() const; |
| /*! |
| * \brief return a new copy this NDArray |
| * \param ctx the new context of this NDArray |
| * \return the new copy |
| */ |
| NDArray Copy(Context ctx) const; |
| /*! |
| * \brief Do a synchronize copy from a continugous CPU memory region. |
| * |
| * This function will call WaitToWrite before the copy is performed. |
| * This is useful to copy data from existing memory region that are |
| * not wrapped by NDArray(thus dependency not being tracked). |
| * |
| * \param data the data source to copy from. |
| * \param size the size of the source array, in sizeof(DType) not raw btyes. |
| */ |
| void SyncCopyFromCPU(const void *data, size_t size) const; |
| /*! |
| * \brief Do a synchronize copy to a continugous CPU memory region. |
| * |
| * This function will call WaitToRead before the copy is performed. |
| * This is useful to copy data from existing memory region that are |
| * not wrapped by NDArray(thus dependency not being tracked). |
| * |
| * \param data the data source to copyinto. |
| * \param size the memory size we want to copy into, in sizeof(DType) not raw btyes. |
| */ |
| void SyncCopyToCPU(void *data, size_t size) const; |
| /*! |
| * \brief Slice a NDArray |
| * \param begin begin index in first dim |
| * \param end end index in first dim |
| * \return sliced NDArray |
| */ |
| NDArray Slice(index_t begin, index_t end) const; |
| /*! |
| * \brief Index a NDArray |
| * \param idx the index |
| * \return idx-th sub array NDArray |
| */ |
| NDArray At(index_t idx) const; |
| /*! |
| * \brief Create a NDArray that shares memory with current one |
| * The new array must have smaller memory size than the current array. |
| * \param shape new shape |
| * \param dtype The data type. |
| * \return NDArray in new shape and type. |
| */ |
| inline NDArray AsArray(const TShape &shape, int dtype) const { |
| CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), |
| shape.Size() * mshadow::mshadow_sizeof(dtype)) |
| << "NDArray.AsArray: target memory size is bigger"; |
| #if MKL_EXPERIMENTAL == 1 |
| if (Mkl_mem_ != nullptr) { |
| // convert prv to cpu |
| Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr); |
| } |
| #endif |
| NDArray ret = *this; |
| ret.shape_ = shape; |
| ret.dtype_ = dtype; |
| return ret; |
| } |
| /*! |
| * \brief Get an reshaped NDArray |
| * \param shape new shape |
| * \return NDArray in new shape |
| */ |
| NDArray Reshape(const TShape &shape) const; |
| /*! |
| * \brief Return a copy of this NDArray without autograd history |
| */ |
| NDArray Detach() const { |
| NDArray ret(*this); |
| ret.entry_ = autograd::AGNodeEntry{nullptr, 0, 0}; |
| return ret; |
| } |
| /*! |
| * \brief Allocate the space if it is delayed allocated. |
| * This is an internal function used by system that normal user should not use |
| */ |
| inline void CheckAndAlloc() const { |
| ptr_->CheckAndAlloc(); |
| } |
| /*! |
| * \brief Save list of ndarray into the Stream.x |
| * \param fo The stream of output. |
| * \param data the NDArrays to be saved. |
| * \param names the name of the NDArray, optional, can be zero length. |
| */ |
| static void Save(dmlc::Stream* fo, |
| const std::vector<NDArray>& data, |
| const std::vector<std::string>& names); |
| /*! |
| * \brief Load list of ndarray into from the stream. |
| * \param fi The stream of the input file. |
| * \param data the NDArrays to be loaded |
| * \param keys the name of the NDArray, if saved in the file. |
| */ |
| static void Load(dmlc::Stream* fi, |
| std::vector<NDArray>* data, |
| std::vector<std::string>* keys); |
| |
| private: |
| friend class autograd::AutogradRuntime; |
| /*! \brief the real data chunk that backs NDArray */ |
| struct Chunk { |
| /*! \brief storage handlefrom storage engine */ |
| Storage::Handle shandle; |
| /*! \brief variable from engine */ |
| Engine::VarHandle var; |
| /*! |
| * \brief if this is true, this means the data do not come |
| * from Storage, and do not need to be freed |
| */ |
| bool static_data; |
| /*! \brief whether allocation is delayed */ |
| bool delay_alloc; |
| /*! \brief default cosntructor */ |
| Chunk() : static_data(true), delay_alloc(false) { |
| var = Engine::Get()->NewVariable(); |
| } |
| /*! \brief construct from static data */ |
| Chunk(const TBlob &data, int dev_id) |
| : static_data(true), |
| delay_alloc(false) { |
| var = Engine::Get()->NewVariable(); |
| if (data.dev_mask() == cpu::kDevMask) { |
| shandle.ctx = Context::CPU(); |
| } else { |
| CHECK_EQ(data.dev_mask(), gpu::kDevMask); |
| shandle.ctx = Context::GPU(dev_id); |
| } |
| shandle.dptr = data.dptr_; |
| shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); |
| } |
| /*! \brief construct a new chunk */ |
| Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype) |
| : static_data(false), delay_alloc(true) { |
| var = Engine::Get()->NewVariable(); |
| shandle.size = size * mshadow::mshadow_sizeof(dtype); |
| shandle.ctx = ctx; |
| if (!delay_alloc_) this->CheckAndAlloc(); |
| } |
| /*! \brief check if delay alloc is on, do alloc if not yet done */ |
| inline void CheckAndAlloc(void) { |
| if (delay_alloc) { |
| shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); |
| delay_alloc = false; |
| } |
| } |
| /*! \brief destructor */ |
| ~Chunk() { |
| if (static_data || delay_alloc) { |
| Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); |
| } else { |
| Storage::Handle h = this->shandle; |
| Engine::Get()->DeleteVariable([h](RunContext s) { |
| Storage::Get()->Free(h); |
| }, shandle.ctx, var); |
| } |
| } |
| }; |
| |
| void SetTBlob() const { |
| tblob_.dptr_ = static_cast<char*>(ptr_->shandle.dptr) + byte_offset_; |
| tblob_.shape_ = shape_; |
| tblob_.type_flag_ = dtype_; |
| tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); |
| #if MKL_EXPERIMENTAL == 1 |
| tblob_.Mkl_mem_ = Mkl_mem_; |
| #endif |
| } |
| |
| #if MKL_EXPERIMENTAL == 1 |
| std::shared_ptr<MKLMemHolder> Mkl_mem_; |
| #endif |
| /*! \brief internal data of NDArray */ |
| std::shared_ptr<Chunk> ptr_; |
| /*! \brief shape of current NDArray */ |
| TShape shape_; |
| /*! \brief byte offset in chunk */ |
| size_t byte_offset_ = 0; |
| /*! \brief type of data */ |
| int dtype_ = -1; |
| /*! \brief node entry for autograd */ |
| autograd::AGNodeEntry entry_; |
| /*! |
| * \brief internal TBlob |
| * \note When user access tblob_ by some const methods like |
| * NDArray::data(), the dptr in tblob_ still need to be updated |
| * in case that allocation happens. So we make it mutable for |
| * this situation. |
| */ |
| mutable TBlob tblob_; |
| }; |
| |
| /*! |
| * \brief issue an copy operation from one NDArray to another |
| * the two ndarray can sit on different devices |
| * this operation will be scheduled by the engine |
| * |
| * \param from the ndarray we want to copy data from |
| * \param to the target ndarray |
| * \param priority Priority of the action. |
| * \note The function name explicitly marks the order of from and to |
| * due to different possible convention carried by copy function. |
| */ |
| void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); |
| |
| |
| /*! |
| * \brief Perform elementwise sum over each data from source, store result into out. |
| * \param source the ndarray we want to sum |
| * \param out the target ndarray |
| * \param priority Priority of the action. |
| */ |
| void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priority = 0); |
| |
| /*! |
| * \brief elementwise add |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator+(const NDArray &lhs, const NDArray &rhs); |
| /*! |
| * \brief elementwise add |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator+(const NDArray &lhs, const real_t &rhs); |
| /*! |
| * \brief elementwise subtraction |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator-(const NDArray &lhs, const NDArray &rhs); |
| /*! |
| * \brief elementwise subtraction |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator-(const NDArray &lhs, const real_t &rhs); |
| /*! |
| * \brief elementwise multiplication |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator*(const NDArray &lhs, const NDArray &rhs); \ |
| /*! |
| * \brief elementwise multiplication |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator*(const NDArray &lhs, const real_t &rhs); |
| /*! |
| * \brief elementwise division |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator/(const NDArray &lhs, const NDArray &rhs); |
| /*! |
| * \brief elementwise division |
| * \param lhs left operand |
| * \param rhs right operand |
| * \return a new result ndarray |
| */ |
| NDArray operator/(const NDArray &lhs, const real_t &rhs); |
| |
| /*! |
| * \brief Seed the random number generator. |
| * \param seed the seed to set to global random number generators. |
| */ |
| void RandomSeed(uint32_t seed); |
| /*! |
| * \brief Sample uniform distribution for each elements of out. |
| * \param begin lower bound of distribution. |
| * \param end upper bound of distribution. |
| * \param out output NDArray. |
| */ |
| void SampleUniform(real_t begin, real_t end, NDArray *out); |
| /*! |
| * \brief Sample gaussian distribution for each elements of out. |
| * \param mu mean of gaussian distribution. |
| * \param sigma standard deviation of gaussian distribution. |
| * \param out output NDArray. |
| */ |
| void SampleGaussian(real_t mu, real_t sigma, NDArray *out); |
| /*! |
| * \brief Sample gamma distribution for each elements of out. |
| * \param alpha parameter (shape) of the gamma distribution |
| * \param beta parameter (scale) of the gamma distribution |
| * \param out output NDArray. |
| */ |
| void SampleGamma(real_t alpha, real_t beta, NDArray *out); |
| /*! |
| * \brief Sample exponential distribution for each elements of out. |
| * \param lambda parameter (rate) of the exponential distribution |
| * \param out output NDArray. |
| */ |
| void SampleExponential(real_t lambda, NDArray *out); |
| /*! |
| * \brief Sample Poisson distribution for each elements of out. |
| * \param lambda parameter (rate) of the Poisson distribution |
| * \param out output NDArray. |
| */ |
| void SamplePoisson(real_t lambda, NDArray *out); |
| /*! |
| * \brief Sample negative binomial distribution for each elements of out. |
| * \param k failure limit |
| * \param p success probability |
| * \param out output NDArray. |
| */ |
| void SampleNegBinomial(int32_t k, real_t p, NDArray *out); |
| /*! |
| * \brief Sample generalized negative binomial distribution for each elements of out. |
| * \param mu parameter (mean) of the distribution |
| * \param alpha parameter (over dispersion) of the distribution |
| * \param out output NDArray. |
| */ |
| void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out); |
| |
| |
| //-------------------------------------------------------------- |
| // The following part are API Registration of NDArray functions. |
| //-------------------------------------------------------------- |
| |
| /*! \brief definition of NDArray function */ |
| typedef std::function<void (NDArray **used_vars, |
| real_t *scalars, |
| NDArray **mutate_vars, |
| int num_params, |
| char **param_keys, |
| char **param_vals)> NDArrayAPIFunction; |
| /*! \brief mask information on how functions can be exposed */ |
| enum NDArrayFunctionTypeMask { |
| /*! \brief all the use_vars should go before scalar */ |
| kNDArrayArgBeforeScalar = 1, |
| /*! \brief all the scalar should go before use_vars */ |
| kScalarArgBeforeNDArray = 1 << 1, |
| /*! |
| * \brief whether this function allows the handles in the target to |
| * be empty NDArray that are not yet initialized, and will initialize |
| * them when the function is invoked. |
| * |
| * most function should support this, except copy between different |
| * devices, which requires the NDArray to be pre-initialized with context |
| */ |
| kAcceptEmptyMutateTarget = 1 << 2 |
| }; |
| /*! \brief Registry entry for NDArrayFunction */ |
| struct NDArrayFunctionReg |
| : public dmlc::FunctionRegEntryBase<NDArrayFunctionReg, |
| NDArrayAPIFunction> { |
| /*! \brief number of variable used by this function */ |
| unsigned num_use_vars; |
| /*! \brief number of variable mutated by this function */ |
| unsigned num_mutate_vars; |
| /*! \brief number of scalars used by this function */ |
| unsigned num_scalars; |
| /*! \brief information on how function should be called from API */ |
| int type_mask; |
| /*! |
| * \brief constructor |
| */ |
| NDArrayFunctionReg() |
| : num_use_vars(0), |
| num_mutate_vars(0), |
| num_scalars(0), |
| type_mask(0) {} |
| /*! |
| * \brief set the function body to a NDArray setvalue function |
| * this will also auto set the parameters correctly |
| * \param fsetvalue function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs, |
| NDArray *out)) { |
| body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| (*fsetvalue)(s[0], mutate_vars[0]); |
| }; |
| num_mutate_vars = 1; num_scalars = 1; |
| this->add_argument("src", "real_t", "Source input to the function."); |
| return *this; |
| } |
| /*! |
| * \brief set the function body to a ternary NDArray function |
| * this will also auto set the parameters correctly |
| * \param fternary function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function(void(*fternary)(const NDArray &lhs, |
| const NDArray &mhs, |
| const NDArray &rhs, |
| NDArray *out)) { |
| body = [fternary](NDArray **used_vars, |
| real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| (*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]); |
| }; |
| num_use_vars = 3; num_mutate_vars = 1; |
| type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; |
| this->add_argument("lhs", "NDArray", "Left operand to the function."); |
| this->add_argument("mhs", "NDArray", "Middle operand to the function."); |
| this->add_argument("rhs", "NDArray", "Right operand to the function."); |
| return *this; |
| } |
| /*! |
| * \brief set the function body to a binary NDArray function |
| * this will also auto set the parameters correctly |
| * \param fbinary function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs, |
| const NDArray &rhs, |
| NDArray *out)) { |
| body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]); |
| }; |
| num_use_vars = 2; num_mutate_vars = 1; |
| type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; |
| this->add_argument("lhs", "NDArray", "Left operand to the function."); |
| this->add_argument("rhs", "NDArray", "Right operand to the function."); |
| return *this; |
| } |
| /*! |
| * \brief set the function body to a binary NDArray function |
| * this will also auto set the parameters correctly |
| * \param fscalar function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs, |
| const real_t &rhs, |
| NDArray *out)) { |
| body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| (*fscalar)(*used_vars[0], s[0], mutate_vars[0]); |
| }; |
| num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1; |
| type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; |
| this->add_argument("lhs", "NDArray", "Left operand to the function."); |
| this->add_argument("rhs", "real_t", "Right operand to the function."); |
| return *this; |
| } |
| /*! |
| * \brief set the function body to a unary NDArray function |
| * this will also auto set the parameters correctly |
| * \param funary function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src, |
| NDArray *out)) { |
| body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| (*funary)(*used_vars[0], mutate_vars[0]); |
| }; |
| num_use_vars = 1; num_mutate_vars = 1; |
| type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; |
| this->add_argument("src", "NDArray", "Source input to the function."); |
| return *this; |
| } |
| /*! |
| * \brief set the function body to a unary NDArray function |
| * this will also auto set the parameters correctly |
| * \param fgeneric function body to set |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_function( |
| void (*fgeneric)(NDArray **used_vars, |
| real_t *s, |
| NDArray **mutate_vars, |
| const std::map<std::string, std::string>& param)) { |
| body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars, |
| int num_params, char **param_keys, char **param_vals) { |
| std::map<std::string, std::string> param; |
| for (int i = 0; i < num_params; ++i) { |
| param[param_keys[i]] = param_vals[i]; |
| } |
| fgeneric(used_vars, s, mutate_vars, param); |
| }; |
| return *this; |
| } |
| /*! |
| * \brief set the number of mutate variables |
| * \param n number of mutate variablesx |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_num_use_vars(unsigned n) { |
| num_use_vars = n; return *this; |
| } |
| /*! |
| * \brief set the number of mutate variables |
| * \param n number of mutate variablesx |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_num_mutate_vars(unsigned n) { |
| num_mutate_vars = n; return *this; |
| } |
| /*! |
| * \brief set the number of scalar arguments |
| * \param n number of scalar arguments |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_num_scalars(unsigned n) { |
| num_scalars = n; return *this; |
| } |
| /*! |
| * \brief set type mask |
| * \param tmask typemask |
| * \return ref to the registered entry, used to set properties |
| */ |
| inline NDArrayFunctionReg &set_type_mask(int tmask) { |
| type_mask = tmask; return *this; |
| } |
| }; // NDArrayFunctionReg |
| |
| /*! |
| * \brief Macro to register NDArray function |
| * |
| * Example: the following code is example to register a plus |
| * \code |
| * |
| * REGISTER_NDARRAY_FUN(Plus) |
| * .set_function(Plus); |
| * |
| * \endcode |
| */ |
| #define MXNET_REGISTER_NDARRAY_FUN(name) \ |
| DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) |
| |
| } // namespace mxnet |
| |
| namespace dmlc { |
| /*!\brief traits */ |
| DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true); |
| } // namespace dmlc |
| #endif // MXNET_NDARRAY_H_ |