| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include <algorithm> |
| #include <iostream> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| #include "./imperative_utils.h" |
| #include "./cached_op.h" |
| |
| namespace nnvm { |
| ObjectPtr CreateVariableNode(const std::string& name); |
| } |
| |
| namespace mxnet { |
| #if DMLC_CXX11_THREAD_LOCAL |
| thread_local bool Imperative::is_train_ = false; |
| thread_local bool Imperative::is_recording_ = false; |
| thread_local bool Imperative::is_deferred_compute_ = false; |
| thread_local OptConstraint Imperative::opt_constraints_ = OptConstraint::None; |
| thread_local bool Imperative::is_np_shape_thread_local_ = false; |
| #else |
| MX_THREAD_LOCAL bool Imperative::is_train_ = false; |
| MX_THREAD_LOCAL bool Imperative::is_recording_ = false; |
| MX_THREAD_LOCAL bool Imperative::is_deferred_compute_ = false; |
| MX_THREAD_LOCAL OptConstraint Imperative::opt_constraints_ = OptConstraint::None; |
| MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false; |
| #endif |
| |
| Imperative* Imperative::Get() { |
| static Imperative inst; |
| return &inst; |
| } |
| |
| OpStatePtr Imperative::InvokeOp(const Context& ctx, |
| const nnvm::NodeAttrs& attrs, |
| const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs, |
| const std::vector<OpReqType>& req, |
| const DispatchMode dispatch_mode, |
| OpStatePtr state) { |
| using namespace imperative; |
| static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState"); |
| static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward"); |
| MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); |
| |
| const nnvm::Op* op = attrs.op; |
| |
| std::vector<engine::VarHandle> read_vars, write_vars; |
| std::vector<Resource> requested; |
| std::vector<uint32_t> mutate_idx; |
| SetDependency( |
| attrs, ctx, inputs, outputs, &read_vars, &write_vars, &requested, &mutate_idx, dispatch_mode); |
| |
| FCompute fn = common::GetFCompute<FCompute>(op, "FCompute", ctx); |
| FComputeEx fn_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", ctx); |
| |
| // FComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx |
| CHECK(dispatch_mode != DispatchMode::kUndefined); |
| bool dispatch_fcompex = dispatch_mode == DispatchMode::kFComputeEx; |
| if (fn_ex && dispatch_fcompex) { |
| PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars, requested, inputs, outputs, req); |
| } else if (fn) { |
| PushFCompute( |
| fn, op, attrs, ctx, read_vars, write_vars, requested, inputs, outputs, mutate_idx, req); |
| } else if (createop.count(op) || is_layer_backward.get(op, false)) { |
| if (!state) { |
| state = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types); |
| } |
| write_vars.push_back(state.get_var()); |
| PushOperator(state, |
| op, |
| attrs, |
| ctx, |
| read_vars, |
| write_vars, |
| requested, |
| inputs, |
| outputs, |
| mutate_idx, |
| req, |
| dispatch_mode); |
| } else { |
| LOG(FATAL) << "Operator " << op->name << " is not implemented for " |
| << (ctx.dev_mask() == gpu::kDevMask ? "GPU." : "CPU."); |
| } |
| |
| return state; |
| } |
| |
| OpStatePtr Imperative::Invoke(const Context& default_ctx, |
| const nnvm::NodeAttrs& attrs, |
| const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs) { |
| using namespace imperative; |
| static auto& ndfunc = nnvm::Op::GetAttr<FNDArrayFunction>("FNDArrayFunction"); |
| |
| if (ndfunc.count(attrs.op)) { |
| std::vector<NDArray> p_inputs, p_outputs; |
| DerefInputOutput(inputs, outputs, &p_inputs, &p_outputs); |
| ndfunc[attrs.op](attrs, p_inputs, &p_outputs); |
| for (size_t i = 0; i < outputs.size(); ++i) |
| *outputs[i] = std::move(p_outputs[i]); |
| return OpStatePtr(); |
| } |
| |
| // TODO(piiswrong): infer ctx |
| DispatchMode dispatch_mode = DispatchMode::kUndefined; |
| Context ctx = GetContext(attrs, inputs, outputs, default_ctx); |
| SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); |
| std::vector<OpReqType> req; |
| SetWriteInplaceReq(inputs, outputs, &req); |
| OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); |
| // the followinng loop is used for finding out the correct shape when some shapes are dynamic |
| for (auto output : outputs) { |
| if (!shape_is_known(output->shape())) { |
| // the WaitToRead overhead here does not seem to be avoidable |
| output->WaitToRead(); |
| output->SetShapeFromChunk(); |
| } |
| } |
| return ret; |
| } |
| |
| // Create nnvm::NodeEntry for variables' and gradients' autograd_entry_ |
| // attribute and associate AGInfo with it's info attribute |
| void Imperative::MarkVariables(const std::vector<NDArray*>& variables, |
| const std::vector<uint32_t>& grad_reqs, |
| const std::vector<NDArray*>& gradients) { |
| for (uint32_t i = 0; i < variables.size(); ++i) { |
| // Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't. |
| if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) { |
| std::string str_c(std::to_string(variable_count_++)); |
| variables[i]->autograd_entry_ = |
| nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; |
| AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node); |
| info.outputs.emplace_back(variables[i]->Detach()); |
| info.out_grads.emplace_back(gradients[i]->Detach()); |
| info.grad_req = static_cast<OpReqType>(grad_reqs[i]); |
| info.ctx = variables[i]->ctx(); |
| |
| gradients[i]->autograd_entry_ = |
| nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; |
| AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node); |
| grad_info.outputs.emplace_back(gradients[i]->Detach()); |
| grad_info.ctx = gradients[i]->ctx(); |
| } else { |
| AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node); |
| CHECK_EQ(info.out_grads.size(), 0) |
| << "The node has already been marked. Cannot mark it again."; |
| info.out_grads.emplace_back(gradients[i]->Detach()); |
| info.grad_req = static_cast<OpReqType>(grad_reqs[i]); |
| info.ctx = variables[i]->ctx(); |
| } |
| } |
| } |
| |
| // Unmark the variables to free the memory. |
| void Imperative::DropGrads(const std::vector<NDArray*>& variables) { |
| for (auto variable : variables) { |
| if (variable->autograd_entry_.node) { |
| AGInfo& info = AGInfo::Get(variable->autograd_entry_.node); |
| CHECK_NE(info.out_grads.size(), 0) |
| << "The node has empty out_grads already. Cannot DropGrads again."; |
| for (auto grad : info.out_grads) { |
| grad.ReInit(); |
| } |
| info.out_grads.clear(); |
| info.grad_req = kNullOp; |
| } |
| } |
| } |
| |
| void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node, |
| uint32_t num_inputs, |
| uint32_t num_outputs, |
| std::vector<bool>* p_save_inputs, |
| std::vector<bool>* p_save_outputs) { |
| static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient"); |
| std::vector<bool>& save_inputs = *p_save_inputs; |
| std::vector<bool>& save_outputs = *p_save_outputs; |
| save_inputs.resize(num_inputs); |
| save_outputs.resize(num_outputs); |
| std::fill(save_inputs.begin(), save_inputs.end(), false); |
| std::fill(save_outputs.begin(), save_outputs.end(), false); |
| |
| node->inputs.clear(); |
| node->inputs.reserve(num_inputs); |
| for (uint32_t i = 0; i < num_inputs; ++i) { |
| node->inputs.emplace_back(nnvm::NodeEntry{nullptr, i, 0}); |
| } |
| |
| if (fgradient.count(node->op())) { |
| std::vector<nnvm::NodeEntry> ograd_entries; |
| ograd_entries.reserve(num_outputs); |
| for (uint32_t i = 0; i < num_outputs; ++i) { |
| ograd_entries.emplace_back(nullptr, i, 1); |
| } |
| auto igrad_entries = fgradient[node->op()](node, ograd_entries); |
| for (const auto& i : igrad_entries) { |
| if (i.node == nullptr && i.version == 0) { |
| save_inputs[i.index] = true; |
| } else if (i.node == node) { |
| save_outputs[i.index] = true; |
| } |
| } |
| DFSVisit(igrad_entries, [&](const nnvm::ObjectPtr& gnode) { |
| if (!gnode || gnode == node) |
| return; |
| for (const auto& i : gnode->inputs) { |
| if (i.node == nullptr && i.version == 0) { |
| save_inputs[i.index] = true; |
| } else if (i.node == node) { |
| save_outputs[i.index] = true; |
| } |
| } |
| }); |
| } |
| } |
| |
| void Imperative::RecordOp(nnvm::NodeAttrs&& attrs, |
| const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs, |
| const OpStatePtr& state, |
| std::vector<bool>* p_save_inputs, |
| std::vector<bool>* p_save_outputs) { |
| MXAPIThreadLocalEntry<>* local_buff = MXAPIThreadLocalStore<>::Get(); |
| |
| CHECK(!is_deferred_compute()) |
| << "Autograd recording is not supported during deferred compute mode."; |
| |
| for (auto output : outputs) { |
| CHECK(AGInfo::IsNone(*output)) |
| << "Assigning to NDArrays that are already in a computational graph " |
| << "will cause undefined behavior when evaluating gradients. " |
| << "Please call backward first to clear the graph or do this out side of " |
| << "a record section. Also note that you cannot use inplace operations " |
| << "like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section. " |
| << "Issue occurred while recording op: " << attrs.name; |
| } |
| |
| bool need_grad = false; |
| for (const auto& i : inputs) { |
| if (AGInfo::IsNone(*i)) |
| continue; |
| need_grad = true; |
| break; |
| } |
| if (!need_grad) |
| return; |
| |
| nnvm::ObjectPtr node = nnvm::Node::Create(); |
| node->attrs = std::move(attrs); |
| // if node name is empty or node name is equal to op name - name it with unique name |
| if (node->attrs.name == "" || node->attrs.op->name == node->attrs.name) { |
| node->attrs.name = "node_" + std::to_string(node_count_++); |
| } else { |
| node_count_++; |
| } |
| AGInfo& info = AGInfo::Create(node); |
| info.state = state; |
| info.ctx = outputs[0]->ctx(); |
| |
| if (p_save_inputs == nullptr) { |
| p_save_inputs = &(local_buff->save_inputs); |
| p_save_outputs = &(local_buff->save_outputs); |
| GetBackwardDependency(node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs); |
| } else { |
| node->inputs.resize(inputs.size()); |
| } |
| |
| std::vector<bool>& save_inputs = *p_save_inputs; |
| std::vector<bool>& save_outputs = *p_save_outputs; |
| |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| if (AGInfo::IsNone(*(inputs[i]))) { |
| nnvm::NodeEntry entry{ |
| nnvm::Symbol::CreateVariable("null" + std::to_string(variable_count_++)).outputs[0].node, |
| 0, |
| 0}; |
| AGInfo& input_info = AGInfo::Create(entry.node); |
| input_info.ctx = inputs[i]->ctx(); |
| if (save_inputs[i]) { |
| input_info.outputs.emplace_back(*inputs[i]); |
| } else { |
| // Put a dummy array here since it will not be used. |
| input_info.outputs.emplace_back(); |
| input_info.outputs.back().shape_ = inputs[i]->shape(); |
| input_info.outputs.back().dtype_ = inputs[i]->dtype(); |
| input_info.outputs.back().storage_type_ = inputs[i]->storage_type(); |
| } |
| inputs[i]->autograd_entry_ = std::move(entry); // assign last to prevent cyclic reference |
| } else if (save_inputs[i]) { |
| nnvm::NodeEntry& entry = inputs[i]->autograd_entry_; |
| AGInfo::Get(entry.node).outputs[entry.index] = inputs[i]->Detach(); |
| } |
| node->inputs[i] = inputs[i]->autograd_entry_; |
| } |
| |
| for (auto output : outputs) { |
| CHECK(AGInfo::IsNone(*output)) |
| << "NotImplementedError: Inplace operations (+=, -=, x[:]=, etc) " |
| << "are not supported when recording with autograd."; |
| } |
| |
| for (uint32_t i = 0; i < outputs.size(); ++i) { |
| if (save_outputs[i]) { |
| info.outputs.emplace_back(outputs[i]->Detach()); |
| } else { |
| // Put a dummy array here since it will not be used. |
| info.outputs.emplace_back(); |
| info.outputs.back().shape_ = outputs[i]->shape(); |
| info.outputs.back().dtype_ = outputs[i]->dtype(); |
| info.outputs.back().storage_type_ = outputs[i]->storage_type(); |
| } |
| outputs[i]->autograd_entry_ = nnvm::NodeEntry{node, i, 0}; |
| } |
| } |
| |
| void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&& attrs, |
| const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs) { |
| CHECK(!is_recording()) |
| << "MXNetError: Autograd recording is not supported during deferred compute mode."; |
| |
| for (const NDArray* input : inputs) { |
| CHECK(!DCInfo::IsNone(*input)) |
| << "ValueError: All inputs to deferred compute recording must be associated " |
| << "with a symbolic variable or be the output of a deferred compute operator."; |
| } |
| for (const NDArray* output : outputs) { |
| CHECK(DCInfo::IsNone(*output)) |
| << "NotImplementedError: Inplace operations (+=, -=, x[:]=, etc) " |
| << "are not supported when recording in deferred compute mode."; |
| } |
| DispatchMode dispatch_mode = DispatchMode::kUndefined; |
| Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU()); |
| imperative::SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); |
| |
| nnvm::ObjectPtr node = nnvm::Node::Create(); |
| node->inputs.reserve(inputs.size()); |
| // Get NodeEntries for inputs |
| for (const NDArray* array : inputs) { |
| CHECK(array->deferredcompute_entry_.node); // Must not be nullptr |
| node->inputs.emplace_back(array->deferredcompute_entry_); |
| } |
| node->attrs = std::move(attrs); |
| // Need to support NameManager in imperative API to better name node->attrs.name |
| // if node name is empty or node name is equal to op name - name it with unique name |
| if (node->attrs.name == "" || node->attrs.op->name == node->attrs.name) { |
| node->attrs.name = "node_" + std::to_string(node_count_++); |
| } else { |
| node_count_++; |
| } |
| |
| if (get_opt_constraints() != OptConstraint::None) { |
| node->attrs.dict[OPT_CONSTRAINT_ATTR] = |
| std::to_string(static_cast<OptConstraint_int_t>(get_opt_constraints())); |
| } |
| |
| for (uint32_t i = 0; i < outputs.size(); ++i) { |
| outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0}; |
| } |
| |
| DCInfo::Create(node, inputs, outputs); |
| } |
| |
| nnvm::Symbol Imperative::GetDeferredComputeSymbol(const std::vector<NDArray*>& outputs) { |
| nnvm::Symbol s; |
| s.outputs.reserve(outputs.size()); |
| for (NDArray* ndoutput : outputs) { |
| CHECK(!Imperative::DCInfo::IsNone(*ndoutput)) |
| << "ValueError: output_arrays for GetDeferredComputeSymbol " |
| << "must have a deferred compute history associated with them."; |
| s.outputs.emplace_back(ndoutput->deferredcompute_entry_); |
| } |
| return s.Copy(); |
| } |
| |
| void Imperative::SetDeferredComputeVariable(NDArrayHandle* arrays, |
| SymbolHandle* variables, |
| const int num) { |
| // Sanity check all inputs |
| for (int i = 0; i < num; i++) { |
| nnvm::Symbol* s = reinterpret_cast<nnvm::Symbol*>(variables[i]); |
| NDArray* nd = reinterpret_cast<NDArray*>(arrays[i]); |
| CHECK_EQ(s->outputs.size(), 1) |
| << "MXNDArraySetDeferredComputeVariable expects variables as input. " |
| << "Instead got a Symbol with " << s->outputs.size() << " outputs as input " << i; |
| CHECK(s->outputs[0].node->is_variable()) |
| << "MXNDArraySetDeferredComputeVariable expects variables as input. " |
| << "Instead got a Symbol associated with an operator as input " << i; |
| CHECK(DCInfo::IsNone(*nd) || nd->deferredcompute_entry_.node == s->outputs[0].node) |
| << "ValueError: array " << i << " is already associated with a different variable. " |
| << "You can call array.detach() to obtain a copy without the variable"; |
| } |
| |
| // Store variables in DCInfo of arrays |
| for (int i = 0; i < num; i++) { |
| nnvm::Symbol* s = reinterpret_cast<nnvm::Symbol*>(variables[i]); |
| NDArray* nd = reinterpret_cast<NDArray*>(arrays[i]); |
| nd->deferredcompute_entry_ = nnvm::NodeEntry{s->outputs[0].node, 0, 0}; |
| |
| std::vector<NDArray*> inputs; |
| std::vector<NDArray*> outputs; // No need to specify outputs, as we will set is_computed_ |
| Imperative::DCInfo& info = Imperative::DCInfo::Create(s->outputs[0].node, inputs, outputs); |
| info.is_computed_ = true; |
| } |
| } |
| |
| void Imperative::DeferredComputeClear(NDArrayHandle* arrays, const int num) { |
| std::vector<nnvm::NodeEntry> outputs; |
| outputs.reserve(num); |
| for (int i = 0; i < num; i++) { |
| NDArray* nd = reinterpret_cast<NDArray*>(arrays[i]); |
| outputs.emplace_back(nd->deferredcompute_entry_); |
| } |
| nnvm::DFSVisit(outputs, [&](const nnvm::ObjectPtr& n) { |
| if (n != nullptr && !n->info.empty()) { |
| Imperative::DCInfo info = Imperative::DCInfo::Get(n); |
| info.inputs_.clear(); |
| info.input_handles_.clear(); |
| info.outputs_.clear(); |
| info.Clear(n); |
| } |
| }); |
| } |
| |
| std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs, |
| const std::vector<NDArray*>& ograds, |
| const std::vector<NDArray*>& variables, |
| bool is_train, |
| bool retain_graph, |
| bool create_graph) { |
| using namespace nnvm; |
| using namespace imperative; |
| static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; |
| static const Op* copy_op = Op::Get("_copy"); |
| |
| // Construct forward graph |
| Graph graph; |
| graph.outputs.reserve(outputs.size()); |
| for (const auto& i : outputs) { |
| CHECK(!AGInfo::IsNone(*i)) |
| << "Cannot differentiate node because it is not in a computational graph. " |
| << "You need to set is_recording to true or use autograd.record() to save " |
| << "computational graphs for backward. If you want to differentiate the same " |
| << "graph twice, you need to pass retain_graph=True to backward."; |
| graph.outputs.emplace_back(i->autograd_entry_); |
| } |
| size_t num_forward_outputs = graph.outputs.size(); |
| |
| // Prepare head gradients |
| std::vector<NodeEntry> ograd_entries; |
| ograd_entries.reserve(ograds.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| nnvm::ObjectPtr np = Node::Create(); |
| np->attrs.name = "_head_grad_" + std::to_string(i); |
| ograd_entries.emplace_back(NodeEntry{np, 0, 0}); |
| AGInfo& info = AGInfo::Create(ograd_entries.back().node); |
| info.ctx = outputs[i]->ctx(); |
| if (ograds[i] != nullptr) { |
| info.outputs.emplace_back(*ograds[i]); |
| } else { |
| info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), true, outputs[i]->dtype()); |
| if (info.outputs.back().shape().Size() != 0) { |
| info.outputs.back() = static_cast<real_t>(1.0); |
| } |
| } |
| } |
| |
| // Get gradient graph |
| Symbol sym; |
| sym.outputs = graph.outputs; |
| std::vector<NodeEntry> xs; |
| std::vector<NDArray*> x_grads; |
| std::vector<OpReqType> x_reqs; |
| if (variables.size()) { |
| xs.reserve(variables.size()); |
| x_grads.reserve(variables.size()); |
| x_reqs.reserve(variables.size()); |
| for (size_t i = 0; i < variables.size(); ++i) { |
| CHECK(!AGInfo::IsNone(*variables[i]) && |
| AGInfo::IsVariable(variables[i]->autograd_entry_.node)) |
| << "Cannot differentiate with respect to the " << i + 1 << "-th variable" |
| << " because it does not require gradient."; |
| xs.emplace_back(variables[i]->autograd_entry_); |
| x_grads.push_back(new NDArray()); |
| x_reqs.push_back(kWriteTo); |
| } |
| } else { |
| std::vector<ObjectPtr> args = sym.ListInputs(Symbol::kReadOnlyArgs); |
| xs.reserve(args.size()); |
| x_grads.reserve(args.size()); |
| x_reqs.reserve(args.size()); |
| for (const auto& i : args) { |
| AGInfo& info = AGInfo::Get(i); |
| if (info.grad_req == kNullOp) |
| continue; |
| xs.emplace_back(NodeEntry{i, 0, 0}); |
| x_grads.push_back(&info.out_grads[0]); |
| x_reqs.push_back(info.grad_req); |
| info.fresh_out_grad = true; |
| } |
| CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients."; |
| } |
| std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym); |
| std::vector<NodeEntry> us; |
| us.reserve(nleaf_vars.size()); |
| for (const auto& i : nleaf_vars) { |
| us.emplace_back(NodeEntry{i, 0, 0}); |
| } |
| |
| Graph g_graph = pass::MXGradient(graph, |
| graph.outputs, |
| xs, |
| ograd_entries, |
| mxnet::AggregateGradient, |
| nullptr, |
| zero_ops, |
| "_copy", |
| ShapeVector(), |
| DTypeVector(), |
| us); |
| CHECK_EQ(g_graph.outputs.size(), xs.size()); |
| for (const auto& e : g_graph.outputs) { |
| if (e.node->op() == nullptr) { |
| auto node = Node::Create(); |
| node->attrs.op = copy_op; |
| node->inputs.push_back(e); |
| graph.outputs.emplace_back(std::move(node)); |
| } else { |
| graph.outputs.push_back(e); |
| } |
| } |
| const auto& idx = graph.indexed_graph(); |
| // get number of nodes used in forward pass |
| size_t num_forward_nodes = 0; |
| size_t num_forward_entries = 0; |
| for (size_t i = 0; i < num_forward_outputs; ++i) { |
| num_forward_nodes = |
| std::max(num_forward_nodes, static_cast<size_t>(idx.outputs()[i].node_id + 1)); |
| num_forward_entries = |
| std::max(num_forward_entries, static_cast<size_t>(idx.entry_id(idx.outputs()[i])) + 1); |
| } |
| |
| // Allocate buffer |
| std::vector<NDArray> buff(idx.num_node_entries()); |
| std::vector<uint32_t> ref_count(buff.size(), 0); |
| std::vector<OpStatePtr> states; |
| std::vector<NDArray*> arrays; |
| arrays.reserve(buff.size()); |
| for (auto& buffered_array : buff) { |
| arrays.push_back(&buffered_array); |
| } |
| if (create_graph) { |
| states.resize(num_forward_nodes); |
| nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) { |
| AGInfo& info = AGInfo::Get(n); |
| states[idx.node_id(n.get())] = info.state; |
| for (uint32_t i = 0; i < info.outputs.size(); ++i) { |
| CHECK(idx.exist(n.get())); |
| size_t nid = idx.node_id(n.get()); |
| size_t eid = idx.entry_id(nid, i); |
| buff[eid] = info.outputs[i]; |
| buff[eid].autograd_entry_ = NodeEntry{n, i, 0}; |
| ref_count[eid] = 1; |
| } |
| }); |
| for (auto& ograd_entry : ograd_entries) { |
| AGInfo& info = AGInfo::Get(ograd_entry.node); |
| if (!idx.exist(ograd_entry.node.get())) |
| continue; |
| size_t eid = idx.entry_id(ograd_entry); |
| buff[eid] = info.outputs[0]; |
| buff[eid].autograd_entry_ = ograd_entry; |
| } |
| } else { |
| states.reserve(num_forward_nodes); |
| for (size_t i = 0; i < num_forward_nodes; ++i) { |
| const AGInfo& info = dmlc::get<AGInfo>(idx[i].source->info); |
| states.emplace_back(info.state); |
| for (size_t j = 0; j < info.outputs.size(); ++j) { |
| size_t eid = idx.entry_id(i, j); |
| arrays[eid] = const_cast<NDArray*>(&(info.outputs[j])); |
| |
| if (retain_graph || info.grad_req != kNullOp) |
| ref_count[eid] = 1; |
| } |
| } |
| for (auto& ograd_entry : ograd_entries) { |
| if (!idx.exist(ograd_entry.node.get())) |
| continue; |
| AGInfo& info = AGInfo::Get(ograd_entry.node); |
| arrays[idx.entry_id(ograd_entry)] = &info.outputs[0]; |
| } |
| } |
| for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) { |
| size_t eid = idx.entry_id(graph.outputs[i]); |
| arrays[eid] = x_grads[i - num_forward_outputs]; |
| ref_count[eid] = 1; |
| } |
| const std::vector<NodeEntry>& us_grads = g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads"); |
| CHECK_EQ(us_grads.size(), us.size()) |
| << "Size of queried nleaf_vars and size of their gradients don't match."; |
| for (size_t i = 0; i < us_grads.size(); i++) { |
| size_t eid = idx.entry_id(us_grads[i]); |
| AGInfo& info = AGInfo::Get(us[i].node); |
| if (arrays[eid]->dtype_ == -1) { |
| arrays[eid] = &info.out_grads[0]; |
| } else { |
| info.out_grads[0] = *arrays[eid]; |
| } |
| ref_count[eid] = 1; |
| } |
| |
| // Assign context |
| auto vctx = PlaceDevice(idx); |
| |
| // Infer shape type |
| { |
| std::pair<uint32_t, uint32_t> node_range, entry_range; |
| node_range = {num_forward_nodes, idx.num_nodes()}; |
| entry_range = {num_forward_entries, idx.num_node_entries()}; |
| |
| ShapeVector shapes; |
| shapes.reserve(idx.num_node_entries()); |
| bool contain_unknown = false; |
| for (const auto& i : arrays) |
| shapes.emplace_back(i->shape()); |
| CheckAndInferShape(&graph, std::move(shapes), false, node_range, entry_range, &contain_unknown); |
| |
| DTypeVector dtypes; |
| dtypes.reserve(idx.num_node_entries()); |
| for (const auto& i : arrays) |
| dtypes.emplace_back(i->dtype()); |
| CheckAndInferType(&graph, std::move(dtypes), false, node_range, entry_range); |
| |
| StorageTypeVector stypes; |
| stypes.reserve(idx.num_node_entries()); |
| for (const auto& i : arrays) |
| stypes.emplace_back(i->storage_type()); |
| exec::DevMaskVector dev_mask; |
| dev_mask.reserve(idx.num_nodes()); |
| for (const auto& i : vctx) |
| dev_mask.emplace_back(i.dev_mask()); |
| CheckAndInferStorageType( |
| &graph, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); |
| } |
| |
| // Calculate ref count |
| for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { |
| for (const auto& j : idx[i].inputs) { |
| ++ref_count[idx.entry_id(j)]; |
| } |
| } |
| |
| // Assign reqs |
| std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo); |
| for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) { |
| if (ref_count[i] == 0) |
| array_reqs[i] = kNullOp; |
| } |
| for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) { |
| size_t eid = idx.entry_id(idx.outputs()[i]); |
| array_reqs[eid] = x_reqs[i - num_forward_outputs]; |
| } |
| for (size_t i = 0; i < us_grads.size(); i++) { |
| size_t eid = idx.entry_id(us_grads[i]); |
| AGInfo& info = AGInfo::Get(us[i].node); |
| array_reqs[eid] = info.grad_req; |
| } |
| |
| const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape"); |
| const auto& dtypes = graph.GetAttr<DTypeVector>("dtype"); |
| const auto& stypes = graph.GetAttr<StorageTypeVector>("storage_type"); |
| const auto& dispatch_modes = graph.GetAttr<DispatchModeVector>("dispatch_mode"); |
| |
| for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { |
| auto num_outputs = idx[i].source->num_outputs(); |
| for (size_t j = 0; j < num_outputs; ++j) { |
| auto eid = idx.entry_id(i, j); |
| if (arrays[eid]->is_none()) |
| arrays[eid]->ReInit( |
| static_cast<NDArrayStorageType>(stypes[eid]), shapes[eid], vctx[i], dtypes[eid]); |
| } |
| } |
| |
| for (size_t nid = num_forward_nodes; nid < idx.num_nodes(); ++nid) { |
| const nnvm::NodeAttrs& attrs = idx[nid].source->attrs; |
| for (size_t oid = 0; oid < idx[nid].source->num_outputs(); ++oid) { |
| size_t eid = idx.entry_id(nid, oid); |
| arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); |
| } |
| } // for (nid ∈ [num_forward_nodes, idx.num_nodes())) |
| |
| if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) { |
| common::LogMemoryPlan(graph); |
| } |
| |
| // Execution |
| |
| bool prev_recording = set_is_recording(create_graph); |
| bool prev_training = set_is_training(is_train); |
| int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); |
| |
| try { |
| RunGraph(retain_graph, |
| idx, |
| arrays, |
| num_forward_nodes, |
| idx.num_nodes(), |
| std::move(array_reqs), |
| std::move(ref_count), |
| &states, |
| dispatch_modes, |
| is_recording()); |
| } catch (const dmlc::Error& e) { |
| Engine::Get()->set_bulk_size(prev_bulk_size); |
| set_is_recording(prev_recording); |
| set_is_training(prev_training); |
| throw e; |
| } |
| |
| Engine::Get()->set_bulk_size(prev_bulk_size); |
| set_is_recording(prev_recording); |
| set_is_training(prev_training); |
| |
| // Clear history |
| if (!retain_graph) { |
| nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) { |
| AGInfo::Clear(n); |
| n->inputs.clear(); |
| }); |
| } |
| |
| if (variables.size()) { |
| return x_grads; |
| } |
| return {}; |
| } |
| |
| Imperative::DCInfo::DCInfo(const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs) { |
| this->inputs_.reserve(inputs.size()); |
| this->input_handles_.reserve(inputs.size()); |
| for (const NDArray* arr : inputs) { |
| CHECK(!arr->is_none()); |
| this->inputs_.push_back(*arr); |
| this->input_handles_.push_back(arr); |
| } |
| |
| this->outputs_.reserve(outputs.size()); |
| for (const NDArray* arr : outputs) { |
| CHECK(!arr->is_none()); |
| this->outputs_.push_back(*arr); |
| } |
| } |
| |
| Imperative::DCInfo& Imperative::DCInfo::Create(const nnvm::ObjectPtr& node, |
| const std::vector<NDArray*>& inputs, |
| const std::vector<NDArray*>& outputs) { |
| node->info.construct<DCInfo>(inputs, outputs); |
| return Imperative::DCInfo::Get(node); |
| } |
| |
| void Imperative::DCInfo::Compute(const NDArray& arr) { |
| if (Imperative::DCInfo::IsComputed(arr)) { |
| if (!shape_is_known(arr.shape())) { |
| // We can't call arr.WaitToRead(); here, as WaitToRead calls Compute |
| // leading to an infinite loop. |
| Engine::Get()->WaitForVar(arr.ptr_->var); |
| if (shape_is_known(arr.ptr_->storage_shape)) { |
| arr.SetShapeFromChunk(); |
| } else { |
| CHECK(shape_is_known(arr.shape())); |
| } |
| } |
| return; |
| } |
| |
| DCInfo& info = Imperative::DCInfo::Get(arr.deferredcompute_entry_.node); |
| info.is_computed_ = true; // We will Invoke at the end of this function. |
| |
| // Recursively compute input arrays |
| for (const NDArray& input : info.inputs_) { |
| Compute(input); |
| } |
| |
| // Prepare pointers |
| std::vector<NDArray*> ndinputs, ndoutputs; |
| ndinputs.reserve(info.inputs_.size()); |
| ndoutputs.reserve(info.outputs_.size()); |
| for (NDArray& input : info.inputs_) |
| ndinputs.push_back(&input); |
| for (NDArray& output : info.outputs_) |
| ndoutputs.push_back(&output); |
| |
| // Compute this array |
| Imperative::Get()->Invoke( |
| Context::CPU(), arr.deferredcompute_entry_.node->attrs, ndinputs, ndoutputs); |
| if (!shape_is_known(arr.shape())) { |
| arr.WaitToRead(); |
| arr.SetShapeFromChunk(); |
| } |
| |
| // Deallocate copies |
| info.inputs_.clear(); |
| info.outputs_.clear(); |
| } |
| |
| std::vector<nnvm::ObjectPtr> Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const { |
| using namespace nnvm; |
| std::vector<ObjectPtr> ret; |
| DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) { |
| AGInfo& info = AGInfo::Get(node); |
| if (info.out_grads.size() > 0 && !node->is_variable()) { |
| ret.push_back(node); |
| } |
| }); |
| return ret; |
| } |
| |
| } // namespace mxnet |