blob: 447f658890d8d127c5098851aea2bb07582689d6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <mxnet/base.h>
#include <algorithm>
#include <deque>
#include <fstream>
#include <functional>
#include <queue>
#include <sstream>
#include <unordered_set>
#include <vector>
#include "error.h"
#include "../imperative/exec_pass.h"
namespace nnvm {
namespace pass {
extern size_t MXGetDTypeSize(const int type_flag); // defined in plan_memory.cc
namespace {
/*! Auxiliary Data Structure for Gradient Entries */
struct GradEntry {
NodeEntry sum = NodeEntry(nullptr, 0, 0);
std::vector<NodeEntry> grads;
};
/*!
* \brief Build the backward graph from the mirror map. This function will be
* invoked twice if backward mirroring has been enabled.
*/
Graph BuildGradientGraph(
const Graph& src,
const std::vector<NodeEntry>& xs,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map);
/*!
* \brief Auxiliary function that maps the forward node of the source graph to
* its corresponding node on the mirror path.
*/
inline const ObjectPtr& MapFwdNodeToMirrorPath(
const ObjectPtr& n,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
auto iter = mirror_map.find(n.get());
if (iter == mirror_map.end() ||
iter->second == nullptr) {
return n;
}
return iter->second;
}
Graph Gradient(Graph src) {
CHECK_NE(src.attrs.count("grad_ys"), 0U)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0U)
<< "Gradient require grad_xs to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented.";
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
CHECK_EQ(ys.size(), ys_out_grad.size());
// initialize a topological order of the graph nodes and `output_grads`
// that maps every operator node to its gradient entries
std::vector<ObjectPtr> topo_order;
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys,
[&](const ObjectPtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
for (size_t i = 0; i < ys.size(); ++i) {
output_grads[ys[i].node.get()][ys[i].index].grads = {ys_out_grad[i]};
}
// check that all xs are reachable from ys
for (size_t i = 0; i < xs.size(); ++i) {
CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
<< "Cannot differentiate with respect to the "
<< (i + 1) << "-th variable "
<< "because it is unreachable from the outputs.";
}
using MirrorFun = std::function<int (const Node&)>;
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("mirror_fun");
}
std::unordered_map<const Node*, ObjectPtr> mirror_map;
// complete the backward graph of the src, but without backward mirroring
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order,
output_grads,
nullptr, mirror_map);
if (mirror_fun == nullptr) {
return gsrc; // Gradient pass without mirroring ends here.
}
const IndexedGraph& idx = src.indexed_graph(),
& gidx = gsrc.indexed_graph();
// ===========================================================================
// ----- Gradient Pass w/ Backward Mirroring -----
// ===========================================================================
// Record, for each node entry ∈ gsrc, the nodes that reference it as inputs.
// It is important to note that since the node entry reference mapping is
// constructed from gradient graph, it can only be indexed using gidx entry ID.
std::vector<std::unordered_set<const Node*> > node_entry_ref_map(
gidx.num_node_entries());
static const auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
for (uint32_t gnid = 0; gnid < gidx.num_nodes(); ++gnid) {
const IndexedGraph::Node& inode = gidx[gnid];
if (inode.source->is_variable()) {
continue;
}
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
if (fignore_inputs.count(inode.source->op()) != 0) {
std::vector<uint32_t> ignore_inputs =
fignore_inputs[inode.source->op()](inode.source->attrs);
if (std::find(ignore_inputs.begin(), ignore_inputs.end(), i)
!= ignore_inputs.end()) {
continue;
}
}
node_entry_ref_map[gidx.entry_id(inode.inputs[i])].insert(inode.source);
}
} // for (gnid ∈ gidx.num_nodes())
// Inference the shapes and data types of the gradient graphs. Those
// information is needed in later stages to determine whether putting a node
// on the mirror path can be beneficial or not.
using mxnet::ShapeVector;
ShapeVector in_arg_shapes = src.GetAttr<ShapeVector>("in_arg_shapes");
DTypeVector in_arg_dtypes = src.GetAttr<DTypeVector>("in_arg_dtypes");
src = mxnet::exec::InferShape(std::move(src), std::move(in_arg_shapes), "__shape__");
src = mxnet::exec::InferType(std::move(src), std::move(in_arg_dtypes), "__dtype__");
CHECK(src.GetAttr<size_t>("shape_num_unknown_nodes") == 0U);
CHECK(src.GetAttr<size_t>("dtype_num_unknown_nodes") == 0U);
const ShapeVector& src_shapes = src.GetAttr<ShapeVector>("shape");
const DTypeVector& src_dtypes = src.GetAttr<DTypeVector>("dtype");
std::queue<const Node*> worklist;
// initialize the worklist to the output nodes
for (const NodeEntry& e : src.outputs) {
worklist.push(e.node.get());
}
for (; !worklist.empty(); worklist.pop()) {
const Node* const workitem = worklist.front();
// skip the current node if it has already been recorded in the mirror map
if (mirror_map.find(workitem) != mirror_map.end()) {
continue;
}
// subgraph and its frontier and topological-sorted view
std::unordered_set<const Node*> subgraph;
// The associated boolean variable is used for marking forward propagation.
std::unordered_map<const Node*, bool> subgraph_frontier;
std::deque<const Node*> subgraph_topo_order;
// =========================================================================
// --- Backward Pass ---
// =========================================================================
// The sub-worklist is used for constructing the subgraph. It is initialized
// to have the current workitem node.
std::queue<const Node*> subworklist;
subworklist.push(workitem);
// Local auxiliary function that does backpropagation on the subworklist
// items to construct the subgraph. E.g.,
// A subworklist = {A}
// ↑
// B
// After invoking this function. `subgraph` will become {A, B}.
// Note that this function will be invoked multiple times.
auto subworklist_backprop = [&subworklist, &subgraph,
&subgraph_topo_order,
&mirror_fun, &worklist]() {
std::deque<const Node*> subworklist_topo_order;
for (; !subworklist.empty(); subworklist.pop()) {
const Node* const subworkitem = subworklist.front();
if (subgraph.find(subworkitem) == subgraph.end()) {
subgraph.insert(subworkitem);
subworklist_topo_order.push_front(subworkitem);
}
for (const NodeEntry& e : subworkitem->inputs) {
if (!mirror_fun(*(e.node))) {
worklist.push(e.node.get());
} else {
subworklist.push(e.node.get());
}
}
for (const ObjectPtr& n : subworkitem->control_deps) {
if (!mirror_fun(*n)) {
worklist.push(n.get());
} else {
subworklist.push(n.get());
}
}
} // for (subworkitem ∈ subworklist)
// please refer to later comments for why the topological order of the
// subworklist should be directly appended to that of the subgraph
subgraph_topo_order.insert(subgraph_topo_order.end(),
subworklist_topo_order.begin(),
subworklist_topo_order.end());
};
// Start propagating from the current workitem node backward until the
// mirroring function returns false (indicating that a compute-heavy layer
// has been hit), in which case we put the node that fails the mirroring
// function into the worklist as the new head. During the traversal, we
// build up the subgraph and its topological order at the same time.
subworklist_backprop();
// Forward propagate the subgraph nodes in topological order and make sure
// that all the node entries that are part of the forward propagation belong
// to the same subgraph. This process continues until all the node entries
// have been included, in which case we say that the subgraph has converged.
//
// The reason why this step is needed is because, consider the example below:
// A B C subworklist = {A}
// ↑ ↑ ↑
// ↖ ↑ ↗
// D
// Without loss of generality, suppose that the previous backpropagation
// starts from node A, then the subgraph will only contain branch D → A.
// However, we want to include branch D → B adn D → C as well since all
// three branches share the same node entries (i.e., the outputs of D) and
// hence they are all affected by the decision on whether D should be put
// onto the mirror path or not.
bool has_subgraph_converged;
do {
has_subgraph_converged = true;
for (const Node* const subgraph_node : subgraph_topo_order) {
for (const NodeEntry& subgraph_node_entry :
subgraph_node->inputs) {
const std::unordered_set<const Node*> ref_nodes =
node_entry_ref_map[gidx.entry_id(subgraph_node_entry)];
for (const Node* const ref_node : ref_nodes) {
// If there are other nodes that reference the node entry and that
// node satisfies the following conditions:
// (1) belongs to the forward graph, and
// (2) is not part of the subgraph
// We add that node to the subgraph and adjust the topological order
// accordingly.
if (ref_node != subgraph_node && idx.exist(ref_node) &&
subgraph.find(ref_node) == subgraph.end()) {
// Forward propagate from the reference node until the mirroring
// function returns false. This indicates that the head of the
// branch has been reached (i.e., B or C in our previously
// illustrated example), and we add it to the subworklist for
// another backpropagation.
std::queue<const Node*> ref_node_heads;
ref_node_heads.push(ref_node);
for (; !ref_node_heads.empty(); ref_node_heads.pop()) {
const Node* const ref_node_head = ref_node_heads.front();
bool is_ref_node_head_output = false;
for (const NodeEntry& y : ys) {
if (ref_node_head == y.node.get()) {
is_ref_node_head_output = true;
}
}
if (!mirror_fun(*ref_node_head) || is_ref_node_head_output) {
subworklist.push(ref_node_head);
continue;
}
uint32_t gnid = gidx.node_id(ref_node_head);
for (uint32_t oid = 0; oid < ref_node_head->num_outputs(); ++oid) {
uint32_t geid = gidx.entry_id(gnid, oid);
for (const Node* const n : node_entry_ref_map[geid]) {
if (idx.exist(n)) {
ref_node_heads.push(n);
}
}
} // for (oid ∈ [0, ref_node_head->num_outputs()))
} // for (ref_node_head ∈ ref_node_heads)
// Do the backpropagation again. The topological order of the
// subworklist can be directly appended to the end of the existing
// order. E,g, in our previous example, we expect to have
// `subgraph_topo_order` = {D, A} + {B} + {C}
subworklist_backprop();
// indicate that the subgraph has not changed the quit the loop
has_subgraph_converged = false;
break;
} // if (ref_node != subgraph_node && idx.exist(ref_node) &&
// subgraph.find(ref_node) == subgraph.end()
} // for (ref_node ∈ ref_nodes)
if (!has_subgraph_converged) {
break;
}
} // for (subgraph_node_entry ∈ subgraph_node->inputs)
if (!has_subgraph_converged) {
break;
}
} // for (subgraph_node ∈ subgraph_topo_order)
} while (!has_subgraph_converged);
// =========================================================================
// --- Forward Pass ---
// =========================================================================
// Now that the subgraph is complete, we start by assuming that all the
// nodes in the subgraph can be mirrored, and forward propagate starting
// from the subgraph frontier. The propagation is successful if the amount
// of storage released by removing the frontier nodes off the mirror path is
// greater or equal to the storage allocated.
do {
has_subgraph_converged = true;
// Obtain the subgraph frontier. The subgraph frontier denotes a group of
// nodes whose inputs satisfy the following conditions:
// (1) fails the mirroring function, or
// (2) has been marked as NOT on the mirror path, i.e.,
// `mirror_map[input_node] == nullptr`
// E.g., consider the subgraph below:
// A
// ↑
// B
// ↑
// C
// The subgraph frontier in this example is {C}. As C is the place where
// the mirror path (and hence the forward propagation) starts.
subgraph_frontier.clear();
for (const Node* const subgraph_node : subgraph) {
if (!mirror_fun(*subgraph_node)) {
mirror_map[subgraph_node] = nullptr;
continue;
}
if (mirror_map.find(subgraph_node) != mirror_map.end()) {
continue;
}
bool is_frontier = true;
for (const NodeEntry& e : subgraph_node->inputs) {
auto iter = mirror_map.find(e.node.get());
if (mirror_fun(*(e.node)) &&
!(iter != mirror_map.end() && iter->second == nullptr)) {
is_frontier = false;
}
}
for (const ObjectPtr& n : subgraph_node->control_deps) {
auto iter = mirror_map.find(n.get());
if (mirror_fun(*n) &&
!(iter != mirror_map.end() && iter->second == nullptr)) {
is_frontier = false;
}
}
if (is_frontier) {
subgraph_frontier.emplace(subgraph_node, false);
}
} // for (subgraph_node ∈ subgraph)
for (std::pair<const Node* const, bool>& frontier_node : subgraph_frontier) {
if (frontier_node.second) {
// If the frontier node has been marked as true, then this indicates
// that the node has been forward propagated before (by other nodes
// that share the same input).
continue;
}
// As we do the forward propagation, we not only propagate the current
// frontier node individually, but all the frontier nodes that share the
// same input with the current one. This is a recursive progress because
// it is possible for A to share the same input with B and B, at the
// same time, to share the same input with C, like in the graph below:
// D E
// ↑ ↑
// ↗ ↖ ↗ ↖
// A B C
std::unordered_set<const Node*> forward_candidates{frontier_node.first};
frontier_node.second = true;
bool has_forward_candidates_converged;
do {
has_forward_candidates_converged = true;
for (const Node* const candidate : forward_candidates) {
for (const NodeEntry& candidate_input : candidate->inputs) {
uint32_t geid = gidx.entry_id(candidate_input);
const std::unordered_set<const Node*>& ref_nodes = node_entry_ref_map[geid];
for (const Node* const ref_node : ref_nodes) {
if (ref_node != frontier_node.first &&
subgraph_frontier.find(ref_node) != subgraph_frontier.end() &&
forward_candidates.find(ref_node) == forward_candidates.end()) {
subgraph_frontier[ref_node] = true;
forward_candidates.insert(ref_node);
has_forward_candidates_converged = false;
}
} // for (ref_node ∈ ref_nodes)
if (!has_forward_candidates_converged) {
break;
}
} // for (candidate_input ∈ candidate->inputs)
if (!has_forward_candidates_converged) {
break;
}
} // for (candidate ∈ forward_candidates)
} while (!has_forward_candidates_converged);
// Record the node entries that are newly allocated and those that are
// released. A node entry can be released if all its referencing nodes
// are part of the subgraph frontier. Otherwise, it is in the allocated set.
std::unordered_set<uint32_t> newly_allocated_node_entries,
released_node_entries;
for (const Node* const candidate : forward_candidates) {
uint32_t nid = idx.node_id(candidate),
gnid = gidx.node_id(candidate);
for (uint32_t oid = 0; oid < candidate->num_outputs(); ++oid) {
uint32_t eid = idx.entry_id(nid, oid),
geid = gidx.entry_id(gnid, oid);
if (node_entry_ref_map[geid].size() != 0) {
newly_allocated_node_entries.insert(eid);
}
}
for (const NodeEntry& candidate_input : candidate->inputs) {
uint32_t eid = idx.entry_id(candidate_input),
geid = gidx.entry_id(candidate_input);
const std::unordered_set<const Node*>& ref_nodes = node_entry_ref_map[geid];
bool can_be_released = true;
for (const Node* const ref_node : ref_nodes) {
if (subgraph_frontier.find(ref_node) == subgraph_frontier.end()) {
newly_allocated_node_entries.insert(eid);
can_be_released = false;
}
}
if (can_be_released) {
released_node_entries.insert(eid);
}
} // for (candidate_input ∈ candidate->input)
} // for (candidate ∈ forward_candidates)
// Now, compare the total amount of newly allocated storage versus the
// released storage, if the latter is greater or equal to the former,
// then we remove the current node from the frontier. Otherwise all the
// forward candidate nodes are marked as on the mirror path.
size_t newly_allocated_storage = 0, released_storage = 0;
for (const uint32_t eid : newly_allocated_node_entries) {
newly_allocated_storage += src_shapes[eid].Size() *
MXGetDTypeSize(src_dtypes[eid]);
}
for (const uint32_t eid : released_node_entries) {
released_storage += src_shapes[eid].Size() * MXGetDTypeSize(src_dtypes[eid]);
}
if (released_storage >= newly_allocated_storage) {
for (const Node* const candidate : forward_candidates) {
CHECK(subgraph_frontier.find(candidate) != subgraph_frontier.end());
subgraph_frontier.erase(candidate);
mirror_map[candidate] = nullptr;
}
has_subgraph_converged = false;
break;
} // if (released_storage >= newly_allocated_storage)
} // for (frontier_node ∈ subgraph_frontier)
} while (!has_subgraph_converged);
// Finally, mark all the remaining nodes of the subgraph as on the mirror path.
for (const Node* const subgraph_node : subgraph_topo_order) {
if (mirror_map.find(subgraph_node) != mirror_map.end()) {
continue;
}
ObjectPtr subgraph_node_mirror = Node::Create();
*subgraph_node_mirror = *subgraph_node;
subgraph_node_mirror->attrs.name += "_mirror";
for (NodeEntry& e : subgraph_node_mirror->inputs) {
e.node = MapFwdNodeToMirrorPath(e.node, mirror_map);
}
for (ObjectPtr& n : subgraph_node_mirror->control_deps) {
n = MapFwdNodeToMirrorPath(n, mirror_map);
}
mirror_map[subgraph_node] = subgraph_node_mirror;
}
} // for (workitem ∈ worklist)
DFSVisit(ys,
[&](const ObjectPtr& node) {
if (mirror_map.at(node.get()) != nullptr) {
node->attrs.dict["__mirror_stage__"] = "1";
} else {
node->attrs.dict["__mirror_stage__"] = "0";
}
});
return BuildGradientGraph(src, xs, topo_order,
output_grads,
mirror_fun, mirror_map);
}
/*!
* \brief Auxiliary function that checks whether all the gradients are zero or not.
*/
inline bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
const std::vector<const Op*>& zero_ops) {
if (!grads.size() || !zero_ops.size()) return false;
for (const auto& g : grads) {
bool found = false;
for (const auto& op : zero_ops) {
if (g.node->op() == op) {
found = true;
break;
}
}
if (!found) return false;
}
return true;
}
Graph BuildGradientGraph(
const Graph& src,
const std::vector<NodeEntry>& xs,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
static auto& grad_fun_map = Op::GetAttr<nnvm::FGradient>("FGradient");
// gradient aggregation function
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&&)>;
AggFun agg_fun = [](std::vector<NodeEntry>&& v)->NodeEntry {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
ObjectPtr zero_grad_node = Node::Create();
zero_grad_node->attrs.op = Op::Get("zeros");
zero_grad_node->attrs.name = "zero_grad";
zero_grad_node->attrs.op->attr_parser(&(zero_grad_node->attrs));
return NodeEntry{zero_grad_node, 0, 0};
} else {
ObjectPtr grad_sum_node = Node::Create();
grad_sum_node->attrs.op = Op::Get("elemwise_sum");
grad_sum_node->inputs = std::move(v);
grad_sum_node->attrs.name = "grad_sum";
grad_sum_node->attrs.dict["num_args"] =
std::to_string(grad_sum_node->inputs.size());
grad_sum_node->attrs.op->attr_parser(&(grad_sum_node->attrs));
return NodeEntry{grad_sum_node, 0, 0};
}
};
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
// zero and copy operators
std::vector<const Op*> zero_ops;
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}
const Op* copy_op = (src.attrs.count("copy_op_str") != 0) ?
Op::Get(src.GetAttr<std::string>("copy_op_str")) : nullptr;
std::vector<NodeEntry> out_agg_grads;
for (auto topo_order_rit = topo_order.rbegin();
topo_order_rit != topo_order.rend(); ++topo_order_rit) {
const ObjectPtr& src_fwd_node = *topo_order_rit;
if (src_fwd_node->is_variable()) continue;
// gather all the output gradient entries and apply the aggregation function
out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(src_fwd_node.get());
for (auto & e : out_grad_vec) {
e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum);
}
if (src_fwd_node->inputs.size() != 0) {
// If the current node has inputs, the gradients need to be further
// propagated backward.
ObjectPtr fwd_node = MapFwdNodeToMirrorPath(src_fwd_node, mirror_map);
// calculate the input gradients
std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(src_fwd_node->op())) {
input_grads = grad_fun_map[src_fwd_node->op()](fwd_node, out_agg_grads);
CHECK_EQ(src_fwd_node->inputs.size(), input_grads.size())
<< "The Gradient function is not returning enough gradients.";
// If the operator node fails the mirror function, it is however still
// possible for its feature maps to be recomputed without incurring
// significant runtime overhead. The reason is because some operators
// have their feature maps sit on the inputs rather than the outputs.
// E.g., the fully-connected layer (Y=XW^T), whose gradients are given
// by dX = dYW, dW = dY^TX and hence have no data dependency on Y.
if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) {
for (NodeEntry& input_grad : input_grads) {
for (NodeEntry& grad_input : input_grad.node->inputs) {
const ObjectPtr& grad_input_node_mirrored = MapFwdNodeToMirrorPath(
grad_input.node, mirror_map);
grad_input = NodeEntry(
grad_input_node_mirrored,
grad_input.index,
grad_input.version);
} // for (grad_input ∈ input_grad.node->inputs)
} // for (input_grad ∈ input_grads)
} // if (mirror_fun != nullptr && !mirror_fun(*fwd_node))
} else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
for (size_t i = 0; i < src_fwd_node->num_inputs(); ++i) {
std::ostringstream os;
if (1 == src_fwd_node->num_inputs()) {
os << fwd_node->attrs.name << "_backward";
} else {
os << fwd_node->attrs.name << "_in" << i << "_backward";
}
auto p = Node::Create();
p->attrs.op = zero_ops[0];
p->attrs.name = os.str();
p->inputs.push_back(fwd_node->inputs[i]);
p->control_deps.emplace_back(fwd_node);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(p, 0, 0);
} // for (i ∈ src_fwd_node->num_inputs())
} else {
std::string message = "Operator " + std::string(src_fwd_node->op()->name)
+ "is non-differentiable because it didn't register FGradient attribute.";
throw nnvm::pass::InvalidGraphError(message);
}
for (const auto& e : input_grads) {
CHECK(e.node);
}
auto input_grad_iter = input_grads.begin();
CHECK(src_fwd_node->inputs.size() <= input_grads.size());
for (auto input_iter = src_fwd_node->inputs.begin();
input_iter != src_fwd_node->inputs.end();
++input_iter, ++input_grad_iter) {
// propagate the input gradients to the output gradients of the input nodes
output_grads[input_iter->node.get()][input_iter->index]
.grads.emplace_back(std::move(*input_grad_iter));
}
} // if (src_fwd_node->inputs.size() != 0)
} // for (topo_order_rit ∈ reverse(topo_order))
// take out the xs' grads
Graph ret;
ret.outputs.resize(xs.size());
NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
size_t counter = 0;
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
}
if (copy_op != nullptr) {
auto kv = unique_grads.find(entry.sum);
if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else {
ObjectPtr copy_node = Node::Create();
std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++;
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(entry.sum);
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0},
std::make_pair(1, counter));
}
} else {
ret.outputs[counter] = entry.sum;
}
++counter;
} // for (e ∈ xs)
if (copy_op != nullptr) {
for (const auto& kv : unique_grads) {
ret.outputs[kv.second.second] = kv.first;
}
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(MXGradient)
.describe(R"(Return a gradient graph of src.attrs["ys"] wrt src.attrs["xs"])")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("in_arg_shapes")
.depend_graph_attr("in_arg_dtypes")
.depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
} // namespace nnvm