blob: 18e7882a2bb50e34abc7c493bfb6c5ec8b87e391 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file autograd.cc
* \brief Implementation of AutogradRuntime module.
*/
#include <mxnet/operator.h>
#include <mxnet/executor.h>
#include <nnvm/pass_functions.h>
#include <unordered_set>
#include <iostream>
#include "../executor/graph_executor.h"
#include "./autograd.h"
namespace mxnet {
namespace autograd {
using nnvm::Symbol;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::NodeEntryMap;
using exec::GraphExecutor;
AutogradRuntime::AutogradRuntime() {}
void AutogradRuntime::SetRecording(bool recording) {
is_recording_ = recording;
}
bool AutogradRuntime::IsRecording() const {
return is_recording_;
}
void AutogradRuntime::MarkVariables(std::vector<NDArray*> *p_variables) {
std::vector<NDArray*>& variables = *p_variables;
for (NDArray* var : variables) {
NodeEntry& e = var->entry_;
e.node = Node::Create();
e.node->attrs.name = "ag_variables_" + std::to_string(variable_count_++);
}
}
void AutogradRuntime::RecordImperativeFCompute(FCompute fn,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
RecordOp(op, attrs, p_inputs, p_outputs);
}
void AutogradRuntime::RecordImperativeOperator(std::shared_ptr<Operator> opr,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
NodePtr node = RecordOp(op, attrs, p_inputs, p_outputs);
saved_opr_.insert({node.get(), opr});
}
std::vector<NDArray> Execute(Symbol sym,
const NodeEntryMap<NDArray>& feed_dict,
const NodeOperatorMap& saved_opr);
std::vector<NDArray> AutogradRuntime::ComputeGradient(const std::vector<NDArray>& outputs) {
Symbol ff_sym;
for (size_t i = 0; i < outputs.size(); ++i) {
ff_sym.outputs.push_back(outputs[i].entry_);
}
std::vector<NDArray> result = Execute(ff_sym, saved_ndarray_, saved_opr_);
ClearRecords();
return result;
}
std::shared_ptr<AutogradRuntime> AutogradRuntime::_GetSharedRef() {
static std::shared_ptr<AutogradRuntime> inst(new AutogradRuntime());
return inst;
}
AutogradRuntime* AutogradRuntime::Get() {
static AutogradRuntime *ptr = _GetSharedRef().get();
return ptr;
}
NodePtr AutogradRuntime::RecordOp(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray> *p_inputs,
std::vector<NDArray> *p_outputs) {
std::vector<NDArray>& inputs = *p_inputs;
std::vector<NDArray>& outputs = *p_outputs;
NodePtr node = Node::Create();
node->attrs = attrs;
node->attrs.name = "ag_" + op->name + "_" + std::to_string(node_count_++);
for (size_t i = 0; i < outputs.size(); ++i) {
NodeEntry &e = outputs[i].entry_;
e.node = node;
e.index = i;
saved_ndarray_[e] = outputs[i];
}
for (size_t i = 0; i < inputs.size(); ++i) {
NodeEntry &e = inputs[i].entry_;
CHECK(e.node.get() != nullptr)
<< "not support partial gradient yet, all the "
<< "inputs of autograd should be marked as variable";
if (!saved_ndarray_.count(e)) {
saved_ndarray_[e] = inputs[i];
}
node->inputs.push_back(e);
}
return node;
}
void AutogradRuntime::ClearRecords() {
node_count_ = 0;
variable_count_ = 0;
saved_ndarray_.clear();
saved_opr_.clear();
}
GraphExecutor *Bind(Symbol symbol,
const NodeEntryMap<TShape>& shapes,
const NodeEntryMap<Context>& ctxs,
const NodeOperatorMap& saved_opr) {
std::vector<NodePtr> input_nodes =
symbol.ListInputs(Symbol::ListInputOption::kAll);
size_t input_size = input_nodes.size();
std::vector<NDArray> inputs;
inputs.reserve(input_size);
std::vector<NDArray> grads;
grads.reserve(input_size);
std::vector<OpReqType> grad_reqs;
grad_reqs.reserve(input_size);
// prepare inputs and set grad for every input
for (size_t i = 0; i < input_size; ++i) {
NodeEntry e = NodeEntry{input_nodes[i], 0, 0};
if (shapes.count(e) && ctxs.count(e)) {
TShape shape = shapes.at(e);
Context ctx = ctxs.at(e);
inputs.emplace_back(shape, ctx);
NDArray grad(shape, ctx);
grad = static_cast<real_t>(1.0);
grads.emplace_back(grad);
grad_reqs.emplace_back(OpReqType::kWriteTo);
} else {
LOG(FATAL) << "no corresponding ndarray: "
<< input_nodes[i]->attrs.name << "(0)";
}
}
// default context, assuming use the same context
CHECK_GT(ctxs.size(), 0)
<< "The size of context mapping should be greater than zero";
Context ctx = ctxs.begin()->second;
std::map<std::string, Context> ctx_map;
std::vector<NDArray> aux_states;
auto exec = new exec::GraphExecutor();
// (TODO) too hack here
exec->saved_opr_ = saved_opr;
exec->Init(symbol, ctx, ctx_map,
inputs, grads, grad_reqs, aux_states);
return exec;
}
std::vector<NDArray> Run(GraphExecutor* exec,
const NodeEntryMap<NDArray>& feed_dict) {
const nnvm::IndexedGraph& idx = exec->graph_.indexed_graph();
for (const auto& kv : feed_dict) {
if (idx.exist(kv.first.node.get())) {
uint32_t entry_id = idx.entry_id(kv.first);
CopyFromTo(kv.second, &(exec->data_entry_[entry_id]));
}
}
std::vector<NDArray> head_grads;
head_grads.reserve(exec->head_grad_array_.size());
for (size_t i = 0; i < exec->output_arrays_.size(); ++i) {
NDArray grad(exec->output_arrays_[i].shape(), exec->output_arrays_[i].ctx());
grad = static_cast<real_t>(1.0);
head_grads.push_back(grad);
}
exec->Backward(head_grads);
std::vector<NDArray> results;
results.reserve(exec->grad_store_.size());
for (const auto& kv : exec->grad_store_) {
results.emplace_back(kv.second);
}
return results;
}
std::vector<NDArray> Execute(Symbol sym,
const NodeEntryMap<NDArray>& feed_dict,
const NodeOperatorMap& saved_opr) {
NodeEntryMap<TShape> shapes;
NodeEntryMap<Context> ctxs;
for (const auto& kv : feed_dict) {
const NodeEntry& e = kv.first;
shapes.insert({kv.first, kv.second.shape()});
ctxs.insert({kv.first, kv.second.ctx()});
}
exec::GraphExecutor *exec = Bind(sym, shapes, ctxs, saved_opr);
std::vector<NDArray> res = Run(exec, feed_dict);
return res;
}
} // namespace autograd
} // namespace mxnet