blob: e853c2617ea41cbc2e44bf83c50e4adf486d0aa3 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file symbol.h
* \brief definition of symbol
* \author Chuntao Hong, Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_
#include <map>
#include <string>
#include <vector>
#include "mxnet-cpp/base.h"
#include "mxnet-cpp/ndarray.h"
#include "mxnet-cpp/op_map.h"
namespace mxnet {
namespace cpp {
class Executor;
/*!
* \brief struct to store SymbolHandle
*/
struct SymBlob {
public:
/*!
* \brief default constructor
*/
SymBlob() : handle_(nullptr) {}
/*!
* \brief construct with SymbolHandle to store
*/
explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
/*!
* \brief destructor, free the SymbolHandle
*/
~SymBlob() { MXSymbolFree(handle_); }
/*!
* \brief the SymbolHandle to store
*/
SymbolHandle handle_;
private:
SymBlob(const SymBlob &);
SymBlob &operator=(const SymBlob &);
};
/*!
* \brief Symbol interface
*/
class Symbol {
public:
Symbol() {}
/*!
* \brief construct a Symbol with SymbolHandle
* \param handle the given SymbolHandle
*/
explicit Symbol(SymbolHandle handle);
/*!
* \brief construct a variable Symbol
* \param name the name of the variable
*/
explicit Symbol(const char *name);
/*!
* \brief construct a variable Symbol
* \param name the name of the variable
*/
explicit Symbol(const std::string &name);
Symbol operator+(const Symbol &rhs) const;
Symbol operator-(const Symbol &rhs) const;
Symbol operator*(const Symbol &rhs) const;
Symbol operator/(const Symbol &rhs) const;
Symbol operator%(const Symbol &rhs) const;
Symbol operator+(mx_float scalar) const;
Symbol operator-(mx_float scalar) const;
Symbol operator*(mx_float scalar) const;
Symbol operator/(mx_float scalar) const;
Symbol operator%(mx_float scalar) const;
Symbol Copy() const;
/*!
* \brief construct a variable Symbol
* \param name the name of the variable
*/
static Symbol Variable(const std::string &name = "");
Symbol operator[](int index);
Symbol operator[](const std::string &index);
/*!
* \brief Create a symbol that groups symbols together
* \param symbols List of symbols to be groupe
*/
static Symbol Group(const std::vector<Symbol> &symbols);
/*!
* \brief load Symbol from a JSON file
* \param file_name the name of the file
*/
static Symbol Load(const std::string &file_name);
/*!
* \brief load Symbol from a JSON string
* \param json_str the JSON string
*/
static Symbol LoadJSON(const std::string &json_str);
/*!
* \brief save Symbol to a file
* \param file_name the name of the file
*/
void Save(const std::string &file_name) const;
/*!
* \brief save Symbol into a JSON string
*/
std::string ToJSON() const;
/*!
* \brief save Symbol into a JSON string
* \retutrn the symbol whose outputs are all the internals.
*/
Symbol GetInternals() const;
/*!
* \return the SymbolHandle
*/
SymbolHandle GetHandle() const { return blob_ptr_->handle_; }
/*!
* \brief construct an operator Symbol, with given input Symbol and config
* \param name the name of the Symbol
* \param input_keys the vector of keys of the input
* \param input_values the vector of the intput Symbols
* \param config_keys the vector of keys of the config
* \param config_values the vecotr of values of the config
*/
Symbol(const std::string &operator_name, const std::string &name,
std::vector<const char *> input_keys,
std::vector<SymbolHandle> input_values,
std::vector<const char *> config_keys,
std::vector<const char *> config_values);
/*!
* \brief infer the shapes by providing shapes of known argument shapes.
* \param arg_shapes map of argument name to shape of arguments with known
* shapes.
* \param in_shapes used to store infered shapes of input arguments.
* \param out_shapes used to store infered shapes of outputs.
* \param aux_shapes use to store the infered shapes of auxiliary states
*/
void InferShape(
const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
std::vector<std::vector<mx_uint> > *in_shape,
std::vector<std::vector<mx_uint> > *aux_shape,
std::vector<std::vector<mx_uint> > *out_shape) const;
/*!
* \brief List the arguments names.
*
* The position of the returned list also corresponds to calling position in
*operator()
* \return the arguments list of this symbol, they can be either named or
*unnamed (empty string).
*/
std::vector<std::string> ListArguments() const;
/*! \return get the descriptions of outputs for this symbol */
std::vector<std::string> ListOutputs() const;
/*! \return get the descriptions of auxiliary data for this symbol */
std::vector<std::string> ListAuxiliaryStates() const;
/*!
* \brief infer and construct all the arrays to bind to executor by providing
* some known arrays.
* \param context the context of all the infered arrays
* \param arg_arrays infered input arguments arrays.
* \param arad_arrays infered arrays to store the gradient output of the input
* arguments.
* \param aux_arrays infered arrays that is used as internal state in op.
* \param args_map map of some given arguments arrays.
* \param args_grad_store map of some gradient given store arrays.
* \param args_req_type map of some given type of gradient saving. Can only be
* in {kNullOp, kAddTo, kWriteTo}.
* \param aux_map NDArray that stores the internal state in op
*/
void InferExecutorArrays(
const Context &context, std::vector<NDArray> *arg_arrays,
std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
std::vector<NDArray> *aux_arrays,
const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store =
std::map<std::string, NDArray>(),
const std::map<std::string, OpReqType> &grad_req_type =
std::map<std::string, OpReqType>(),
const std::map<std::string, NDArray> &aux_map =
std::map<std::string, NDArray>()) const;
/*!
* \brief infer and construct all the input arguments arrays to bind to
* executor by providing some known arguments arrays.
* \param context the context of all the infered arrays.
* \param args_map map of all the infered input arguments arrays.
* \param known_args map of some given arguments arrays.
*/
void InferArgsMap(const Context &context,
std::map<std::string, NDArray> *args_map,
const std::map<std::string, NDArray> &known_args) const;
/*!
* \brief Create an executor by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument,
*grad_req_type[i] can be kNullOp.
* The input arrays in the given maps should have the same name with the input
*symbol.
* Only need some of the necessary arrays, and the other arrays can be infered
*automatically.
*
* \param context the context of binding.
* \param args_map the NDArray that stores the input arguments to the symbol.
* \param arg_grad_store NDArray that is used to store the gradient output of
*the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in
*{kNullOp, kAddTo, kWriteTo}.
* \param aux_map NDArray that stores the internal state in op
* \return a new executor, which need to be free manually.
*/
Executor *SimpleBind(const Context &context,
const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store =
std::map<std::string, NDArray>(),
const std::map<std::string, OpReqType> &grad_req_type =
std::map<std::string, OpReqType>(),
const std::map<std::string, NDArray> &aux_map =
std::map<std::string, NDArray>());
/*!
* \brief Create an executor by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument,
*grad_req_type[i] can be kNullOp.
*
* \param context the context of binding.
* \param arg_arrays the NDArray that stores the input arguments to the symbol.
* \param grad_arrays NDArray that is used to store the gradient output of the
*input arguments.
* \param grad_reqs requirment type of gradient saving. Can only be in
*{kNullOp, kAddTo, kWriteTo}.
* \param aux_arrays NDArray that is used as internal state in op
* \param group_to_ctx dict of string to mx.Context
* \param shared_exec Executor to share memory with. This is intended for
*runtime reshaping, variable length sequencesn etc. The returned executor
*shares state with shared_exec, and should not be used in parallel with it.
* \return a new executor, which need to be free manually.
*/
Executor *Bind(const Context &context, const std::vector<NDArray> &arg_arrays,
const std::vector<NDArray> &grad_arrays,
const std::vector<OpReqType> &grad_reqs,
const std::vector<NDArray> &aux_arrays,
const std::map<std::string, Context> &group_to_ctx =
std::map<std::string, Context>(),
Executor *shared_exec = nullptr);
private:
std::shared_ptr<SymBlob> blob_ptr_;
static OpMap*& op_map();
};
Symbol operator+(mx_float lhs, const Symbol &rhs);
Symbol operator-(mx_float lhs, const Symbol &rhs);
Symbol operator*(mx_float lhs, const Symbol &rhs);
Symbol operator/(mx_float lhs, const Symbol &rhs);
Symbol operator%(mx_float lhs, const Symbol &rhs);
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_