blob: ce1b98f095d85e60b3d1b8582a549ea57ae2a72c [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file autograd.cc
* \brief Implementation of AutogradRuntime module.
*/
#include <mxnet/operator.h>
#include <mxnet/executor.h>
#include <nnvm/pass_functions.h>
#include <unordered_set>
#include <iostream>
#include "../executor/graph_executor.h"
#include "./autograd.h"
namespace mxnet {
namespace autograd {
using nnvm::Symbol;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::NodeEntryMap;
using exec::GraphExecutor;
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool AutogradRuntime::is_train_;
#else
MX_THREAD_LOCAL bool AutogradRuntime::is_train_;
#endif
template<typename FVisit>
inline void AGDFSVisit(const std::vector<AGNodeEntry>& heads,
FVisit fvisit) {
typedef const AGNodePtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const AGNodeEntry& e)->GNode {
return &e.ag_node;
});
nnvm::PostOrderDFSVisit<GNode, AGNode*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->AGNode* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { return (*n)->inputs.size(); },
[](GNode n, uint32_t index)->GNode { return &(*n)->inputs.at(index).ag_node; });
}
nnvm::NodeEntry AGNodeEntry::nn_entry() const {
return nnvm::NodeEntry{ag_node->nn_node, index, version};
}
bool AGNodeEntry::is_none() const {
return ag_node == nullptr || ag_node->outputs.empty();
}
AutogradRuntime::AutogradRuntime() {}
void AutogradRuntime::MarkVariables(
const std::vector<NDArray*>& variables,
const std::vector<mx_uint>& grad_reqs,
const std::vector<NDArray*>& gradients) {
for (uint32_t i = 0; i < variables.size(); ++i) {
std::string str_c(std::to_string(variable_count_++));
AGNodeEntry e{AGNode::Create(Node::Create()), 0, 0};
variables[i]->entry_.clear();
e.ag_node->outputs.emplace_back(*variables[i]);
AGNodeEntry ge{AGNode::Create(Node::Create()), 0, 0};
gradients[i]->entry_.clear();
ge.ag_node->outputs.emplace_back(*gradients[i]);
ge.ag_node->nn_node->attrs.name = "grad" + str_c;
gradients[i]->entry_ = std::move(ge);
e.ag_node->out_grads.emplace_back(*gradients[i]);
e.ag_node->grad_req = static_cast<OpReqType>(grad_reqs[i]);
e.ag_node->nn_node->attrs.name = "var" + str_c;
variables[i]->entry_ = std::move(e); // assign last to prevent cyclic reference
}
}
void AutogradRuntime::RecordImperativeFCompute(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
RecordOp(op, attrs, p_inputs, p_outputs, nullptr);
}
void AutogradRuntime::RecordImperativeOperator(const std::shared_ptr<Operator>& opr,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
RecordOp(op, attrs, p_inputs, p_outputs, opr);
}
std::shared_ptr<AutogradRuntime> AutogradRuntime::_GetSharedRef() {
static std::shared_ptr<AutogradRuntime> inst(new AutogradRuntime());
return inst;
}
AutogradRuntime* AutogradRuntime::Get() {
static AutogradRuntime *ptr = _GetSharedRef().get();
return ptr;
}
AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs,
const std::shared_ptr<Operator>& opr) {
std::vector<NDArray>& inputs = *p_inputs;
std::vector<NDArray>& outputs = *p_outputs;
NodePtr nn_node = Node::Create();
nn_node->attrs = attrs;
nn_node->attrs.name = "node_" + std::to_string(node_count_++);
AGNodePtr ag_node = AGNode::Create(nn_node);
ag_node->opr = opr;
for (uint32_t i = 0; i < outputs.size(); ++i) {
CHECK(outputs[i].entry_.is_none())
<< "Output NDArray is non-empty and already in another computation graph. "
<< "Assigning to it will cause undefined behavior when evaluating gradients. "
<< "Please call backward first to clear the graph or do this out side of "
<< "a train section. ";
outputs[i].entry_.clear();
ag_node->outputs.push_back(outputs[i]);
outputs[i].entry_ = AGNodeEntry{ag_node, i, 0};
}
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].entry_.is_none()) {
AGNodeEntry e{AGNode::Create(Node::Create()), 0, 0};
e.ag_node->outputs.emplace_back(inputs[i]);
e.ag_node->out_grads.emplace_back();
e.ag_node->nn_node->attrs.name = "var_" + std::to_string(variable_count_++);
inputs[i].entry_ = std::move(e); // assign last to prevent cyclic reference
}
nn_node->inputs.push_back(inputs[i].entry_.nn_entry());
ag_node->inputs.push_back(inputs[i].entry_);
}
return ag_node;
}
void AutogradRuntime::ComputeGradient(const std::vector<NDArray>& outputs,
const std::vector<NDArray>& ograds,
bool retain_graph) {
static auto& fmutate_inputs = nnvm::Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
std::vector<AGNodeEntry> heads;
Symbol sym;
NodeEntryMap<NDArray> feed_dict;
for (const auto& i : outputs) {
CHECK(!i.entry_.is_none())
<< "Cannot differentiate node because it is not in a computational graph. "
<< "You need to set is_training to true or use a train_section to save "
<< "computational graphs for backward. If you want to differentiate the same "
<< "graph twice, you need to pass retain_graph=True to backward.";
heads.emplace_back(i.entry_);
sym.outputs.emplace_back(i.entry_.nn_entry());
}
std::unordered_set<AGNode*> mutable_set;
std::vector<AGNodePtr> vlist;
std::vector<NDArray> args, args_grad;
std::vector<NDArray> aux_states;
std::vector<OpReqType> grad_reqs;
std::unordered_map<const nnvm::Node*, std::shared_ptr<Operator>> saved_opr;
AGDFSVisit(heads, [&](const AGNodePtr& n) {
if (n->nn_node->is_variable()) {
vlist.push_back(n);
} else {
if (n->opr != nullptr) {
saved_opr.insert({n->nn_node.get(), n->opr});
}
if (fmutate_inputs.count(n->nn_node->op())) {
for (uint32_t i : fmutate_inputs[n->nn_node->op()](n->nn_node->attrs)) {
mutable_set.insert(n->inputs[i].ag_node.get());
}
}
}
for (uint32_t i = 0; i < n->outputs.size(); ++i) {
feed_dict.insert({NodeEntry{n->nn_node, i, 0}, n->outputs[i]});
}
});
for (const auto& n : vlist) {
if (mutable_set.count(n.get())) {
aux_states.push_back(n->outputs[0]);
} else {
if (n->grad_req != kNullOp) {
n->fresh_out_grad = true;
}
args.push_back(n->outputs[0]);
args_grad.push_back(n->out_grads[0]);
grad_reqs.push_back(n->grad_req);
}
}
if (args.size()) {
std::map<std::string, Context> ctx_map;
auto exec = new exec::GraphExecutor();
// (TODO) too hack here
exec->saved_opr_ = saved_opr;
exec->Init(sym, args[0].ctx(), ctx_map,
args, args_grad, grad_reqs,
aux_states, nullptr, feed_dict);
std::vector<NDArray> head_grads;
head_grads.reserve(exec->head_grad_array_.size());
CHECK_EQ(ograds.size(), exec->output_arrays_.size());
for (size_t i = 0; i < ograds.size(); ++i) {
if (ograds[i].is_none()) {
head_grads.emplace_back(
exec->output_arrays_[i].shape(), exec->output_arrays_[i].ctx(),
false, exec->output_arrays_[i].dtype());
head_grads.back() = static_cast<real_t>(1.0);
} else {
head_grads.emplace_back(ograds[i]);
}
}
exec->Backward(head_grads);
delete exec;
}
if (!retain_graph) {
for (auto& i : heads) {
i.ag_node->clear_history();
}
}
}
} // namespace autograd
} // namespace mxnet