blob: 3e5b3987522c9ef5ffbc0939468aac076f3d14b1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
#include "./cached_op.h"
namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
#endif
Imperative* Imperative::Get() {
static Imperative inst;
return &inst;
}
OpStatePtr Imperative::InvokeOp(
const Context& ctx,
const nnvm::NodeAttrs& attrs,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs,
const std::vector<OpReqType>& req,
const DispatchMode dispatch_mode,
OpStatePtr state) {
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
const nnvm::Op *op = attrs.op;
std::vector<engine::VarHandle> read_vars, write_vars;
std::vector<Resource> requested;
std::vector<uint32_t> mutate_idx;
SetDependency(attrs, ctx, inputs, outputs,
&read_vars, &write_vars, &requested, &mutate_idx, dispatch_mode);
FCompute fn = common::GetFCompute<FCompute>(op, "FCompute", ctx);
FComputeEx fn_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", ctx);
// FComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
CHECK(dispatch_mode != DispatchMode::kUndefined);
bool dispatch_fcompex = dispatch_mode == DispatchMode::kFComputeEx;
if (fn_ex && dispatch_fcompex) {
PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars,
requested, inputs, outputs, req);
} else if (fn) {
PushFCompute(fn, op, attrs, ctx, read_vars, write_vars,
requested, inputs, outputs, mutate_idx, req);
} else if (createop.count(op) || is_layer_backward.get(op, false)) {
if (!state) {
state = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types);
}
write_vars.push_back(state.get_var());
PushOperator(state, op, attrs, ctx, read_vars, write_vars,
requested, inputs, outputs, mutate_idx, req, dispatch_mode);
} else {
LOG(FATAL)
<< "Operator " << op->name << " is not implemented for "
<< (ctx.dev_mask() == gpu::kDevMask ? "GPU." : "CPU.");
}
return state;
}
OpStatePtr Imperative::Invoke(
const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
using namespace imperative;
static auto& ndfunc = nnvm::Op::GetAttr<FNDArrayFunction>("FNDArrayFunction");
if (ndfunc.count(attrs.op)) {
std::vector<NDArray> p_inputs, p_outputs;
DerefInputOutput(inputs, outputs, &p_inputs, &p_outputs);
ndfunc[attrs.op](attrs, p_inputs, &p_outputs);
for (size_t i = 0; i < outputs.size(); ++i) *outputs[i] = std::move(p_outputs[i]);
return OpStatePtr();
}
// TODO(piiswrong): infer ctx
DispatchMode dispatch_mode = DispatchMode::kUndefined;
Context ctx = GetContext(attrs, inputs, outputs, default_ctx);
SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode);
std::vector<OpReqType> req;
SetWriteInplaceReq(inputs, outputs, &req);
OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode);
// the followinng loop is used for finding out the correct shape when some shapes are dynamic
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i]->shape().ndim() == 0) {
// the WaitToRead overhead here does not seem to be avoidable
outputs[i]->WaitToRead();
outputs[i]->SetShapeFromChunk();
}
}
return ret;
}
void Imperative::MarkVariables(
const std::vector<NDArray*>& variables,
const std::vector<mx_uint>& grad_reqs,
const std::vector<NDArray*>& gradients) {
for (uint32_t i = 0; i < variables.size(); ++i) {
std::string str_c(std::to_string(variable_count_++));
variables[i]->entry_ = nnvm::NodeEntry{
nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
AGInfo& info = AGInfo::Create(variables[i]->entry_.node);
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();
gradients[i]->entry_ = nnvm::NodeEntry{
nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
}
}
void Imperative::GetBackwardDependency(
const nnvm::NodePtr& node,
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs) {
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;
save_inputs.resize(num_inputs);
save_outputs.resize(num_outputs);
std::fill(save_inputs.begin(), save_inputs.end(), false);
std::fill(save_outputs.begin(), save_outputs.end(), false);
node->inputs.clear();
node->inputs.reserve(num_inputs);
for (uint32_t i = 0; i < num_inputs; ++i) {
node->inputs.emplace_back(nnvm::NodeEntry{nullptr, i, 0});
}
if (fgradient.count(node->op())) {
std::vector<nnvm::NodeEntry> ograd_entries;
ograd_entries.reserve(num_outputs);
for (uint32_t i = 0; i < num_outputs; ++i) {
ograd_entries.emplace_back(nnvm::NodeEntry{nullptr, i, 1});
}
auto igrad_entries = fgradient[node->op()](node, ograd_entries);
for (const auto& i : igrad_entries) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == node) {
save_outputs[i.index] = true;
}
}
DFSVisit(igrad_entries, [&](const nnvm::NodePtr& gnode) {
if (!gnode || gnode == node) return;
for (const auto& i : gnode->inputs) {
if (i.node == nullptr && i.version == 0) {
save_inputs[i.index] = true;
} else if (i.node == node) {
save_outputs[i.index] = true;
}
}
});
}
}
void Imperative::RecordOp(
nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs,
const OpStatePtr& state,
std::vector<bool>* p_save_inputs,
std::vector<bool>* p_save_outputs) {
MXAPIThreadLocalEntry *local_buff = MXAPIThreadLocalStore::Get();
for (auto output : outputs) {
CHECK(AGInfo::IsNone(*output))
<< "Assigning to NDArrays that are already in a computational graph "
<< "will cause undefined behavior when evaluating gradients. "
<< "Please call backward first to clear the graph or do this out side of "
<< "a record section. Also note that you cannot use inplace operations "
<< "like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section.";
}
bool need_grad = false;
for (const auto& i : inputs) {
if (AGInfo::IsNone(*i)) continue;
need_grad = true;
break;
}
if (!need_grad) return;
nnvm::NodePtr node = nnvm::Node::Create();
node->attrs = std::move(attrs);
node->attrs.name = "node_" + std::to_string(node_count_++);
AGInfo& info = AGInfo::Create(node);
info.state = state;
info.ctx = outputs[0]->ctx();
if (p_save_inputs == nullptr) {
p_save_inputs = &(local_buff->save_inputs);
p_save_outputs = &(local_buff->save_outputs);
GetBackwardDependency(
node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs);
} else {
node->inputs.resize(inputs.size());
}
std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
if (AGInfo::IsNone(*(inputs[i]))) {
nnvm::NodeEntry entry{nnvm::Symbol::CreateVariable(
"null" + std::to_string(variable_count_++)).outputs[0].node, 0, 0};
AGInfo& input_info = AGInfo::Create(entry.node);
input_info.ctx = inputs[i]->ctx();
if (save_inputs[i]) {
input_info.outputs.emplace_back(*inputs[i]);
} else {
// Put a dummy array here since it will not be used.
input_info.outputs.emplace_back();
input_info.outputs.back().shape_ = inputs[i]->shape();
input_info.outputs.back().dtype_ = inputs[i]->dtype();
input_info.outputs.back().storage_type_ = inputs[i]->storage_type();
}
inputs[i]->entry_ = std::move(entry); // assign last to prevent cyclic reference
} else if (save_inputs[i]) {
AGInfo::Get(inputs[i]->entry_.node).outputs[inputs[i]->entry_.index] = inputs[i]->Detach();
}
node->inputs[i] = inputs[i]->entry_;
}
for (auto output : outputs) {
CHECK(AGInfo::IsNone(*output))
<< "Inplace operations (+=, -=, x[:]=, etc) are not supported when "
<< "recording with autograd.";
}
for (uint32_t i = 0; i < outputs.size(); ++i) {
if (save_outputs[i]) {
info.outputs.emplace_back(outputs[i]->Detach());
} else {
// Put a dummy array here since it will not be used.
info.outputs.emplace_back();
info.outputs.back().shape_ = outputs[i]->shape();
info.outputs.back().dtype_ = outputs[i]->dtype();
info.outputs.back().storage_type_ = outputs[i]->storage_type();
}
outputs[i]->entry_ = nnvm::NodeEntry{node, i, 0};
}
}
std::vector<NDArray*> Imperative::Backward(
const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
const std::vector<NDArray*>& variables,
bool is_train, bool retain_graph,
bool create_graph) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const Op* copy_op = Op::Get("_copy");
// Construct forward graph
Graph graph;
graph.outputs.reserve(outputs.size());
for (const auto& i : outputs) {
CHECK(!AGInfo::IsNone(*i))
<< "Cannot differentiate node because it is not in a computational graph. "
<< "You need to set is_recording to true or use autograd.record() to save "
<< "computational graphs for backward. If you want to differentiate the same "
<< "graph twice, you need to pass retain_graph=True to backward.";
graph.outputs.emplace_back(i->entry_);
}
size_t num_forward_outputs = graph.outputs.size();
// Prepare head gradients
std::vector<NodeEntry> ograd_entries;
ograd_entries.reserve(ograds.size());
for (size_t i = 0; i < outputs.size(); ++i) {
ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
AGInfo& info = AGInfo::Create(ograd_entries.back().node);
info.ctx = outputs[i]->ctx();
if (ograds[i] != nullptr) {
info.outputs.emplace_back(*ograds[i]);
} else {
info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(),
true, outputs[i]->dtype());
info.outputs.back() = static_cast<real_t>(1.0);
}
}
// Get gradient graph
Symbol sym;
sym.outputs = graph.outputs;
std::vector<NodeEntry> xs;
std::vector<NDArray*> x_grads;
std::vector<OpReqType> x_reqs;
if (variables.size()) {
xs.reserve(variables.size());
x_grads.reserve(variables.size());
x_reqs.reserve(variables.size());
for (size_t i = 0; i < variables.size(); ++i) {
CHECK(!AGInfo::IsNone(*variables[i]) &&
AGInfo::IsVariable(variables[i]->entry_.node))
<< "Cannot differentiate with respect to the " << i+1 << "-th variable"
<< " because it does not require gradient.";
xs.emplace_back(variables[i]->entry_);
x_grads.push_back(new NDArray());
x_reqs.push_back(kWriteTo);
}
} else {
std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
xs.reserve(args.size());
x_grads.reserve(args.size());
x_reqs.reserve(args.size());
for (const auto& i : args) {
AGInfo& info = AGInfo::Get(i);
if (info.grad_req == kNullOp) continue;
xs.emplace_back(NodeEntry{i, 0, 0});
x_grads.push_back(&info.out_grads[0]);
x_reqs.push_back(info.grad_req);
info.fresh_out_grad = true;
}
CHECK_GT(xs.size(), 0)
<< "There are no inputs in computation graph that require gradients.";
}
Graph g_graph = pass::MXGradient(
graph, graph.outputs, xs, ograd_entries,
exec::AggregateGradient, nullptr, nullptr,
zero_ops, "_copy");
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
if (e.node->op() == nullptr) {
auto node = Node::Create();
node->attrs.op = copy_op;
node->inputs.push_back(e);
graph.outputs.push_back(NodeEntry{node, 0, 0});
} else {
graph.outputs.push_back(e);
}
}
const auto& idx = graph.indexed_graph();
// get number of nodes used in forward pass
size_t num_forward_nodes = 0;
size_t num_forward_entries = 0;
for (size_t i = 0; i < num_forward_outputs; ++i) {
num_forward_nodes = std::max(
num_forward_nodes, static_cast<size_t>(idx.outputs()[i].node_id + 1));
num_forward_entries = std::max(
num_forward_entries, static_cast<size_t>(idx.entry_id(idx.outputs()[i])) + 1);
}
// Allocate buffer
std::vector<NDArray> buff(idx.num_node_entries());
std::vector<uint32_t> ref_count(buff.size(), 0);
std::vector<OpStatePtr> states;
std::vector<NDArray*> arrays;
arrays.reserve(buff.size());
for (auto& buffered_array : buff) {
arrays.push_back(&buffered_array);
}
if (create_graph) {
states.resize(num_forward_nodes);
nnvm::DFSVisit(sym.outputs, [&](const nnvm::NodePtr& n) {
AGInfo& info = AGInfo::Get(n);
states[idx.node_id(n.get())] = info.state;
for (uint32_t i = 0; i < info.outputs.size(); ++i) {
CHECK(idx.exist(n.get()));
size_t nid = idx.node_id(n.get());
size_t eid = idx.entry_id(nid, i);
buff[eid] = info.outputs[i];
buff[eid].entry_ = NodeEntry{n, i, 0};
ref_count[eid] = 1;
}
});
for (auto& ograd_entry : ograd_entries) {
AGInfo& info = AGInfo::Get(ograd_entry.node);
if (!idx.exist(ograd_entry.node.get())) continue;
size_t eid = idx.entry_id(ograd_entry);
buff[eid] = info.outputs[0];
buff[eid].entry_ = ograd_entry;
}
} else {
states.reserve(num_forward_nodes);
for (size_t i = 0; i < num_forward_nodes; ++i) {
const AGInfo& info = dmlc::get<AGInfo>(idx[i].source->info);
states.emplace_back(info.state);
for (size_t j = 0; j < info.outputs.size(); ++j) {
size_t eid = idx.entry_id(i, j);
arrays[eid] = const_cast<NDArray*>(&(info.outputs[j]));
if (retain_graph || info.grad_req != kNullOp) ref_count[eid] = 1;
}
}
for (auto& ograd_entry : ograd_entries) {
if (!idx.exist(ograd_entry.node.get())) continue;
AGInfo& info = AGInfo::Get(ograd_entry.node);
arrays[idx.entry_id(ograd_entry)] = &info.outputs[0];
}
}
for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) {
size_t eid = idx.entry_id(graph.outputs[i]);
arrays[eid] = x_grads[i - num_forward_outputs];
ref_count[eid] = 1;
}
// Assign context
auto vctx = PlaceDevice(idx);
// Infer shape type
{
std::pair<uint32_t, uint32_t> node_range, entry_range;
node_range = {num_forward_nodes, idx.num_nodes()};
entry_range = {num_forward_entries, idx.num_node_entries()};
ShapeVector shapes;
shapes.reserve(idx.num_node_entries());
bool contain_unknown = false;
for (const auto& i : arrays) shapes.emplace_back(i->shape());
CheckAndInferShape(&graph, std::move(shapes), false,
node_range, entry_range, &contain_unknown);
DTypeVector dtypes;
dtypes.reserve(idx.num_node_entries());
for (const auto& i : arrays) dtypes.emplace_back(i->dtype());
CheckAndInferType(&graph, std::move(dtypes), false,
node_range, entry_range);
StorageTypeVector stypes;
stypes.reserve(idx.num_node_entries());
for (const auto& i : arrays) stypes.emplace_back(i->storage_type());
exec::DevMaskVector dev_mask;
dev_mask.reserve(idx.num_nodes());
for (const auto& i : vctx) dev_mask.emplace_back(i.dev_mask());
CheckAndInferStorageType(&graph, std::move(dev_mask), std::move(stypes), false,
node_range, entry_range);
}
// Calculate ref count
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) {
++ref_count[idx.entry_id(j)];
}
}
// Assign reqs
std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) {
size_t eid = idx.entry_id(idx.outputs()[i]);
array_reqs[eid] = x_reqs[i - num_forward_outputs];
}
const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape");
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
const auto& stypes = graph.GetAttr<StorageTypeVector>("storage_type");
const auto& dispatch_modes = graph.GetAttr<DispatchModeVector>("dispatch_mode");
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
auto num_outputs = idx[i].source->num_outputs();
for (size_t j = 0; j < num_outputs; ++j) {
auto eid = idx.entry_id(i, j);
if (!arrays[eid]->is_none()) continue;
if (stypes[eid] == kDefaultStorage) {
*arrays[eid] = NDArray(shapes[eid], vctx[i], true, dtypes[eid]);
} else {
*arrays[eid] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], vctx[i], true, dtypes[eid]);
}
}
}
// Execution
bool prev_recording = set_is_recording(create_graph);
bool prev_training = set_is_training(is_train);
int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
try {
RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
is_recording());
} catch (const dmlc::Error& e) {
Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
set_is_training(prev_training);
throw e;
}
Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
set_is_training(prev_training);
// Clear history
if (!retain_graph) {
nnvm::DFSVisit(sym.outputs, [&](const nnvm::NodePtr& n) {
AGInfo::Clear(n);
n->inputs.clear();
});
}
if (variables.size()) {
return x_grads;
}
return {};
}
} // namespace mxnet