blob: d5a4e8c3aa6c9219e6f45faaf917720a577a1ce1 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file graph_executor.h
* \brief Executor to execute the computation graph.
*/
#ifndef MXNET_EXECUTOR_GRAPH_EXECUTOR_H_
#define MXNET_EXECUTOR_GRAPH_EXECUTOR_H_
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/executor.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "./exec_pass.h"
namespace mxnet {
using NodeOperatorMap = std::unordered_map<const nnvm::Node*,
std::shared_ptr<Operator>>;
// forward declaration
namespace exec {
class GraphExecutor;
}
// forward declaration
namespace autograd {
class AutogradRuntime;
}
namespace exec {
using nnvm::Graph;
// graph executors
class GraphExecutor : public Executor {
public:
friend class autograd::AutogradRuntime;
using Executor::MonitorCallback;
virtual ~GraphExecutor();
void Forward(bool is_train) override;
void PartialForward(bool is_train, int step, int *step_left) override;
void Backward(const std::vector<NDArray> &head_grads) override;
const std::vector<NDArray>& outputs() const override;
const std::unordered_map<std::string, NDArray>& in_arg_map() const override;
const std::unordered_map<std::string, NDArray>& arg_grad_map() const override;
const std::unordered_map<std::string, NDArray>& aux_state_map() const override;
void Print(std::ostream &os) const override; // NOLINT(*)
void SetMonitorCallback(const MonitorCallback& callback) override;
// Initialize the rest of attributes
// after setting up arguments.
void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g,
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// initialize executor for bind
void Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_types,
const std::vector<NDArray>& aux_states,
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// initialize executor for simple bind
void Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& shared_arg_names,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_buffer = nullptr,
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
protected:
// Information about operational node
struct OpNode {
// The name of the operator
const char* opr_name;
// the context of the node
Context ctx;
// The executor
std::shared_ptr<OpExecutor> exec;
// skip the execution of this node
bool skip_exec_node{false};
// cached operator handle
Engine::OprHandle cached_opr{nullptr};
// cached const vars, used for seg ops creation
std::vector<Engine::VarHandle> use_vars;
// cached mutate vars, used for seg ops creation
std::vector<Engine::VarHandle> mutate_vars;
};
// a cached segment operator that executes a segment
struct CachedSegOpr {
// context of the operator
Context ctx;
// begin in topo order
size_t topo_start;
// end in topo order
size_t topo_end;
// the cached operator
Engine::OprHandle opr = nullptr;
// list of op executors
std::vector<OpExecutor*> exec_list;
};
// Initialize in_args, arg_grads, and aux_states
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec);
// Initialize in_args, arg_grads and aux_states with
// shared_buffer and shared_exec
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& shared_arg_names,
const Executor* shared_exec,
std::unordered_map<std::string, NDArray>* shared_buffer,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec);
// internal initialization of the graph for simple bind
Graph InitGraph(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types);
// intialize the full graph for simple bind, including gradient
Graph InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types);
// initialize the cached operator
void InitCachedOps();
// initialize the opr segments for bulk exec
void InitOpSegs();
// initialize the resources in the graph
// initialize the memory of data entries
// shared_pool: extra memory shared from other parts
void InitDataEntryMemory(std::vector<NDArray>* shared_pool);
// run ops from topo order start to end
void RunOps(bool is_train, size_t topo_start, size_t topo_end);
/*!
* \brief Try to create a cached operator to run segments between start and end
* \param topo_start beginning of segment
* \param topo_end end of segment
* \return the cached operator.
* ret.opr Can be nullptr if creation failed.
*/
CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end);
// run the monitor callback for node `nid`
void ExecuteMonCallback(size_t nid);
// internal graph
nnvm::Graph graph_;
// operator node
std::vector<OpNode> op_nodes_;
// internal data entry of each node
std::vector<NDArray> data_entry_;
// internal data pool of allocated entries
std::vector<NDArray> data_pool_;
// output arrays
std::vector<NDArray> output_arrays_;
// input argument map, key is arg name, value is arg's NDArray
std::unordered_map<std::string, NDArray> in_arg_map_;
// arg grad map, key is arg name, value is arg grad NDArray
std::unordered_map<std::string, NDArray> arg_grad_map_;
// aux state map, key is aux state name, value is aux state NDArray
std::unordered_map<std::string, NDArray> aux_state_map_;
// gradient store
std::vector<std::pair<OpReqType, NDArray> > grad_store_;
// array to hold head gradient.
std::vector<NDArray> head_grad_array_;
// entry to hold head gradient
std::vector<nnvm::NodeEntry> head_grad_entry_;
// the index map of entry to map.
std::unordered_map<const nnvm::Node*, size_t> head_grad_map_;
// number of outputs.
size_t num_forward_outputs_{0};
// number of inputs
size_t num_forward_inputs_{0};
// number of forward nodes
size_t num_forward_nodes_{0};
// saved operator for autograd
NodeOperatorMap saved_opr_;
// monitor call back
std::function<void(const char*, void*)> monitor_callback_{nullptr};
// whether to enable bulk execution
bool prefer_bulk_execution_;
// cached segment operator
std::vector<CachedSegOpr> cached_seg_opr_;
};
} // namespace exec
} // namespace mxnet
#endif // MXNET_EXECUTOR_GRAPH_EXECUTOR_H_