Autograd for NDArray (#5129)
* [Autograd] Support for stateless ndarray operator
* [Autograd] Support for ndarray operator with state
* [Autograd] Remove set_mark_for_record, add comments
* [Autograd] Fix lint, refactor
* [Autograd] Add mark_variables, comment for autograd c_api
* [Autograd] Fix lint.
* Increase rtol & atol of test_operator::deconvolution
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 7ac32a8..3fb9d6a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -515,7 +515,32 @@
int num_params,
const char **param_keys,
const char **param_vals);
-
+/*!
+ * \brief set whether to record operator for autograd
+ * \param recording 1 when turn on recording, 0 when turn off recording
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradSetRecording(int recording);
+/*!
+ * \brief mark NDArrays as variables to compute gradient for autograd
+ * \param num_var number of variable NDArrays
+ * \param var_handles variable NDArrays
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var,
+ NDArrayHandle* var_handles);
+/*!
+ * \brief compute the gradient of outputs w.r.t variabels
+ * \param num_output number of output NDArray
+ * \param output_handles output NDArrays
+ * \param num_grad number of gradient NDArrays
+ * \param grad_handles gradient NDArrays
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output,
+ NDArrayHandle* output_handles,
+ mx_uint* num_grad,
+ NDArrayHandle** grad_handles);
//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index d7871d6..2ea03bf 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -11,6 +11,7 @@
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dmlc/registry.h>
+#include <nnvm/node.h>
#include <vector>
#include <map>
#include <string>
@@ -27,6 +28,12 @@
#endif
namespace mxnet {
+
+// forward declaration
+namespace autograd {
+class AutogradRuntime;
+}
+
/*!
* \brief ndarray interface
*/
@@ -48,7 +55,7 @@
NDArray(const TShape &shape, Context ctx,
bool delay_alloc = false, int dtype = mshadow::default_type_flag)
: ptr_(std::make_shared<Chunk>(shape.Size(), ctx, delay_alloc, dtype)),
- shape_(shape), offset_(0), dtype_(dtype) {
+ shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
@@ -62,7 +69,7 @@
*/
NDArray(const TBlob &data, int dev_id)
: ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_), offset_(0),
- dtype_(data.type_flag_) {
+ dtype_(data.type_flag_), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
@@ -344,6 +351,7 @@
std::vector<std::string>* keys);
private:
+ friend class autograd::AutogradRuntime;
/*! \brief the real data chunk that backs NDArray */
struct Chunk {
/*! \brief storage handlefrom storage engine */
@@ -414,6 +422,8 @@
size_t offset_;
/*! \brief type of data */
int dtype_ = -1;
+ /*! \brief node entry for autograd */
+ nnvm::NodeEntry entry_;
};
/*!
diff --git a/nnvm b/nnvm
index 85aaf57..0d64855 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit 85aaf570c261986eebb076c7e160254c75f89ebb
+Subproject commit 0d64855f741e0482be7a3dfecde05f290bfec85d
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 2fb5655..805ebeb 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -17,6 +17,7 @@
from . import operator
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
+from . import autograd
# use mx.rnd as short for mx.random
from . import random as rnd
from . import random
diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py
new file mode 100644
index 0000000..758c9db
--- /dev/null
+++ b/python/mxnet/autograd.py
@@ -0,0 +1,104 @@
+# coding: utf-8
+"""Autograd for NDArray."""
+from __future__ import absolute_import
+from __future__ import division
+
+import ctypes
+import functools
+from .base import _LIB, check_call
+from .base import mx_uint, NDArrayHandle, c_array
+from .ndarray import NDArray
+
+def set_recording(recording):
+ """Turn on or turn of operator recording.
+
+ Parameters
+ ----------
+ recording: bool
+ """
+ check_call(_LIB.MXAutogradSetRecording(
+ ctypes.c_int(recording)))
+
+def mark_variables(variables):
+ """Mark NDArrays as variables to compute gradient for autograd.
+
+ Parameters
+ ----------
+ variables: list of NDArray
+ """
+ variable_handles = []
+ for var in variables:
+ variable_handles.append(var.handle)
+ check_call(_LIB.MXAutogradMarkVariables(
+ len(variable_handles),
+ c_array(NDArrayHandle, variable_handles)))
+
+def compute_gradient(outputs):
+ """Compute the gradients of outputs w.r.t variables.
+
+ Parameters
+ ----------
+ outputs: list of NDArray
+
+ Returns
+ -------
+ gradients: list of NDArray
+ """
+ output_handles = []
+ for arr in outputs:
+ output_handles.append(arr.handle)
+
+ num_grad = mx_uint()
+ grad_handles = ctypes.POINTER(NDArrayHandle)()
+ check_call(_LIB.MXAutogradComputeGradient(
+ len(output_handles),
+ c_array(NDArrayHandle, output_handles),
+ ctypes.byref(num_grad),
+ ctypes.byref(grad_handles)))
+ return [NDArray(NDArrayHandle(grad_handles[i])) for i in range(num_grad.value)]
+
+def grad_and_loss(func):
+ """Return function that computes both gradient of arguments and loss value.
+
+ Parameters
+ ----------
+ func: a python function
+ The forward (loss) function.
+
+ Returns
+ -------
+ grad_and_loss_func: a python function
+ A function that would compute both the gradient of arguments and loss value.
+ """
+ @functools.wraps(func)
+ def wrapped(*args):
+ """Wrapped function."""
+ for x in args:
+ assert isinstance(x, NDArray), "type of autograd input should NDArray."
+ mark_variables(args)
+ set_recording(True)
+ outputs = func(*args)
+ set_recording(False)
+ grad_vals = compute_gradient(
+ outputs if isinstance(outputs, list) else [outputs])
+ return grad_vals, outputs
+ return wrapped
+
+def grad(func):
+ """Return function that computes gradient of arguments.
+
+ Parameters
+ ----------
+ func: a python function
+ The forward (loss) function.
+
+ Returns
+ -------
+ grad_func: a python function
+ A function that would compute the gradient of arguments.
+ """
+ grad_with_loss_func = grad_and_loss(func)
+ @functools.wraps(grad_with_loss_func)
+ def wrapped(*args):
+ return grad_with_loss_func(*args)[0]
+ return wrapped
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 7d9d8d6..1cd3a35 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -11,10 +11,13 @@
#include <mxnet/op_attr_types.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
+#include <string>
#include "./c_api_common.h"
#include "../common/utils.h"
+#include "../ndarray/autograd.h"
using namespace mxnet;
+using mxnet::autograd::AutogradRuntime;
void SetOpAttrs(const nnvm::Op *op,
nnvm::NodeAttrs *p_attrs,
@@ -261,7 +264,8 @@
0, PROFILER_MESSAGE(op->name.c_str()));
}
-void PushOperator(const nnvm::Op* op,
+void PushOperator(std::shared_ptr<Operator> opr,
+ const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
const Context& ctx,
const std::vector<engine::VarHandle>& read_vars,
@@ -270,12 +274,9 @@
const std::vector<uint32_t>& auxidx,
const std::vector<NDArray>& ndinputs,
const std::vector<NDArray>& ndoutputs) {
- static auto& createop = nnvm::Op::GetAttr<FCreateLayerOp>("FCreateLayerOp");
- MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
- Operator* opr = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types);
struct Capture {
engine::CallbackOnComplete on_complete;
- Operator *opr;
+ std::shared_ptr<Operator> opr;
};
Engine::Get()->PushAsync(
@@ -302,7 +303,6 @@
[](Engine* engine, void *cpt_handle) {
Capture* cpt = static_cast<Capture*>(cpt_handle);
cpt->on_complete();
- delete cpt->opr;
delete cpt;
}, static_cast<void*>(capture)),
requested};
@@ -312,7 +312,6 @@
if (ctx.dev_mask() == gpu::kDevMask) {
rctx.get_stream<gpu>()->Wait();
}
- delete opr;
delete capture;
on_complete();
}
@@ -372,10 +371,20 @@
}
if (fn) {
+ if (AutogradRuntime::Get()->IsRecording()) {
+ AutogradRuntime::Get()->RecordImperativeFCompute(fn, op,
+ attrs, &ndinputs, &ndoutputs);
+ }
PushFCompute(fn, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs);
} else if (createop.count(op)) {
- PushOperator(op, attrs, ctx, read_vars, write_vars,
+ std::shared_ptr<Operator> opr(
+ createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types));
+ if (AutogradRuntime::Get()->IsRecording()) {
+ AutogradRuntime::Get()->RecordImperativeOperator(opr, op,
+ attrs, &ndinputs, &ndoutputs);
+ }
+ PushOperator(opr, op, attrs, ctx, read_vars, write_vars,
requested, auxidx, ndinputs, ndoutputs);
} else {
LOG(FATAL)
@@ -385,6 +394,7 @@
}
}
+
if (outarray == nullptr) {
ret->ret_handles.clear();
for (int i = 0; i < num_visible_outputs; ++i) {
@@ -399,3 +409,48 @@
}
API_END();
}
+
+int MXAutogradSetRecording(int recording) {
+ API_BEGIN();
+ AutogradRuntime::Get()->SetRecording(static_cast<bool>(recording));
+ API_END();
+}
+
+int MXAutogradMarkVariables(mx_uint num_var,
+ NDArrayHandle *var_handles) {
+ API_BEGIN();
+ std::vector<NDArray*> variables;
+ variables.reserve(num_var);
+ for (mx_uint i = 0; i < num_var; ++i) {
+ variables.emplace_back(static_cast<NDArray*>(var_handles[i]));
+ }
+ AutogradRuntime::Get()->MarkVariables(&variables);
+ API_END();
+}
+
+int MXAutogradComputeGradient(mx_uint num_output,
+ NDArrayHandle *output_handles,
+ mx_uint* num_grad,
+ NDArrayHandle **grad_handles) {
+ API_BEGIN();
+ MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+
+ std::vector<NDArray> outputs;
+ outputs.reserve(num_output);
+ for (mx_uint i = 0; i < num_output; ++i) {
+ outputs.emplace_back(*static_cast<NDArray*>(output_handles[i]));
+ }
+
+ std::vector<NDArray> grads =
+ AutogradRuntime::Get()->ComputeGradient(outputs);
+
+ ret->ret_handles.resize(grads.size());
+ for (size_t i = 0; i < grads.size(); ++i) {
+ NDArray *ptr = new NDArray();
+ *ptr = grads[i];
+ ret->ret_handles[i] = ptr;
+ }
+ *num_grad = static_cast<mx_uint>(grads.size());
+ *grad_handles = dmlc::BeginPtr(ret->ret_handles);
+ API_END();
+}
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 71fb635..9a41554 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -42,7 +42,8 @@
Operator::ExecType exec_type() const override {
return op_->exec_type();
}
- explicit ForwardOpExecutor(Operator* op, std::vector<uint32_t> aux_index)
+ explicit ForwardOpExecutor(std::shared_ptr<Operator> op,
+ std::vector<uint32_t> aux_index)
: op_(op), aux_index_(aux_index) {
std::sort(aux_index_.begin(), aux_index_.end());
}
@@ -170,6 +171,8 @@
const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
const auto& vshape = g.GetAttr<ShapeVector>("shape");
const auto& vctx = g.GetAttr<ContextVector>("context");
+ const auto& saved_opr = g.GetAttr<
+ std::unordered_map<const nnvm::Node*, std::shared_ptr<Operator>>>("saved_opr");
// get the graph
const auto& idx = g.indexed_graph();
@@ -191,12 +194,17 @@
ishape.emplace_back(vshape[idx.entry_id(e)]);
itype.emplace_back(vdtype[idx.entry_id(e)]);
}
- ret[i] = std::make_shared<ForwardOpExecutor>(
- fcreate_layer_op[inode.source->op()](
- inode.source->attrs, vctx[i], ishape, itype), mutate_index);
+ std::shared_ptr<Operator> opr;
+ if (saved_opr.count(inode.source)) {
+ opr = saved_opr.at(inode.source);
+ } else {
+ opr.reset(fcreate_layer_op[inode.source->op()](
+ inode.source->attrs, vctx[i], ishape, itype));
+ }
+ ret[i] = std::make_shared<ForwardOpExecutor>(opr, mutate_index);
} else if (is_layer_backward.get(inode.source->op(), false)) {
+ CHECK_GE(inode.control_deps.size(), 1);
uint32_t fwd_id = inode.control_deps[0];
- CHECK_GE(inode.control_deps.size(), 1U);
CHECK(vctx[fwd_id] == vctx[i]);
CHECK(ret[fwd_id] != nullptr);
ret[i] = std::make_shared<BackwardOpExecutor>(
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 2a372d9..ca7d45a 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -330,6 +330,7 @@
nnvm::Graph g = InitGraph(symbol, default_ctx,
ctx_map, in_args, arg_grad_store,
grad_req_type, aux_states);
+ g.attrs["saved_opr"] = std::make_shared<nnvm::any>(std::move(saved_opr_));
g = AttachOpExecs(g);
g = AttachOpResources(g);
graph_ = std::move(g);
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 46b33a5..8ac8631 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -20,6 +20,25 @@
#include "./exec_pass.h"
namespace mxnet {
+
+using NodeOperatorMap = std::unordered_map<const nnvm::Node*,
+ std::shared_ptr<Operator>>;
+
+// forward declaration
+namespace exec {
+class GraphExecutor;
+}
+
+// forward declaration
+namespace autograd {
+exec::GraphExecutor *Bind(nnvm::Symbol symbol,
+ const nnvm::NodeEntryMap<TShape>& shapes,
+ const nnvm::NodeEntryMap<Context>& ctxs,
+ const NodeOperatorMap& saved_opr);
+std::vector<NDArray> Run(exec::GraphExecutor* exec,
+ const nnvm::NodeEntryMap<NDArray>& feed_dict);
+}
+
namespace exec {
using nnvm::Graph;
@@ -27,6 +46,13 @@
// graph executors
class GraphExecutor : public Executor {
public:
+ friend GraphExecutor *autograd::Bind(nnvm::Symbol symbol,
+ const nnvm::NodeEntryMap<TShape>& shapes,
+ const nnvm::NodeEntryMap<Context>& ctxs,
+ const NodeOperatorMap& saved_opr);
+ friend std::vector<NDArray> autograd::Run(GraphExecutor* exec,
+ const nnvm::NodeEntryMap<NDArray>& feed_dict);
+
using Executor::MonitorCallback;
virtual ~GraphExecutor();
@@ -133,6 +159,8 @@
size_t num_forward_inputs_{0};
// number of forward nodes
size_t num_forward_nodes_{0};
+ // saved operator for autograd
+ NodeOperatorMap saved_opr_;
// monitor call back
std::function<void(const char*, void*)> monitor_callback_{nullptr};
// whether to enable bulk execution
diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc
new file mode 100644
index 0000000..18e7882
--- /dev/null
+++ b/src/ndarray/autograd.cc
@@ -0,0 +1,218 @@
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file autograd.cc
+ * \brief Implementation of AutogradRuntime module.
+ */
+
+#include <mxnet/operator.h>
+#include <mxnet/executor.h>
+#include <nnvm/pass_functions.h>
+#include <unordered_set>
+#include <iostream>
+#include "../executor/graph_executor.h"
+#include "./autograd.h"
+
+namespace mxnet {
+namespace autograd {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::NodeEntryMap;
+using exec::GraphExecutor;
+
+AutogradRuntime::AutogradRuntime() {}
+
+void AutogradRuntime::SetRecording(bool recording) {
+ is_recording_ = recording;
+}
+
+bool AutogradRuntime::IsRecording() const {
+ return is_recording_;
+}
+
+void AutogradRuntime::MarkVariables(std::vector<NDArray*> *p_variables) {
+ std::vector<NDArray*>& variables = *p_variables;
+ for (NDArray* var : variables) {
+ NodeEntry& e = var->entry_;
+ e.node = Node::Create();
+ e.node->attrs.name = "ag_variables_" + std::to_string(variable_count_++);
+ }
+}
+
+void AutogradRuntime::RecordImperativeFCompute(FCompute fn,
+ const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray> *p_inputs,
+ std::vector<NDArray> *p_outputs) {
+ RecordOp(op, attrs, p_inputs, p_outputs);
+}
+
+void AutogradRuntime::RecordImperativeOperator(std::shared_ptr<Operator> opr,
+ const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray> *p_inputs,
+ std::vector<NDArray> *p_outputs) {
+ NodePtr node = RecordOp(op, attrs, p_inputs, p_outputs);
+ saved_opr_.insert({node.get(), opr});
+}
+
+std::vector<NDArray> Execute(Symbol sym,
+ const NodeEntryMap<NDArray>& feed_dict,
+ const NodeOperatorMap& saved_opr);
+
+std::vector<NDArray> AutogradRuntime::ComputeGradient(const std::vector<NDArray>& outputs) {
+ Symbol ff_sym;
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ ff_sym.outputs.push_back(outputs[i].entry_);
+ }
+ std::vector<NDArray> result = Execute(ff_sym, saved_ndarray_, saved_opr_);
+ ClearRecords();
+ return result;
+}
+
+std::shared_ptr<AutogradRuntime> AutogradRuntime::_GetSharedRef() {
+ static std::shared_ptr<AutogradRuntime> inst(new AutogradRuntime());
+ return inst;
+}
+
+AutogradRuntime* AutogradRuntime::Get() {
+ static AutogradRuntime *ptr = _GetSharedRef().get();
+ return ptr;
+}
+
+NodePtr AutogradRuntime::RecordOp(const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray> *p_inputs,
+ std::vector<NDArray> *p_outputs) {
+ std::vector<NDArray>& inputs = *p_inputs;
+ std::vector<NDArray>& outputs = *p_outputs;
+
+ NodePtr node = Node::Create();
+ node->attrs = attrs;
+ node->attrs.name = "ag_" + op->name + "_" + std::to_string(node_count_++);
+
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ NodeEntry &e = outputs[i].entry_;
+ e.node = node;
+ e.index = i;
+ saved_ndarray_[e] = outputs[i];
+ }
+
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ NodeEntry &e = inputs[i].entry_;
+ CHECK(e.node.get() != nullptr)
+ << "not support partial gradient yet, all the "
+ << "inputs of autograd should be marked as variable";
+ if (!saved_ndarray_.count(e)) {
+ saved_ndarray_[e] = inputs[i];
+ }
+ node->inputs.push_back(e);
+ }
+ return node;
+}
+
+void AutogradRuntime::ClearRecords() {
+ node_count_ = 0;
+ variable_count_ = 0;
+ saved_ndarray_.clear();
+ saved_opr_.clear();
+}
+
+GraphExecutor *Bind(Symbol symbol,
+ const NodeEntryMap<TShape>& shapes,
+ const NodeEntryMap<Context>& ctxs,
+ const NodeOperatorMap& saved_opr) {
+ std::vector<NodePtr> input_nodes =
+ symbol.ListInputs(Symbol::ListInputOption::kAll);
+
+ size_t input_size = input_nodes.size();
+ std::vector<NDArray> inputs;
+ inputs.reserve(input_size);
+ std::vector<NDArray> grads;
+ grads.reserve(input_size);
+ std::vector<OpReqType> grad_reqs;
+ grad_reqs.reserve(input_size);
+
+ // prepare inputs and set grad for every input
+ for (size_t i = 0; i < input_size; ++i) {
+ NodeEntry e = NodeEntry{input_nodes[i], 0, 0};
+ if (shapes.count(e) && ctxs.count(e)) {
+ TShape shape = shapes.at(e);
+ Context ctx = ctxs.at(e);
+ inputs.emplace_back(shape, ctx);
+ NDArray grad(shape, ctx);
+ grad = static_cast<real_t>(1.0);
+ grads.emplace_back(grad);
+ grad_reqs.emplace_back(OpReqType::kWriteTo);
+ } else {
+ LOG(FATAL) << "no corresponding ndarray: "
+ << input_nodes[i]->attrs.name << "(0)";
+ }
+ }
+
+ // default context, assuming use the same context
+ CHECK_GT(ctxs.size(), 0)
+ << "The size of context mapping should be greater than zero";
+ Context ctx = ctxs.begin()->second;
+
+ std::map<std::string, Context> ctx_map;
+ std::vector<NDArray> aux_states;
+
+ auto exec = new exec::GraphExecutor();
+ // (TODO) too hack here
+ exec->saved_opr_ = saved_opr;
+ exec->Init(symbol, ctx, ctx_map,
+ inputs, grads, grad_reqs, aux_states);
+
+ return exec;
+}
+
+std::vector<NDArray> Run(GraphExecutor* exec,
+ const NodeEntryMap<NDArray>& feed_dict) {
+ const nnvm::IndexedGraph& idx = exec->graph_.indexed_graph();
+
+ for (const auto& kv : feed_dict) {
+ if (idx.exist(kv.first.node.get())) {
+ uint32_t entry_id = idx.entry_id(kv.first);
+ CopyFromTo(kv.second, &(exec->data_entry_[entry_id]));
+ }
+ }
+
+ std::vector<NDArray> head_grads;
+ head_grads.reserve(exec->head_grad_array_.size());
+
+ for (size_t i = 0; i < exec->output_arrays_.size(); ++i) {
+ NDArray grad(exec->output_arrays_[i].shape(), exec->output_arrays_[i].ctx());
+ grad = static_cast<real_t>(1.0);
+ head_grads.push_back(grad);
+ }
+
+ exec->Backward(head_grads);
+
+ std::vector<NDArray> results;
+ results.reserve(exec->grad_store_.size());
+ for (const auto& kv : exec->grad_store_) {
+ results.emplace_back(kv.second);
+ }
+ return results;
+}
+
+std::vector<NDArray> Execute(Symbol sym,
+ const NodeEntryMap<NDArray>& feed_dict,
+ const NodeOperatorMap& saved_opr) {
+ NodeEntryMap<TShape> shapes;
+ NodeEntryMap<Context> ctxs;
+ for (const auto& kv : feed_dict) {
+ const NodeEntry& e = kv.first;
+ shapes.insert({kv.first, kv.second.shape()});
+ ctxs.insert({kv.first, kv.second.ctx()});
+ }
+ exec::GraphExecutor *exec = Bind(sym, shapes, ctxs, saved_opr);
+ std::vector<NDArray> res = Run(exec, feed_dict);
+ return res;
+}
+
+} // namespace autograd
+} // namespace mxnet
diff --git a/src/ndarray/autograd.h b/src/ndarray/autograd.h
new file mode 100644
index 0000000..67f845f
--- /dev/null
+++ b/src/ndarray/autograd.h
@@ -0,0 +1,86 @@
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file autograd.h
+ * \brief AutogradRuntime can automatically compute gradients
+ */
+#ifndef MXNET_NDARRAY_AUTOGRAD_H_
+#define MXNET_NDARRAY_AUTOGRAD_H_
+
+#include <dmlc/logging.h>
+#include <mxnet/base.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/op_attr_types.h>
+#include <nnvm/symbolic.h>
+#include <nnvm/op.h>
+#include <nnvm/graph.h>
+#include <vector>
+#include <unordered_map>
+
+namespace mxnet {
+namespace autograd {
+
+/*!
+ * \brief AutogradRuntime Interface
+ */
+class AutogradRuntime {
+ public:
+ /*! \brief turn on or turn off operator recording for autograd. */
+ void SetRecording(bool recording);
+ /*! \brief whether operator recording is on. */
+ bool IsRecording() const;
+ /*! \brief mark variables for computing gradients. */
+ void MarkVariables(std::vector<NDArray*>* p_variables);
+ /*! \brief record imperative operator which is executed by fcompute. */
+ void RecordImperativeFCompute(FCompute fn,
+ const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray>* p_inputs,
+ std::vector<NDArray>* p_outputs);
+ /*! \brief record imperative operator which is executed by operator. */
+ void RecordImperativeOperator(std::shared_ptr<Operator> opr,
+ const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray>* p_inputs,
+ std::vector<NDArray>* p_outputs);
+ /*! \brief compute the gradient of outputs w.r.t variables. */
+ std::vector<NDArray> ComputeGradient(const std::vector<NDArray>& outputs);
+ /*! \return AutogradRuntime singleton */
+ static AutogradRuntime* Get();
+ /*! \brief Get shared pointer reference to AutogradRuntime singleton.
+ * Most user should not call this function.
+ * This function is called by another singleton X who requires
+ * AutogradRuntime to be destructed after X.
+ *
+ * \return A shared pointer to AutogradRuntime singleton.
+ */
+ static std::shared_ptr<AutogradRuntime> _GetSharedRef();
+
+ protected:
+ /*! \brief make constructor protected. */
+ AutogradRuntime();
+
+ private:
+ /*! \brief to record operator, return corresponding node. */
+ nnvm::NodePtr RecordOp(const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ std::vector<NDArray>* p_inputs,
+ std::vector<NDArray>* p_outputs);
+ /*! \brief clear the record data. */
+ void ClearRecords();
+ /*! \brief AutogradRuntime singleton. */
+ static AutogradRuntime* instance_;
+ /*! \brief indicate whether operator recording is on. */
+ bool is_recording_{false};
+ /*! \brief node count used for naming */
+ int node_count_{0};
+ /*! \brief variable count used for naming */
+ int variable_count_{0};
+ /*! \brief mapping from node entry to saved ndarray. */
+ nnvm::NodeEntryMap<NDArray> saved_ndarray_;
+ /*! \brief mapping from node to saved operator. */
+ std::unordered_map<const nnvm::Node*, std::shared_ptr<Operator>> saved_opr_;
+};
+
+} // namespace autograd
+} // namespace mxnet
+#endif // MXNET_NDARRAY_AUTOGRAD_H_
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 947df05..d0b2a84 100755
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -174,7 +174,6 @@
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- // LOG(INFO) << attrs.name;
using namespace broadcast;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py
new file mode 100644
index 0000000..3ded9fa
--- /dev/null
+++ b/tests/python/unittest/test_autograd.py
@@ -0,0 +1,62 @@
+import mxnet.ndarray as nd
+from mxnet.autograd import grad, grad_and_loss
+
+def autograd_assert(*args, **kwargs):
+ f = kwargs["func"]
+ grad_f = kwargs["grad_func"]
+
+ grad_func = grad_and_loss(f)
+ grad_vals, output = grad_func(*args)
+ res = f(*args)
+ assert output == res
+ grad_res = grad_f(*args)
+ assert len(grad_vals) == len(grad_res)
+ for a, b in zip(grad_vals, grad_res):
+ assert a == b
+
+def test_unary_func():
+ x = nd.uniform(shape=(4, 5))
+ f_exp = lambda x: nd.exp(x)
+ f_exp_grad = lambda x: [x]
+ autograd_assert(x, func=f_exp, grad_func=f_exp_grad)
+ f_half = lambda x: x/2
+ f_half_grad = lambda x: [nd.ones(x.shape) * 0.5]
+ autograd_assert(x, func=f_half, grad_func=f_half_grad)
+ f_square = lambda x: x**2
+ f_square_grad = lambda x: [2*x]
+ autograd_assert(x, func=f_square, grad_func=f_square_grad)
+
+def test_binary_func():
+ x = nd.uniform(shape=(4, 5))
+ y = nd.uniform(shape=(4, 5))
+ f_add = lambda x, y: x+y
+ f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)]
+ autograd_assert(x, y, func=f_add, grad_func=f_add_grad)
+ f_mul = lambda x, y: x*y
+ f_mul_grad = lambda x, y: [y, x]
+ autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad)
+ f_compose = lambda x, y: x+x*y
+ f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x]
+ autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad)
+
+def test_operator_with_state():
+ def f_fc(a, b, weight, bias):
+ x = a*b
+ fc = nd.FullyConnected(
+ x, weight, bias, num_hidden=32)
+ return fc
+
+ a = nd.uniform(shape=(64, 50))
+ b = nd.uniform(shape=(64, 50))
+ weight = nd.uniform(shape=(32, 50))
+ bias = nd.uniform(shape=(32, ))
+
+ grad_func = grad_and_loss(f_fc)
+ grad_vals, outputs = grad_func(a, b, weight, bias)
+ # (TODO) assert
+
+
+if __name__ == "__main__":
+ test_unary_func()
+ test_binary_func()
+ test_operator_with_state()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index fd05225..2045293 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -691,7 +691,8 @@
exe_deconv.forward(is_train=True)
deconv_out_grad = conv_data[:]
exe_deconv.backward(deconv_out_grad)
- assert_almost_equal(conv_args_grad[1].asnumpy(), deconv_args_grad[1].asnumpy(), rtol=1e-3)
+ assert_almost_equal(conv_args_grad[1].asnumpy(), deconv_args_grad[1].asnumpy(),
+ rtol=2e-3, atol=1e-2)
# Test AddTo
exe_deconv_addto = deconv.bind(default_context(), args=deconv_args,
args_grad=deconv_addto_args_grad,