blob: 9b39794b478267f18c5a366d11f1c3641341a3c9 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file legacy_op_util.cc
* \brief Utility to adapt OpProperty to the new NNVM registery
*/
#include <dmlc/base.h>
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/ndarray.h>
#include <nnvm/node.h>
#include <nnvm/graph.h>
#include <memory>
namespace mxnet {
namespace op {
using nnvm::Op;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeAttrs;
using nnvm::NodeEntry;
class ParsedOpProp {
public:
std::shared_ptr<OperatorProperty> ptr;
std::vector<std::string> arguments;
std::vector<std::string> aux_states;
std::vector<std::string> inputs;
std::vector<std::string> outputs;
// initializer
void Init(const NodeAttrs& attrs) {
std::vector<std::pair<std::string, std::string> > kwargs(
attrs.dict.begin(), attrs.dict.end());
try {
ptr->Init(kwargs);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs.op->name << "("
<< "name=\"" << attrs.name << "\"";
for (const auto& k : attrs.dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
arguments = ptr->ListArguments();
aux_states = ptr->ListAuxiliaryStates();
outputs = ptr->ListOutputs();
inputs = arguments;
inputs.insert(
inputs.end(), aux_states.begin(), aux_states.end());
}
};
// function to use operator property to infer attr
// get op property from the attribute
const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs) {
return nnvm::get<ParsedOpProp>(attrs.parsed).ptr.get();
}
template<typename AttrType, typename FInfer>
bool OpPropInferAttr(const NodeAttrs& attrs,
std::vector<AttrType> *iattr,
std::vector<AttrType> *oattr,
FInfer finfer) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
CHECK_EQ(prop.inputs.size(), iattr->size())
<< "op=" << attrs.op->name
<< ", inputs.size=" << prop.inputs.size()
<< ", iattr.size=" << iattr->size()
<< ", arg.size=" << prop.arguments.size();
std::vector<AttrType> in_attr(prop.arguments.size());
std::vector<AttrType> aux_attr(prop.aux_states.size());
for (size_t i = 0; i < prop.arguments.size(); ++i) {
in_attr[i] = (*iattr)[i];
}
for (size_t i = 0; i < prop.aux_states.size(); ++i) {
aux_attr[i] = (*iattr)[i + prop.arguments.size()];
}
if (!finfer(prop.ptr.get(), &in_attr, oattr, &aux_attr)) return false;
for (size_t i = 0; i < prop.arguments.size(); ++i) {
(*iattr)[i] = in_attr[i];
}
for (size_t i = 0; i < prop.aux_states.size(); ++i) {
(*iattr)[i + prop.arguments.size()] = aux_attr[i];
}
return true;
}
bool OpPropInferShape(const NodeAttrs& attrs,
std::vector<TShape> *iattr,
std::vector<TShape> *oattr) {
auto finfer = [](const OperatorProperty* op,
std::vector<TShape> *in,
std::vector<TShape> *out,
std::vector<TShape> *aux) {
return op->InferShape(in, out, aux);
};
return OpPropInferAttr(attrs, iattr, oattr, finfer);
}
bool OpPropInferType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
auto finfer = [](const OperatorProperty* op,
std::vector<int> *in,
std::vector<int> *out,
std::vector<int> *aux) {
return op->InferType(in, out, aux);
};
return OpPropInferAttr(attrs, iattr, oattr, finfer);
}
inline uint32_t OpPropNumInputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return static_cast<uint32_t>(prop.inputs.size());
}
inline uint32_t OpPropNumOutputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return static_cast<uint32_t>(prop.outputs.size());
}
inline uint32_t OpPropNumVisibleOutputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return static_cast<uint32_t>(prop.ptr->NumVisibleOutputs());
}
std::vector<std::string> OpPropListInputNames(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return prop.inputs;
}
std::vector<std::string> OpPropListOutputNames(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return prop.outputs;
}
std::vector<uint32_t> OpPropMutateInputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
std::vector<uint32_t> ret;
for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
ret.push_back(static_cast<uint32_t>(i + prop.arguments.size()));
}
return ret;
}
std::vector<std::pair<int, int> > OpPropInplaceOption(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
std::vector<int> in_data(prop.arguments.size());
std::vector<int> out_data(prop.outputs.size());
std::vector<void*> out_addr(prop.outputs.size());
for (size_t i = 0; i < in_data.size(); ++i) {
in_data[i] = static_cast<int>(i);
}
for (size_t i = 0; i < out_data.size(); ++i) {
out_data[i] = static_cast<int>(i);
out_addr[i] = &out_data[i];
}
std::vector<std::pair<int, int> > forward_inplace;
for (auto& kv : prop.ptr->ForwardInplaceOption(in_data, out_addr)) {
forward_inplace.push_back(
std::make_pair(kv.first, *static_cast<int*>(kv.second)));
}
return forward_inplace;
}
std::vector<ResourceRequest> OpPropResourceRequest(const NodeAttrs& attrs) {
std::vector<TShape> ishape;
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return prop.ptr->ForwardResource(ishape);
}
std::vector<ResourceRequest> OpBackResourceRequest(const NodeAttrs& attrs) {
std::vector<TShape> ishape;
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return prop.ptr->BackwardResource(ishape);
}
Operator* OpPropCreateLayerOp(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& ishape,
const std::vector<int>& itype) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
std::vector<TShape> is(ishape.begin(), ishape.begin() + prop.arguments.size());
std::vector<int> it(itype.begin(), itype.begin() + prop.arguments.size());
return prop.ptr->CreateOperatorEx(ctx, &is, &it);
}
inline std::vector<NodeEntry> OpPropGradient(
const Op* back_op,
const NodePtr& ptr,
const std::vector<NodeEntry>& out_grads) {
auto& prop = nnvm::get<ParsedOpProp>(ptr->attrs.parsed);
std::vector<NodeEntry> out_data(prop.outputs.size());
for (uint32_t i = 0; i < out_data.size(); ++i) {
out_data[i] = NodeEntry{ptr, i, 0};
}
std::vector<NodeEntry> in_data(
ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size());
std::vector<NodeEntry> ograd(
out_grads.begin(), out_grads.begin() + prop.ptr->NumVisibleOutputs());
auto inputs = prop.ptr->BackwardInputs(ograd, in_data, out_data);
// add all the auxiliary data
for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]);
}
NodePtr gnode = Node::Create();
gnode->inputs = std::move(inputs);
gnode->control_deps.emplace_back(ptr);
gnode->attrs = ptr->attrs;
gnode->attrs.op = back_op;
gnode->attrs.name = ptr->attrs.name + "_backward";
std::vector<NodeEntry> in_grad(prop.arguments.size());
for (uint32_t i = 0; i < prop.arguments.size(); ++i) {
in_grad[i] = NodeEntry{gnode, i, 0};
}
// attach no gradient node to forbid gradient on aux_state
if (prop.aux_states.size() != 0) {
NodePtr ng = Node::Create();
ng->attrs.op = Op::Get("_NoGradient");
ng->attrs.name = "NoGradient";
for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
in_grad.emplace_back(NodeEntry{ng, 0, 0});
}
}
return in_grad;
}
inline uint32_t OpBackNumOutputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return static_cast<uint32_t>(prop.arguments.size());
}
std::vector<std::string> OpBackListOutputNames(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
return prop.arguments;
}
std::vector<uint32_t> OpBackMutateInputs(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
if (prop.aux_states.size() == 0) return std::vector<uint32_t>{};
std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
std::vector<int> in_data_index(prop.arguments.size());
std::vector<int> out_data_index(prop.outputs.size());
size_t arg_size = prop.ptr->DeclareBackwardDependency(
out_grad_index, in_data_index, out_data_index).size();
std::vector<uint32_t> ret;
for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
ret.push_back(static_cast<uint32_t>(i + arg_size));
}
return ret;
}
std::vector<std::pair<int, int> > OpBackInplaceOption(const NodeAttrs& attrs) {
auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
std::vector<int> in_data_index(prop.arguments.size());
std::vector<int> out_data_index(prop.outputs.size());
int counter = 0;
for (size_t i = 0; i < in_data_index.size(); ++i) {
in_data_index[i] = counter++;
}
for (size_t i = 0; i < out_grad_index.size(); ++i) {
out_grad_index[i] = counter++;
}
for (size_t i = 0; i < out_data_index.size(); ++i) {
out_data_index[i] = counter++;
}
auto args_index = prop.ptr->DeclareBackwardDependency(
out_grad_index, in_data_index, out_data_index);
std::vector<int> args_array(counter, -1);
for (size_t i = 0; i < args_index.size(); ++i) {
args_array[args_index[i]] = static_cast<int>(i);
}
std::vector<void*> in_grad_ptr(in_data_index.size());
for (size_t i = 0; i < in_grad_ptr.size(); ++i) {
// in data index starts from 0 to num_inputs
in_grad_ptr[i] = (void*)&in_data_index[i]; // NOLINT(*)
}
auto remap_index = prop.ptr->BackwardInplaceOption(
out_grad_index, in_data_index, out_data_index, in_grad_ptr);
std::vector<std::pair<int, int> > remap(remap_index.size());
for (size_t i = 0; i < remap_index.size(); ++i) {
if (args_array[remap_index[i].first] == -1) {
LOG(FATAL) << "BackwardInplaceOption not consistent with DeclareBackwardDependency";
}
remap[i].first = args_array[remap_index[i].first];
remap[i].second = *static_cast<int*>(remap_index[i].second);
}
return remap;
}
// register the legacy operator properties under NNVM registry.
void RegisterLegacyOpProp() {
for (auto reg : dmlc::Registry<OperatorPropertyReg>::List()) {
Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
if (op.attr_parser != nullptr) continue;
auto creator = reg->body;
auto attr_parser = [creator](NodeAttrs* attrs) {
if (attrs->parsed.empty()) {
ParsedOpProp op;
op.ptr.reset(creator());
op.Init(*attrs);
attrs->parsed = std::move(op);
}
};
op.add_arguments(reg->arguments);
op.describe(reg->description);
// attribute parser
op.set_attr_parser(attr_parser);
op.set_num_inputs(OpPropNumInputs);
op.set_num_outputs(OpPropNumOutputs);
op.set_attr<nnvm::FListInputNames>("FListInputNames", OpPropListInputNames);
op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpPropListOutputNames);
op.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", OpPropNumVisibleOutputs);
op.set_attr<nnvm::FInferShape>("FInferShape", OpPropInferShape);
op.set_attr<nnvm::FInferType>("FInferType", OpPropInferType);
op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpPropMutateInputs);
op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpPropInplaceOption);
op.set_attr<FResourceRequest>("FResourceRequest", OpPropResourceRequest);
op.set_attr<FCreateLayerOp>("FCreateLayerOp", OpPropCreateLayerOp);
if (reg->key_var_num_args.length() != 0) {
op.set_attr<std::string>("key_var_num_args", reg->key_var_num_args);
}
// register BackwardOps
std::string back_op_name = "_backward_" + reg->name;
Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name);
op.set_attr<nnvm::FGradient>("FGradient", std::bind(
OpPropGradient, &back_op,
std::placeholders::_1, std::placeholders::_2));
back_op.set_attr_parser(attr_parser);
back_op.set_num_inputs(nnvm::kVarg);
back_op.set_num_outputs(OpBackNumOutputs);
back_op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpBackListOutputNames);
back_op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpBackMutateInputs);
back_op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpBackInplaceOption);
back_op.set_attr<FResourceRequest>(
"FResourceRequest", OpBackResourceRequest);
back_op.set_attr<bool>("TIsLayerOpBackward", true);
back_op.set_attr<bool>("TIsBackward", true);
}
}
// no gradient operator
NNVM_REGISTER_OP(_NoGradient)
.set_num_inputs(0)
.set_num_outputs(1)
.describe("Place holder for variable who cannot perform gradient");
void RegisterLegacyNDFunc() {
for (auto reg : dmlc::Registry<NDArrayFunctionReg>::List()) {
if (reg->type_mask & kScalarArgBeforeNDArray) continue;
Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
if (op.attr_parser != nullptr) continue;
CHECK_LE(reg->num_scalars + reg->num_use_vars, reg->arguments.size())
<< reg->name;
auto func = reg->body;
op.describe(reg->description);
op.add_arguments(reg->arguments);
op.set_num_inputs(reg->num_use_vars);
op.set_num_outputs(reg->num_mutate_vars);
op.set_attr_parser([](NodeAttrs* attrs){});
op.set_attr<FNDArrayFunction>("FNDArrayFunction", [reg](const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
std::vector<NDArray>* outputs) {
CHECK_EQ(inputs.size(), reg->num_use_vars);
CHECK_EQ(outputs->size(), reg->num_mutate_vars);
int n_scalars = reg->num_scalars;
std::vector<float> scalars;
scalars.reserve(n_scalars);
auto dict = attrs.dict;
for (int i = 0; i < n_scalars; ++i) {
const std::string& name = reg->arguments[i+reg->num_use_vars].name;
auto s = dict.find(name);
CHECK(s != dict.end()) << "Missing scalar param " << name;
scalars.push_back(std::stof(s->second));
dict.erase(s);
}
int n_params = dict.size();
std::vector<const char*> keys, vals;
keys.reserve(n_params);
vals.reserve(n_params);
for (auto& i : dict) {
keys.push_back(dmlc::BeginPtr(i.first));
vals.push_back(dmlc::BeginPtr(i.second));
}
std::vector<NDArray*> input_ptrs, output_ptrs;
for (auto& i : inputs) {
input_ptrs.push_back(const_cast<NDArray*>(&i));
}
for (auto& i : *outputs) {
output_ptrs.push_back(&i);
}
reg->body(input_ptrs.data(),
scalars.data(),
output_ptrs.data(),
n_params,
const_cast<char**>(keys.data()),
const_cast<char**>(vals.data()));
});
}
}
} // namespace op
} // namespace mxnet