blob: 0f27b10368cfdd43345961875662ccba24c8390b [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file operator_util.h
* \brief Utility functions and registries to help quickly build new operators.
* [Deprecated]
* Use the register functions in this file when possible to simplify operator creations.
* Operators registered in this file will be exposed to both NDArray API and symbolic API.
*
* \author Tianqi Chen
*/
#ifndef MXNET_OPERATOR_UTIL_H_
#define MXNET_OPERATOR_UTIL_H_
#ifdef _MSC_VER
#pragma warning(disable:4503) // disable warning: decorated name length exceeded.
#endif
#include <dmlc/registry.h>
#include <dmlc/parameter.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"
#include "./operator.h"
#if DMLC_USE_CXX11
#include <functional>
#endif
namespace mxnet {
/*! \brief namespace of arguments */
namespace op {
/*! \brief super class of all gradient function argument */
struct GradFunctionArgument {
/*! \brief The real data */
TBlob data;
};
/*! \brief First input to the function */
struct Input0 : GradFunctionArgument {};
/*! \brief Second input to the function */
struct Input1 : GradFunctionArgument {};
/*! \brief Ouput value of the function to the function */
struct OutputValue : GradFunctionArgument {};
/*! \brief Gradient of output value */
struct OutputGrad : GradFunctionArgument {};
/*!
* \brief Environment arguments that is used by the function.
* These can be things like scalar arguments when add a value with scalar.
*/
struct EnvArguments {
/*! \brief scalar argument, if enabled */
real_t scalar;
/*! \brief keyword arguments */
std::vector<std::pair<std::string, std::string> > kwargs;
/*! \brief pointer to the resources requested */
std::vector<Resource> resource;
};
/*!
* \brief source function that generate output based on env
* The result container is pre-allocated with the correct shape.
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*SourceFunction)(const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape.
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*SourceShapeFunction)(const EnvArguments& env);
/*!
* \brief Unary function that takes a src and save result to ret.
* The result container is pre-allocated with the correct shape.
* \param src The source data.
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryFunction)(const TBlob& src,
const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape given source.
* \param src The source shape
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*UnaryShapeFunction)(const TShape& src,
const EnvArguments& env);
/*!
* \brief Gradient function that takes output value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Gradient function that takes output value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param out_value the value of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
const OutputValue& out_value,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Gradient function that takes input value of function and computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param in_data0 the input value of the function.
* \param env The Environment arguments.
* \param in_grad The container to store result input gradient.
* \param req The requirement to store the ret value.
* \param ctx Runtime context to execute the function.
*/
typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
const Input0& in_data0,
const EnvArguments& env,
TBlob* in_grad,
OpReqType req,
RunContext ctx);
/*!
* \brief Binary function that takes lhs, rhs and save result to ret.
* The result container is pre-allocated with the correct shape.
* \param lhs The left operand
* \param rhs The right operand
* \param env The Environment arguments.
* \param ret The containter to store return value.
* \param req The requirement to stroe the ret.
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryFunction)(const TBlob& lhs,
const TBlob& rhs,
const EnvArguments& env,
TBlob* ret,
OpReqType req,
RunContext ctx);
/*!
* \brief Shape inference function to get the correct shape given source shapes.
* \param lhs The shape of left operand.
* \param rhs The shape of right operand.
* \param env The Environment arguments.
* \return The inferred result shape.
*/
typedef TShape (*BinaryShapeFunction)(const TShape& lhs,
const TShape& rhs,
const EnvArguments& env);
/*!
* \brief Gradient function that takes only output gradient and computes gradient wrt to input.
* We support total gradient as a whole to make it easy to combine a few ops.
* \param out_grad the gradient wrt to output of the function.
* \param env The Environment arguments.
* \param lhs_grad The container to store result of lhs gradient.
* \param rhs_grad The container to store result of lhs gradient.
* \param req_lhs_grad The requirement to store the lhs_grad
* \param req_rhs_grad The requirement to store the rhs_grad
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
const EnvArguments& env,
TBlob* lhs_grad,
TBlob* rhs_grad,
OpReqType req_lhs_grad,
OpReqType req_rhs_grad,
RunContext ctx);
/*!
* \brief Gradient function that takes inputs of function anod computes gradient wrt to input.
* \param out_grad the gradient wrt to output of the function.
* \param lhs The left operand to the function.
* \param rhs The right operand to the function.
* \param env The Environment arguments.
* \param lhs_grad The container to store result of lhs gradient.
* \param rhs_grad The container to store result of lhs gradient.
* \param req_lhs_grad The requirement to store the lhs_grad
* \param req_rhs_grad The requirement to store the rhs_grad
* \param ctx Runtime context to execute the function.
*/
typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
const Input0& lhs,
const Input1& rhs,
const EnvArguments& env,
TBlob* lhs_grad,
TBlob* rhs_grad,
OpReqType req_lhs_grad,
OpReqType req_rhs_grad,
RunContext ctx);
/*! \brief options in the registry to set inplace of operator */
enum SimpleOpInplaceOption {
/*! \brief do not allow inplace in arguments */
kNoInplace,
/*! \brief in unary forward, allow inplace in with out */
kInplaceInOut,
/*! \brief in unary backward, allow inplace out_grad with in_grad */
kInplaceOutIn,
/*! \brief in binary forward, allow inplace left operand with out */
kInplaceLhsOut,
/*! \brief in binary backward, allow inplace out_grad with lhs_grad */
kInplaceOutLhs
};
/*! \brief options in the registry to set symbolic registration */
enum SimpleOpScalarOption {
kScalarBeforeArray,
kArrayBeforeScalar
};
/*! \brief options in the registry to set symbolic registration */
enum SimpleOpRegOption {
kNotRegisterSymbolic,
kRegisterSymbolic
};
/*! \brief registry entry to register simple operators via functions. */
class SimpleOpRegEntry {
public:
/*! \brief declare self type */
typedef SimpleOpRegEntry TSelf;
/*! \brief name of the operator */
std::string name;
/*!
* \brief set a seperate name for symbol
* This must be called before set_function.
* Default: this is set to be same as the name of operator.
* \param symbol_name the name of symbolic operator.
*/
virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
/*!
* \brief set number of scalar arguments needed to be passed in env
* A function cannot have both kwargs and scalar arguments.
* Default: this is set to false
* \param enable_scalar whether to enable scalar argument
* \param type_mask the position of the scalar argument.
*/
virtual TSelf& set_enable_scalar(
bool enable_scalar,
SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
/*!
* \brief set whether to enable kwargs
* A function cannot have both kwargs and scalar arguments.
* Default: this is set to false
* \param enable_kwargs whether to enable kwargs
*/
virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
/*!
* \brief set resource request
* By default there is no resource request.
* The resource will be presented in both forward and backward.
* \param reqs the request.
*/
virtual TSelf& set_resource_request(
const std::vector<ResourceRequest>& reqs) = 0;
/*!
* \brief set resource request
* By default there is no resource request.
* The resource will be presented in both forward and backward.
* \param req the request.
*/
virtual TSelf& set_resource_request(ResourceRequest req) = 0;
/*!
* \brief set source inference function.
* \param fshapeinfer The source function that peforms the operation.
*/
virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0;
/*!
* \brief set shape inference function.
* Default: out_shape = in_shape
* \param fshapeinfer The unary function that peforms the operation.
*/
virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
/*!
* \brief set shape inference function to be the binary inference function
* Default: out_shape = lhs_shape, and lhs_shape must equal rhs_shape.
* \param fshapeinfer The binary function that peforms the operation.
*/
virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
/*!
* \brief set function of the function to be fsource
* \param dev_mask The device mask of the function can act on.
* \param fsource The unary function that peforms the operation.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
SourceFunction fsource,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set function of the function to be funary
* \param dev_mask The device mask of the function can act on.
* \param funary The unary function that peforms the operation.
* \param inplace_in_out Whether do inplace optimization on in and out.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
UnaryFunction funary,
SimpleOpInplaceOption inplace_in_out,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set function of the function to be funary
* \param dev_mask The device mask of the function can act on.
* \param fbinary The binary function that peforms the operation.
* \param inplace_lhs_out Whether do inplace optimization on lhs and out.
* \param register_symbolic Whether register a symbolic operator as well.
*/
virtual TSelf& set_function(
int dev_mask,
BinaryFunction fbinary,
SimpleOpInplaceOption inplace_lhs_out,
SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT0 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT1 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_in_grad whether out_grad and in_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
UnaryGradFunctionT2 fgrad,
SimpleOpInplaceOption inplace_out_in_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
BinaryGradFunctionT0 fgrad,
SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
/*!
* \brief set gradient of the function of this function.
* \param dev_mask The device mask of the function can act on.
* \param fgrad The gradient function to be set.
* \param inplace_out_lhs_grad whether out_grad and lhs_grad can share memory.
*/
virtual TSelf& set_gradient(int dev_mask,
BinaryGradFunctionT1 fgrad,
SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
/*!
* \brief Describe the function.
* \param description The description of the function.
* \return reference to self.
*/
virtual TSelf& describe(const std::string &description) = 0;
/*!
* \brief Describe the function.
* \param args argument information.
* Add additional arguments to the function.
* \return reference to self.
*/
virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0;
/*! \brief virtual destructor */
virtual ~SimpleOpRegEntry() {}
};
/*! \brief registry for TBlob functions */
class SimpleOpRegistry {
public:
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name);
/*!
* \brief Find the entry with corresponding name.
* \param name name of the function
* \return the corresponding function, can be NULL
*/
inline static const SimpleOpRegEntry *Find(const std::string &name) {
return Get()->fmap_.at(name);
}
/*! \return global singleton of the registry */
static SimpleOpRegistry* Get();
private:
// destructor
~SimpleOpRegistry();
/*! \brief internal registry map */
std::map<std::string, SimpleOpRegEntry*> fmap_;
};
/*!
* \brief assign the expression to out according to request
* \param out the data to be assigned
* \param req the assignment request
* \param exp the expression
* \tparam OType output type
* \tparam Exp expression type
*/
#define ASSIGN_DISPATCH(out, req, exp) \
{ \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
(out) = (exp); \
break; \
case kAddTo: \
(out) += (exp); \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}
/*!
* \brief Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs
*/
#define MXNET_SPECIAL_MAX_NDIM 5
//--------------------------------------------------------------
// The following part are API Registration of Simple Operators
//--------------------------------------------------------------
/*!
* \brief Macro to register simple operator to both imperative and symbolic API.
*
* see src/operator/elementwise_unary_op-inl.h for example
*
* \code
* // example of registering a sigmoid operator on GPU
* // MySigmoid is of type UnaryFunction,
* // MySigmoidGrad is of type UnaryGradFunctionT2
*
* MXNET_REGISTER_SIMPLE_OP(sigmoid, cpu)
* .set_function(MySigmoid<gpu>, true)
* .set_gradient(MySigmoidGrad<gpu>, true)
* .describe("Sigmoid function");
*
* \endcode
*/
#define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
static ::mxnet::op::SimpleOpRegEntry & \
__make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_UTIL_H_