blob: e280c60447dcef6107f0c65332b02db685270839 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file elemwise_op_common.h
* \brief common function used for broadcasting and reducing
* \author Xingjian Shi
*/
#ifndef MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_
#define MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_
#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <utility>
#include "./operator_common.h"
namespace mxnet {
namespace op {
template<typename AttrType,
bool (*is_none)(const AttrType&), bool reverse_infer>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs) {
size_t n_in = in_attrs->size();
size_t n_out = out_attrs->size();
bool found = false;
AttrType dattr;
for (size_t i = 0; i < n_in; ++i) {
if (!is_none((*in_attrs)[i])) {
dattr = (*in_attrs)[i];
found = true;
break;
}
}
if (reverse_infer && !found) {
for (size_t i = 0; i < n_out; ++i) {
if (!is_none((*out_attrs)[i])) {
dattr = (*out_attrs)[i];
found = true;
break;
}
}
}
if (!found) {
return false;
}
for (size_t i = 0; i < n_in; ++i) {
if (is_none((*in_attrs)[i])) {
(*in_attrs)[i] = dattr;
} else if ((*in_attrs)[i] != dattr) {
LOG(FATAL) << "Incompatible attr in node " << attrs.name << " at " << i << "-th input: "
<< "expected " << dattr << ", got " << (*in_attrs)[i];
}
}
for (size_t i = 0; i < n_out; ++i) {
if (is_none((*out_attrs)[i])) {
(*out_attrs)[i] = dattr;
} else if ((*out_attrs)[i] != dattr) {
LOG(FATAL) << "Incompatible attr in node " << attrs.name << " at " << i << "-th output: "
<< "expected " << dattr << ", got " << (*out_attrs)[i];
}
}
return true;
}
inline bool shape_is_none(const TShape& x) {
return x.ndim() == 0;
}
template<int n_in, int n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), n_in) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), n_out);
return ElemwiseAttr<TShape, shape_is_none, true>(
attrs, in_attrs, out_attrs);
}
inline bool type_is_none(const int& x) {
return x == -1;
}
template<int n_in, int n_out>
inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), n_in) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), n_out);
return ElemwiseAttr<int, type_is_none, true>(
attrs, in_attrs, out_attrs);
}
struct ElemwiseGradUseIn {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
for (auto& h : n->inputs) {
heads.push_back(h);
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
struct ElemwiseGradUseOut {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
index_t n_out = n->num_outputs();
for (index_t i = 0; i < n_out; ++i) {
heads.emplace_back(nnvm::NodeEntry{n, i, 0});
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
struct ElemwiseGradUseNone {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeGradNode(op_name, n, ograds, n->attrs.dict);
}
};
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_ELEMWISE_OP_COMMON_H_