blob: e4343a19a50dae7836f299fca66c3d67603963e5 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file executor.h
* \brief executor definition
* \author Chuntao Hong, Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_
#include <vector>
#include <map>
#include <set>
#include <string>
#include "mxnet-cpp/base.h"
#include "mxnet-cpp/symbol.h"
namespace mxnet {
namespace cpp {
class Optimizer;
/*!
* \brief Executor interface
*/
class Executor {
friend class Monitor;
public:
Executor(const Symbol &symbol, Context context,
const std::vector<NDArray> &arg_arrays,
const std::vector<NDArray> &grad_arrays,
const std::vector<OpReqType> &grad_reqs,
const std::vector<NDArray> &aux_arrays,
const std::map<std::string, Context> &group_to_ctx =
std::map<std::string, Context>(),
Executor *shared_exec = nullptr);
explicit Executor(const ExecutorHandle &h) { handle_ = h; }
/*!
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
void Forward(bool is_train) {
MXExecutorForward(handle_, is_train ? 1 : 0);
mx_uint out_size;
NDArrayHandle *out_array;
CHECK_EQ(MXExecutorOutputs(handle_, &out_size, &out_array), 0);
for (mx_uint i = 0; i < out_size; ++i) {
outputs[i] = NDArray(out_array[i]);
}
}
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NDArrays specified by grad_in_args_store will be
*updated accordingly.
* User is allowed to pass in an empty Array if the head node is
* loss function and head gradeitn is not needed.
*
* \param head_grads the gradient of head nodes to be backproped.
*/
void Backward(const std::vector<NDArray> &head_grads =
std::vector<NDArray>()) {
std::vector<NDArrayHandle> head_grads_;
for (auto d : head_grads) {
head_grads_.push_back(d.GetHandle());
}
if (head_grads_.size() > 0) {
MXExecutorBackward(handle_, head_grads_.size(), head_grads_.data());
} else {
MXExecutorBackward(handle_, 0, nullptr);
}
}
// TODO(zhangchen-qinyinghua)
// To implement reshape function
void Reshape();
/*!
* \brief update the arguments with given learning rate and optimizer
* \return the SymbolHandle
*/
std::string DebugStr();
/*!
* \brief update the arguments with given learning rate and optimizer
* \param opt the pointer to the optimizer
* \param lr learning rate
* \param wd weight decay
* \param arg_update_begin begin index of the arguments to be updated, it
* starts after the input data by default
* \param arg_update_end end index of the arguments to be updated, it ends
* before the label data by default
*/
void UpdateAll(Optimizer *opt, float lr, float wd, int arg_update_begin = 1,
int arg_update_end = -1);
/*!
* \brief destructor, free the handle
*/
~Executor() { MXExecutorFree(handle_); }
std::vector<NDArray> arg_arrays;
std::vector<NDArray> grad_arrays;
std::vector<NDArray> aux_arrays;
/*!
* \brief arrays store the outputs of forward
*/
std::vector<NDArray> outputs;
std::map<std::string, NDArray> arg_dict() {
return GetDict(symbol_.ListArguments(), arg_arrays);
}
std::map<std::string, NDArray> grad_dict() {
return GetDict(symbol_.ListArguments(), grad_arrays);
}
std::map<std::string, NDArray> aux_dict() {
return GetDict(symbol_.ListAuxiliaryStates(), aux_arrays);
}
private:
Executor(const Executor &e);
Executor &operator=(const Executor &e);
ExecutorHandle handle_;
Symbol symbol_;
std::map<std::string, NDArray> GetDict(const std::vector<std::string> &names,
const std::vector<NDArray> &arrays) {
std::map<std::string, NDArray> ret;
std::set<std::string> name_set;
for (const auto &s : names) {
CHECK(name_set.find(s) == name_set.end()) << "Duplicate names detected, "
<< s;
name_set.insert(s);
}
CHECK_EQ(name_set.size(), arrays.size())
<< "names size not equal to arrays size";
for (size_t i = 0; i < names.size(); ++i) {
ret[names[i]] = arrays[i];
}
return ret;
}
};
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_