blob: 693222bba9d1653db1407c4b66252aa75cde30aa [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file elemwise_op_common.h
* \brief common function used for broadcasting and reducing
* \author Xingjian Shi
*/
#ifndef MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_
#define MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_
#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <utility>
#include "./operator_common.h"
namespace mxnet {
namespace op {
template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
AttrType dattr = none;
auto deduce = [&](std::vector<AttrType> *vec, const char *name) {
for (size_t i = 0; i < vec->size(); ++i) {
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << dattr << ", got " << (*vec)[i];
}
};
deduce(in_attrs, "input");
if (reverse_infer) deduce(out_attrs, "output");
auto write = [&](std::vector<AttrType> *vec, const char *name) {
for (size_t i = 0; i < vec->size(); ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << dattr << ", got " << (*vec)[i];
}
};
write(in_attrs, "input");
write(out_attrs, "output");
if (is_none(dattr)) return false;
return true;
}
template<int n_in, int n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true>(
attrs, in_attrs, out_attrs, TShape());
}
template<int n_in, int n_out>
inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
return ElemwiseAttr<int, type_is_none, type_assign, true>(
attrs, in_attrs, out_attrs, -1);
}
// Transfer gradient and input to FGradient function
struct ElemwiseGradUseIn {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
for (auto& h : n->inputs) {
heads.push_back(h);
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
// Transfer gradient and output to FGradient function
struct ElemwiseGradUseOut {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
index_t n_out = n->num_outputs();
for (index_t i = 0; i < n_out; ++i) {
heads.emplace_back(nnvm::NodeEntry{n, i, 0});
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
// Transfer only gradient to FGradient function
struct ElemwiseGradUseNone {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeGradNode(op_name, n, ograds, n->attrs.dict);
}
};
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_