blob: 4927191a59642beba98ce72121af3c44c3600123 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <mxnet/base.h>
#include <algorithm>
#include <functional>
namespace nnvm {
namespace pass {
namespace {
// default aggregate gradient function
// require operator zeros and elemwise_sum to be presented.
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("zeros");
zero_node->attrs.name = "zero_grad";
zero_node->attrs.op->attr_parser(&(zero_node->attrs));
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v);
sum_node->attrs.name = "grad_sum";
sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size());
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
return NodeEntry{sum_node, 0, 0};
}
}
bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
const std::vector<const Op*>& zero_ops) {
if (!grads.size() || !zero_ops.size()) return false;
for (const auto& g : grads) {
bool found = false;
for (const auto& op : zero_ops) {
if (g.node->op() == op) {
found = true;
break;
}
}
if (!found) return false;
}
return true;
}
// helper entry
struct GradEntry {
#ifdef _MSC_VER
NodeEntry sum = NodeEntry{nullptr, 0, 0};
#else
NodeEntry sum{nullptr, 0, 0};
#endif
std::vector<NodeEntry> grads;
bool need_attr_hint{true};
};
Graph Gradient(Graph src) {
using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>;
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
CHECK_NE(src.attrs.count("grad_ys"), 0U)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0U)
<< "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
}
AttrHintFun attr_hint_fun = nullptr;
if (src.attrs.count("attr_hint_fun") != 0) {
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
}
std::vector<const Op*> zero_ops;
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}
const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
Op::Get(src.GetAttr<std::string>("copy_op")) :
nullptr;
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i];
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
}
// Check that all xs are reachable from ys
for (size_t i = 0; i < xs.size(); ++i) {
CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
<< "Cannot differentiate with respect to the " << i+1 << "-th variable "
<< "because it is unreachable from the outputs.";
}
// construct mirror reduece memory strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
NodePtr new_node = Node::Create();
*new_node = *n;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
}
}
}
// traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
static auto& finfer_shape = Op::GetAttr<mxnet::FInferShape>("FInferShape");
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(ptr.get());
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
GradEntry& e = out_grad_vec[i];
e.sum = agg_fun(std::move(e.grads));
if (e.need_attr_hint && attr_hint_fun != nullptr) {
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
}
out_agg_grads.push_back(e.sum);
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
} else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
for (size_t i = 0; i < fwd_node->num_inputs(); ++i) {
std::ostringstream os;
if (1 == fwd_node->num_inputs()) {
os << fwd_node->attrs.name << "_backward";
} else {
os << fwd_node->attrs.name << "_in" << i << "_backward";
}
auto p = Node::Create();
p->attrs.op = zero_ops[0];
p->attrs.name = os.str();
p->inputs.push_back(fwd_node->inputs[i]);
p->control_deps.emplace_back(fwd_node);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0});
}
} else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
}
}
}
// take out the xs' grads
Graph ret;
ret.outputs.resize(xs.size());
NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
size_t counter = 0;
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
entry.sum = attr_hint_fun(entry.sum, e);
}
}
if (copy_op != nullptr) {
auto kv = unique_grads.find(entry.sum);
if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else {
NodePtr copy_node = Node::Create();
std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++;
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(entry.sum);
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
}
} else {
ret.outputs[counter] = entry.sum;
}
++counter;
}
if (copy_op != nullptr) {
for (const auto& kv : unique_grads) {
ret.outputs[kv.second.second] = kv.first;
}
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(MXGradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
} // namespace nnvm