blob: 713dd2d1703a3a24b74d0ca451c39ed468b16917 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file operator_common.h
* \brief common internal header of most operators
* this header includes utility functions operator can use
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_
#define MXNET_OPERATOR_OPERATOR_COMMON_H_
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <mxnet/base.h>
#include <istream>
#include <ostream>
#include <string>
#include <vector>
namespace mxnet {
namespace op {
/*!
* \brief assign the expression to out according to request
* \param out the data to be assigned
* \param req the assignment request
* \param exp the expression
* \tparam OType output type
* \tparam Exp expression type
*/
#define Assign(out, req, exp) \
{ \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
(out) = (exp); \
break; \
case kAddTo: \
(out) += (exp); \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}
/*! \brief exception throwed by InferShape error */
struct InferShapeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferShapeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief exception throwed by InferShape error */
struct InferTypeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferTypeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief check if shape is empty or contains unkown (0) dim. */
inline bool shape_is_none(const TShape& x) {
return x.ndim() == 0 || x.Size() == 0;
}
/*! \brief check if type is none (-1) */
inline bool type_is_none(const int& x) {
return x == -1;
}
/*!
* \brief Assign x to y. Checks for compatiblity when y is not empty.
* Allow missing dim in both x and y (as 0).
* \param y target shape.
* \param x source shape.
* \return whether x and y are compatible.
*/
inline bool shape_assign(TShape *y, const TShape& x) {
if (y->ndim() == 0) {
*y = x;
return true;
} else if (y->ndim() != x.ndim()) {
return x.ndim() == 0;
} else {
for (size_t i = 0; i < y->ndim(); ++i) {
if ((*y)[i] == 0) {
(*y)[i] = x[i];
} else if ((*y)[i] != x[i] && x[i] != 0) {
return false;
}
}
return true;
}
}
/*!
* \brief Assign x to y. Checks for compatiblity when y is not -1.
* \param y target type.
* \param x source type.
* \return whether x and y are compatible.
*/
inline bool type_assign(int *y, const int& x) {
if (*y == -1) {
*y = x;
return true;
} else if (*y != x && x != -1) {
return false;
}
return true;
}
/*!
* \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly
* \param shape_array the shape array to store the result
* \param index the index of in the array
* \param shape the inferred shape
*/
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided=" << (shape_array)[index]<< ',' \
<< " inferred shape=" << shape; \
throw ::mxnet::op::InferShapeError(os.str(), index); \
} \
}
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the type array to store the result
* \param index the index of in the array
* \param type the inferred type
*/
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Type inconsistent, Provided=" << (type_array)[index] << ',' \
<< " inferred type=" << type; \
throw ::mxnet::op::InferTypeError(os.str(), index); \
} \
}
// helper macro to implement bind dispatch
#if MXNET_USE_CUDA
#define DO_BIND_DISPATCH(Method, ...) \
if (ctx.dev_mask() == cpu::kDevMask) { \
return Method<cpu>(__VA_ARGS__); \
} else { \
return Method<gpu>(__VA_ARGS__); \
}
#else
#define DO_BIND_DISPATCH(Method, ...) \
if (ctx.dev_mask() == cpu::kDevMask) { \
return Method<cpu>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "GPU is not enabled"; \
return nullptr; \
}
#endif
// describe op registration point
// TODO(eric): move to dmlc-core
#define STRINGIZE_DETAIL(x) #x
#define STRINGIZE(x) STRINGIZE_DETAIL(x)
#define MXNET_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" STRINGIZE(__LINE__))
#define ADD_FILELINE "\n\nDefined in " __FILE__ ":L" STRINGIZE(__LINE__)
// quick helper to make node
inline std::vector<nnvm::NodeEntry> MakeGradNode(
const char* op_name,
const nnvm::NodePtr& n,
std::vector<nnvm::NodeEntry> inputs,
std::unordered_map<std::string, std::string> dict) {
nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = n->attrs.name + "_backward";
p->attrs.dict = std::move(dict);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
p->control_deps.emplace_back(n);
p->inputs = std::move(inputs);
std::vector<nnvm::NodeEntry> ret;
for (index_t i = 0; i < p->num_outputs(); ++i) {
ret.emplace_back(nnvm::NodeEntry{p, i, 0});
}
return ret;
}
// quick helper to make gradient nodes that simply pass back zero. could be used in output ops.
inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(
const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
for (index_t i = 0; i < n->num_inputs(); ++i) {
nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_zeros");
std::ostringstream os;
if (1 == n->num_inputs()) {
os << n->attrs.name << "_backward";
} else {
os << n->attrs.name << "_in" << i << "_backward";
}
p->attrs.name = os.str();
p->attrs.dict = std::unordered_map<std::string, std::string>();
p->control_deps.emplace_back(n);
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
}
return ret;
}
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_