| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file operator_util.h |
| * \brief Utility functions and registries to help quickly build new operators. |
| * [Deprecated] |
| * Use the register functions in this file when possible to simplify operator creations. |
| * Operators registered in this file will be exposed to both NDArray API and symbolic API. |
| * |
| * \author Tianqi Chen |
| */ |
| #ifndef MXNET_OPERATOR_UTIL_H_ |
| #define MXNET_OPERATOR_UTIL_H_ |
| |
| #ifdef _MSC_VER |
| #pragma warning(disable:4503) // disable warning: decorated name length exceeded. |
| #endif |
| |
| #include <dmlc/registry.h> |
| #include <dmlc/parameter.h> |
| #include <map> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include "./base.h" |
| #include "./operator.h" |
| |
| #if DMLC_USE_CXX11 |
| #include <functional> |
| #endif |
| |
| namespace mxnet { |
| /*! \brief namespace of arguments */ |
| namespace op { |
| /*! \brief super class of all gradient function argument */ |
| struct GradFunctionArgument { |
| /*! \brief The real data */ |
| TBlob data; |
| }; |
| |
| /*! \brief First input to the function */ |
| struct Input0 : GradFunctionArgument {}; |
| /*! \brief Second input to the function */ |
| struct Input1 : GradFunctionArgument {}; |
| |
| /*! \brief Ouput value of the function to the function */ |
| struct OutputValue : GradFunctionArgument {}; |
| /*! \brief Gradient of output value */ |
| struct OutputGrad : GradFunctionArgument {}; |
| |
| /*! |
| * \brief Environment arguments that is used by the function. |
| * These can be things like scalar arguments when add a value with scalar. |
| */ |
| struct EnvArguments { |
| /*! \brief scalar argument, if enabled */ |
| real_t scalar; |
| /*! \brief keyword arguments */ |
| std::vector<std::pair<std::string, std::string> > kwargs; |
| /*! \brief pointer to the resources requested */ |
| std::vector<Resource> resource; |
| }; |
| |
| /*! |
| * \brief source function that generate output based on env |
| * The result container is pre-allocated with the correct shape. |
| * \param env The Environment arguments. |
| * \param ret The containter to store return value. |
| * \param req The requirement to stroe the ret. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*SourceFunction)(const EnvArguments& env, |
| TBlob* ret, |
| OpReqType req, |
| RunContext ctx); |
| |
| /*! |
| * \brief Shape inference function to get the correct shape. |
| * \param env The Environment arguments. |
| * \return The inferred result shape. |
| */ |
| typedef TShape (*SourceShapeFunction)(const EnvArguments& env); |
| |
| /*! |
| * \brief Unary function that takes a src and save result to ret. |
| * The result container is pre-allocated with the correct shape. |
| * \param src The source data. |
| * \param env The Environment arguments. |
| * \param ret The containter to store return value. |
| * \param req The requirement to stroe the ret. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*UnaryFunction)(const TBlob& src, |
| const EnvArguments& env, |
| TBlob* ret, |
| OpReqType req, |
| RunContext ctx); |
| /*! |
| * \brief Shape inference function to get the correct shape given source. |
| * \param src The source shape |
| * \param env The Environment arguments. |
| * \return The inferred result shape. |
| */ |
| typedef TShape (*UnaryShapeFunction)(const TShape& src, |
| const EnvArguments& env); |
| |
| /*! |
| * \brief Gradient function that takes output value of function and computes gradient wrt to input. |
| * \param out_grad the gradient wrt to output of the function. |
| * \param env The Environment arguments. |
| * \param in_grad The container to store result input gradient. |
| * \param req The requirement to store the ret value. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad, |
| const EnvArguments& env, |
| TBlob* in_grad, |
| OpReqType req, |
| RunContext ctx); |
| /*! |
| * \brief Gradient function that takes output value of function and computes gradient wrt to input. |
| * \param out_grad the gradient wrt to output of the function. |
| * \param out_value the value of the function. |
| * \param env The Environment arguments. |
| * \param in_grad The container to store result input gradient. |
| * \param req The requirement to store the ret value. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad, |
| const OutputValue& out_value, |
| const EnvArguments& env, |
| TBlob* in_grad, |
| OpReqType req, |
| RunContext ctx); |
| /*! |
| * \brief Gradient function that takes input value of function and computes gradient wrt to input. |
| * \param out_grad the gradient wrt to output of the function. |
| * \param in_data0 the input value of the function. |
| * \param env The Environment arguments. |
| * \param in_grad The container to store result input gradient. |
| * \param req The requirement to store the ret value. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad, |
| const Input0& in_data0, |
| const EnvArguments& env, |
| TBlob* in_grad, |
| OpReqType req, |
| RunContext ctx); |
| /*! |
| * \brief Binary function that takes lhs, rhs and save result to ret. |
| * The result container is pre-allocated with the correct shape. |
| * \param lhs The left operand |
| * \param rhs The right operand |
| * \param env The Environment arguments. |
| * \param ret The containter to store return value. |
| * \param req The requirement to stroe the ret. |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*BinaryFunction)(const TBlob& lhs, |
| const TBlob& rhs, |
| const EnvArguments& env, |
| TBlob* ret, |
| OpReqType req, |
| RunContext ctx); |
| |
| /*! |
| * \brief Shape inference function to get the correct shape given source shapes. |
| * \param lhs The shape of left operand. |
| * \param rhs The shape of right operand. |
| * \param env The Environment arguments. |
| * \return The inferred result shape. |
| */ |
| typedef TShape (*BinaryShapeFunction)(const TShape& lhs, |
| const TShape& rhs, |
| const EnvArguments& env); |
| /*! |
| * \brief Gradient function that takes only output gradient and computes gradient wrt to input. |
| * We support total gradient as a whole to make it easy to combine a few ops. |
| * \param out_grad the gradient wrt to output of the function. |
| * \param env The Environment arguments. |
| * \param lhs_grad The container to store result of lhs gradient. |
| * \param rhs_grad The container to store result of lhs gradient. |
| * \param req_lhs_grad The requirement to store the lhs_grad |
| * \param req_rhs_grad The requirement to store the rhs_grad |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad, |
| const EnvArguments& env, |
| TBlob* lhs_grad, |
| TBlob* rhs_grad, |
| OpReqType req_lhs_grad, |
| OpReqType req_rhs_grad, |
| RunContext ctx); |
| /*! |
| * \brief Gradient function that takes inputs of function anod computes gradient wrt to input. |
| * \param out_grad the gradient wrt to output of the function. |
| * \param lhs The left operand to the function. |
| * \param rhs The right operand to the function. |
| * \param env The Environment arguments. |
| * \param lhs_grad The container to store result of lhs gradient. |
| * \param rhs_grad The container to store result of lhs gradient. |
| * \param req_lhs_grad The requirement to store the lhs_grad |
| * \param req_rhs_grad The requirement to store the rhs_grad |
| * \param ctx Runtime context to execute the function. |
| */ |
| typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad, |
| const Input0& lhs, |
| const Input1& rhs, |
| const EnvArguments& env, |
| TBlob* lhs_grad, |
| TBlob* rhs_grad, |
| OpReqType req_lhs_grad, |
| OpReqType req_rhs_grad, |
| RunContext ctx); |
| |
| /*! \brief options in the registry to set inplace of operator */ |
| enum SimpleOpInplaceOption { |
| /*! \brief do not allow inplace in arguments */ |
| kNoInplace, |
| /*! \brief in unary forward, allow inplace in with out */ |
| kInplaceInOut, |
| /*! \brief in unary backward, allow inplace out_grad with in_grad */ |
| kInplaceOutIn, |
| /*! \brief in binary forward, allow inplace left operand with out */ |
| kInplaceLhsOut, |
| /*! \brief in binary backward, allow inplace out_grad with lhs_grad */ |
| kInplaceOutLhs |
| }; |
| |
| /*! \brief options in the registry to set symbolic registration */ |
| enum SimpleOpScalarOption { |
| kScalarBeforeArray, |
| kArrayBeforeScalar |
| }; |
| |
| /*! \brief options in the registry to set symbolic registration */ |
| enum SimpleOpRegOption { |
| kNotRegisterSymbolic, |
| kRegisterSymbolic |
| }; |
| |
| /*! \brief registry entry to register simple operators via functions. */ |
| class SimpleOpRegEntry { |
| public: |
| /*! \brief declare self type */ |
| typedef SimpleOpRegEntry TSelf; |
| /*! \brief name of the operator */ |
| std::string name; |
| /*! |
| * \brief set a seperate name for symbol |
| * This must be called before set_function. |
| * Default: this is set to be same as the name of operator. |
| * \param symbol_name the name of symbolic operator. |
| */ |
| virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0; |
| /*! |
| * \brief set number of scalar arguments needed to be passed in env |
| * A function cannot have both kwargs and scalar arguments. |
| * Default: this is set to false |
| * \param enable_scalar whether to enable scalar argument |
| * \param type_mask the position of the scalar argument. |
| */ |
| virtual TSelf& set_enable_scalar( |
| bool enable_scalar, |
| SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0; |
| /*! |
| * \brief set whether to enable kwargs |
| * A function cannot have both kwargs and scalar arguments. |
| * Default: this is set to false |
| * \param enable_kwargs whether to enable kwargs |
| */ |
| virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0; |
| /*! |
| * \brief set resource request |
| * By default there is no resource request. |
| * The resource will be presented in both forward and backward. |
| * \param reqs the request. |
| */ |
| virtual TSelf& set_resource_request( |
| const std::vector<ResourceRequest>& reqs) = 0; |
| /*! |
| * \brief set resource request |
| * By default there is no resource request. |
| * The resource will be presented in both forward and backward. |
| * \param req the request. |
| */ |
| virtual TSelf& set_resource_request(ResourceRequest req) = 0; |
| /*! |
| * \brief set source inference function. |
| * \param fshapeinfer The source function that peforms the operation. |
| */ |
| virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0; |
| /*! |
| * \brief set shape inference function. |
| * Default: out_shape = in_shape |
| * \param fshapeinfer The unary function that peforms the operation. |
| */ |
| virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0; |
| /*! |
| * \brief set shape inference function to be the binary inference function |
| * Default: out_shape = lhs_shape, and lhs_shape must equal rhs_shape. |
| * \param fshapeinfer The binary function that peforms the operation. |
| */ |
| virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0; |
| /*! |
| * \brief set function of the function to be fsource |
| * \param dev_mask The device mask of the function can act on. |
| * \param fsource The unary function that peforms the operation. |
| * \param register_symbolic Whether register a symbolic operator as well. |
| */ |
| virtual TSelf& set_function( |
| int dev_mask, |
| SourceFunction fsource, |
| SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0; |
| /*! |
| * \brief set function of the function to be funary |
| * \param dev_mask The device mask of the function can act on. |
| * \param funary The unary function that peforms the operation. |
| * \param inplace_in_out Whether do inplace optimization on in and out. |
| * \param register_symbolic Whether register a symbolic operator as well. |
| */ |
| virtual TSelf& set_function( |
| int dev_mask, |
| UnaryFunction funary, |
| SimpleOpInplaceOption inplace_in_out, |
| SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0; |
| /*! |
| * \brief set function of the function to be funary |
| * \param dev_mask The device mask of the function can act on. |
| * \param fbinary The binary function that peforms the operation. |
| * \param inplace_lhs_out Whether do inplace optimization on lhs and out. |
| * \param register_symbolic Whether register a symbolic operator as well. |
| */ |
| virtual TSelf& set_function( |
| int dev_mask, |
| BinaryFunction fbinary, |
| SimpleOpInplaceOption inplace_lhs_out, |
| SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0; |
| /*! |
| * \brief set gradient of the function of this function. |
| * \param dev_mask The device mask of the function can act on. |
| * \param fgrad The gradient function to be set. |
| * \param inplace_out_in_grad whether out_grad and in_grad can share memory. |
| */ |
| virtual TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT0 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) = 0; |
| /*! |
| * \brief set gradient of the function of this function. |
| * \param dev_mask The device mask of the function can act on. |
| * \param fgrad The gradient function to be set. |
| * \param inplace_out_in_grad whether out_grad and in_grad can share memory. |
| */ |
| virtual TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT1 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) = 0; |
| /*! |
| * \brief set gradient of the function of this function. |
| * \param dev_mask The device mask of the function can act on. |
| * \param fgrad The gradient function to be set. |
| * \param inplace_out_in_grad whether out_grad and in_grad can share memory. |
| */ |
| virtual TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT2 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) = 0; |
| /*! |
| * \brief set gradient of the function of this function. |
| * \param dev_mask The device mask of the function can act on. |
| * \param fgrad The gradient function to be set. |
| * \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory. |
| */ |
| virtual TSelf& set_gradient(int dev_mask, |
| BinaryGradFunctionT0 fgrad, |
| SimpleOpInplaceOption inplace_out_lhs_grad) = 0; |
| /*! |
| * \brief set gradient of the function of this function. |
| * \param dev_mask The device mask of the function can act on. |
| * \param fgrad The gradient function to be set. |
| * \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory. |
| */ |
| virtual TSelf& set_gradient(int dev_mask, |
| BinaryGradFunctionT1 fgrad, |
| SimpleOpInplaceOption inplace_out_lhs_grad) = 0; |
| /*! |
| * \brief Describe the function. |
| * \param description The description of the function. |
| * \return reference to self. |
| */ |
| virtual TSelf& describe(const std::string &description) = 0; |
| /*! |
| * \brief Describe the function. |
| * \param args argument information. |
| * Add additional arguments to the function. |
| * \return reference to self. |
| */ |
| virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0; |
| /*! \brief virtual destructor */ |
| virtual ~SimpleOpRegEntry() {} |
| }; |
| |
| /*! \brief registry for TBlob functions */ |
| class SimpleOpRegistry { |
| public: |
| /*! |
| * \brief Internal function to register a name function under name. |
| * \param name name of the function |
| * \return ref to the registered entry, used to set properties |
| */ |
| SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name); |
| /*! |
| * \brief Find the entry with corresponding name. |
| * \param name name of the function |
| * \return the corresponding function, can be NULL |
| */ |
| inline static const SimpleOpRegEntry *Find(const std::string &name) { |
| return Get()->fmap_.at(name); |
| } |
| /*! \return global singleton of the registry */ |
| static SimpleOpRegistry* Get(); |
| |
| private: |
| // destructor |
| ~SimpleOpRegistry(); |
| /*! \brief internal registry map */ |
| std::map<std::string, SimpleOpRegEntry*> fmap_; |
| }; |
| |
| /*! |
| * \brief assign the expression to out according to request |
| * \param out the data to be assigned |
| * \param req the assignment request |
| * \param exp the expression |
| * \tparam OType output type |
| * \tparam Exp expression type |
| */ |
| #define ASSIGN_DISPATCH(out, req, exp) \ |
| { \ |
| switch (req) { \ |
| case kNullOp: \ |
| break; \ |
| case kWriteTo: \ |
| case kWriteInplace: \ |
| (out) = (exp); \ |
| break; \ |
| case kAddTo: \ |
| (out) += (exp); \ |
| break; \ |
| default: \ |
| LOG(FATAL) << "not reached"; \ |
| } \ |
| } |
| |
| /*! |
| * \brief Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs |
| */ |
| #define MXNET_SPECIAL_MAX_NDIM 5 |
| |
| |
| //-------------------------------------------------------------- |
| // The following part are API Registration of Simple Operators |
| //-------------------------------------------------------------- |
| /*! |
| * \brief Macro to register simple operator to both imperative and symbolic API. |
| * |
| * see src/operator/elementwise_unary_op-inl.h for example |
| * |
| * \code |
| * // example of registering a sigmoid operator on GPU |
| * // MySigmoid is of type UnaryFunction, |
| * // MySigmoidGrad is of type UnaryGradFunctionT2 |
| * |
| * MXNET_REGISTER_SIMPLE_OP(sigmoid, cpu) |
| * .set_function(MySigmoid<gpu>, true) |
| * .set_gradient(MySigmoidGrad<gpu>, true) |
| * .describe("Sigmoid function"); |
| * |
| * \endcode |
| */ |
| #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \ |
| static ::mxnet::op::SimpleOpRegEntry & \ |
| __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \ |
| ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name) |
| |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_UTIL_H_ |