| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file elemwise_op_common.h |
| * \brief common function used for broadcasting and reducing |
| * \author Xingjian Shi |
| */ |
| #ifndef MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_ |
| #define MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_ |
| #include <dmlc/logging.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/operator_util.h> |
| #include <mxnet/op_attr_types.h> |
| #include <nnvm/op.h> |
| #include <nnvm/node.h> |
| #include <nnvm/op_attr_types.h> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include "./operator_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| template<typename AttrType, bool (*is_none)(const AttrType&), |
| bool (*assign)(AttrType*, const AttrType&), bool reverse_infer, |
| std::string (*attr_string)(const AttrType&), |
| int n_in = -1, int n_out = -1> |
| inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, |
| std::vector<AttrType> *in_attrs, |
| std::vector<AttrType> *out_attrs, |
| const AttrType& none) { |
| AttrType dattr = none; |
| size_t in_size = in_attrs->size(); |
| size_t out_size = out_attrs->size(); |
| if (n_in != -1) |
| in_size = static_cast<size_t>(n_in); |
| if (n_out != -1) |
| out_size = static_cast<size_t>(n_out); |
| |
| auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) { |
| for (size_t i = 0; i < size; ++i) { |
| CHECK(assign(&dattr, (*vec)[i])) |
| << "Incompatible attr in node " << attrs.name << " at " << i << "-th " |
| << name << ": " << "expected " << attr_string(dattr) |
| << ", got " << attr_string((*vec)[i]); |
| } |
| }; |
| deduce(in_attrs, in_size, "input"); |
| if (reverse_infer) deduce(out_attrs, out_size, "output"); |
| |
| auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) { |
| for (size_t i = 0; i < size; ++i) { |
| CHECK(assign(&(*vec)[i], dattr)) |
| << "Incompatible attr in node " << attrs.name << " at " << i << "-th " |
| << name << ": " << "expected " << attr_string(dattr) |
| << ", got " << attr_string((*vec)[i]); |
| } |
| }; |
| write(in_attrs, in_size, "input"); |
| write(out_attrs, out_size, "output"); |
| |
| if (is_none(dattr)) return false; |
| return true; |
| } |
| |
| template<int n_in, int n_out> |
| inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, |
| std::vector<TShape> *in_attrs, |
| std::vector<TShape> *out_attrs) { |
| CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name; |
| CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; |
| return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string>( |
| attrs, in_attrs, out_attrs, TShape()); |
| } |
| |
| template<int n_in, int n_out> |
| inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_attrs, |
| std::vector<int> *out_attrs) { |
| CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name; |
| CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; |
| return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>( |
| attrs, in_attrs, out_attrs, -1); |
| } |
| |
| // Transfer gradient and input to FGradient function |
| struct ElemwiseGradUseIn { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| return MakeNonlossGradNode(op_name, n, ograds, n->inputs, n->attrs.dict); |
| } |
| }; |
| |
| // Transfer gradient and output to FGradient function |
| struct ElemwiseGradUseOut { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> heads; |
| index_t n_out = n->num_outputs(); |
| for (index_t i = 0; i < n_out; ++i) { |
| heads.emplace_back(nnvm::NodeEntry{n, i, 0}); |
| } |
| return MakeNonlossGradNode(op_name, n, ograds, heads, n->attrs.dict); |
| } |
| }; |
| |
| // Transfer gradient and input and output to FGradient function |
| struct ElemwiseGradUseInOut { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end()); |
| for (auto& h : n->inputs) { |
| heads.push_back(h); |
| } |
| index_t n_out = n->num_outputs(); |
| for (index_t i = 0; i < n_out; ++i) { |
| heads.emplace_back(nnvm::NodeEntry{n, i, 0}); |
| } |
| return MakeGradNode(op_name, n, heads, n->attrs.dict); |
| } |
| }; |
| |
| // Transfer only gradient to FGradient function |
| struct ElemwiseGradUseNone { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| return MakeNonlossGradNode(op_name, n, ograds, {}, n->attrs.dict); |
| } |
| }; |
| |
| struct CloneGradient { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> ret; |
| for (size_t i = 0; i < n->inputs.size(); ++i) |
| ret.emplace_back(ograds[0]); |
| return ret; |
| } |
| }; |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_ |