| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file graph_executor.h |
| * \brief Executor to execute the computation graph. |
| */ |
| #ifndef MXNET_EXECUTOR_GRAPH_EXECUTOR_H_ |
| #define MXNET_EXECUTOR_GRAPH_EXECUTOR_H_ |
| |
| #include <mxnet/base.h> |
| #include <mxnet/ndarray.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/executor.h> |
| #include <nnvm/graph.h> |
| #include <nnvm/op_attr_types.h> |
| #include <nnvm/graph_attr_types.h> |
| #include <map> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| #include "./exec_pass.h" |
| |
| namespace mxnet { |
| |
| using NodeOperatorMap = std::unordered_map<const nnvm::Node*, |
| std::shared_ptr<Operator>>; |
| |
| // forward declaration |
| namespace exec { |
| class GraphExecutor; |
| } |
| |
| // forward declaration |
| namespace autograd { |
| class AutogradRuntime; |
| } |
| |
| namespace exec { |
| |
| using nnvm::Graph; |
| |
| // graph executors |
| class GraphExecutor : public Executor { |
| public: |
| friend class autograd::AutogradRuntime; |
| using Executor::MonitorCallback; |
| |
| virtual ~GraphExecutor(); |
| void Forward(bool is_train) override; |
| void PartialForward(bool is_train, int step, int *step_left) override; |
| void Backward(const std::vector<NDArray> &head_grads) override; |
| const std::vector<NDArray>& outputs() const override; |
| const std::unordered_map<std::string, NDArray>& in_arg_map() const override; |
| const std::unordered_map<std::string, NDArray>& arg_grad_map() const override; |
| const std::unordered_map<std::string, NDArray>& aux_state_map() const override; |
| void Print(std::ostream &os) const override; // NOLINT(*) |
| void SetMonitorCallback(const MonitorCallback& callback) override; |
| // Initialize the rest of attributes |
| // after setting up arguments. |
| void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, |
| Executor* shared_exec = nullptr, |
| const nnvm::NodeEntryMap<NDArray>& feed_dict |
| = nnvm::NodeEntryMap<NDArray>()); |
| |
| // initialize executor for bind |
| void Init(nnvm::Symbol symbol, |
| const Context& default_ctx, |
| const std::map<std::string, Context>& ctx_map, |
| const std::vector<NDArray>& in_args, |
| const std::vector<NDArray>& arg_grad_store, |
| const std::vector<OpReqType>& grad_req_types, |
| const std::vector<NDArray>& aux_states, |
| Executor* shared_exec = nullptr, |
| const nnvm::NodeEntryMap<NDArray>& feed_dict |
| = nnvm::NodeEntryMap<NDArray>()); |
| // initialize executor for simple bind |
| void Init(nnvm::Symbol symbol, |
| const Context& default_ctx, |
| const std::map<std::string, Context>& ctx_map, |
| const std::vector<Context>& in_arg_ctxes, |
| const std::vector<Context>& arg_grad_ctxes, |
| const std::vector<Context>& aux_state_ctxes, |
| const std::unordered_map<std::string, TShape>& arg_shape_map, |
| const std::unordered_map<std::string, int>& arg_dtype_map, |
| const std::vector<OpReqType>& grad_req_types, |
| const std::unordered_set<std::string>& shared_arg_names, |
| std::vector<NDArray>* in_arg_vec, |
| std::vector<NDArray>* arg_grad_vec, |
| std::vector<NDArray>* aux_state_vec, |
| std::unordered_map<std::string, NDArray>* shared_buffer = nullptr, |
| Executor* shared_exec = nullptr, |
| const nnvm::NodeEntryMap<NDArray>& feed_dict |
| = nnvm::NodeEntryMap<NDArray>()); |
| |
| protected: |
| // Information about operational node |
| struct OpNode { |
| // The name of the operator |
| const char* opr_name; |
| // the context of the node |
| Context ctx; |
| // The executor |
| std::shared_ptr<OpExecutor> exec; |
| // skip the execution of this node |
| bool skip_exec_node{false}; |
| // cached operator handle |
| Engine::OprHandle cached_opr{nullptr}; |
| // cached const vars, used for seg ops creation |
| std::vector<Engine::VarHandle> use_vars; |
| // cached mutate vars, used for seg ops creation |
| std::vector<Engine::VarHandle> mutate_vars; |
| }; |
| // a cached segment operator that executes a segment |
| struct CachedSegOpr { |
| // context of the operator |
| Context ctx; |
| // begin in topo order |
| size_t topo_start; |
| // end in topo order |
| size_t topo_end; |
| // the cached operator |
| Engine::OprHandle opr = nullptr; |
| // list of op executors |
| std::vector<OpExecutor*> exec_list; |
| }; |
| // Initialize in_args, arg_grads, and aux_states |
| void InitArguments(const nnvm::IndexedGraph& idx, |
| const nnvm::ShapeVector& inferred_shapes, |
| const nnvm::DTypeVector& inferred_dtypes, |
| const std::vector<Context>& in_arg_ctxes, |
| const std::vector<Context>& arg_grad_ctxes, |
| const std::vector<Context>& aux_state_ctxes, |
| const std::vector<OpReqType>& grad_req_types, |
| std::vector<NDArray>* in_arg_vec, |
| std::vector<NDArray>* arg_grad_vec, |
| std::vector<NDArray>* aux_state_vec); |
| // Initialize in_args, arg_grads and aux_states with |
| // shared_buffer and shared_exec |
| void InitArguments(const nnvm::IndexedGraph& idx, |
| const nnvm::ShapeVector& inferred_shapes, |
| const nnvm::DTypeVector& inferred_dtypes, |
| const std::vector<Context>& in_arg_ctxes, |
| const std::vector<Context>& arg_grad_ctxes, |
| const std::vector<Context>& aux_state_ctxes, |
| const std::vector<OpReqType>& grad_req_types, |
| const std::unordered_set<std::string>& shared_arg_names, |
| const Executor* shared_exec, |
| std::unordered_map<std::string, NDArray>* shared_buffer, |
| std::vector<NDArray>* in_arg_vec, |
| std::vector<NDArray>* arg_grad_vec, |
| std::vector<NDArray>* aux_state_vec); |
| // internal initialization of the graph for simple bind |
| Graph InitGraph(nnvm::Symbol symbol, |
| const Context& default_ctx, |
| const std::map<std::string, Context>& ctx_map, |
| const std::vector<Context>& in_arg_ctxes, |
| const std::vector<Context>& arg_grad_ctxes, |
| const std::vector<Context>& aux_state_ctxes, |
| const std::vector<OpReqType>& grad_req_types); |
| // intialize the full graph for simple bind, including gradient |
| Graph InitFullGraph(nnvm::Symbol symbol, |
| const std::vector<OpReqType>& grad_req_types); |
| // initialize the cached operator |
| void InitCachedOps(); |
| // initialize the opr segments for bulk exec |
| void InitOpSegs(); |
| // initialize the resources in the graph |
| // initialize the memory of data entries |
| // shared_pool: extra memory shared from other parts |
| void InitDataEntryMemory(std::vector<NDArray>* shared_pool); |
| // run ops from topo order start to end |
| void RunOps(bool is_train, size_t topo_start, size_t topo_end); |
| /*! |
| * \brief Try to create a cached operator to run segments between start and end |
| * \param topo_start beginning of segment |
| * \param topo_end end of segment |
| * \return the cached operator. |
| * ret.opr Can be nullptr if creation failed. |
| */ |
| CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end); |
| // run the monitor callback for node `nid` |
| void ExecuteMonCallback(size_t nid); |
| |
| // internal graph |
| nnvm::Graph graph_; |
| // operator node |
| std::vector<OpNode> op_nodes_; |
| // internal data entry of each node |
| std::vector<NDArray> data_entry_; |
| // internal data pool of allocated entries |
| std::vector<NDArray> data_pool_; |
| // output arrays |
| std::vector<NDArray> output_arrays_; |
| // input argument map, key is arg name, value is arg's NDArray |
| std::unordered_map<std::string, NDArray> in_arg_map_; |
| // arg grad map, key is arg name, value is arg grad NDArray |
| std::unordered_map<std::string, NDArray> arg_grad_map_; |
| // aux state map, key is aux state name, value is aux state NDArray |
| std::unordered_map<std::string, NDArray> aux_state_map_; |
| // gradient store |
| std::vector<std::pair<OpReqType, NDArray> > grad_store_; |
| // array to hold head gradient. |
| std::vector<NDArray> head_grad_array_; |
| // entry to hold head gradient |
| std::vector<nnvm::NodeEntry> head_grad_entry_; |
| // the index map of entry to map. |
| std::unordered_map<const nnvm::Node*, size_t> head_grad_map_; |
| // number of outputs. |
| size_t num_forward_outputs_{0}; |
| // number of inputs |
| size_t num_forward_inputs_{0}; |
| // number of forward nodes |
| size_t num_forward_nodes_{0}; |
| // saved operator for autograd |
| NodeOperatorMap saved_opr_; |
| // monitor call back |
| std::function<void(const char*, void*)> monitor_callback_{nullptr}; |
| // whether to enable bulk execution |
| bool prefer_bulk_execution_; |
| // cached segment operator |
| std::vector<CachedSegOpr> cached_seg_opr_; |
| }; |
| |
| } // namespace exec |
| } // namespace mxnet |
| #endif // MXNET_EXECUTOR_GRAPH_EXECUTOR_H_ |