blob: 42cf5dd775c20694754e075882d67a1a2d577fdd [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file nnvm/symbolic.h
* \brief Symbolic graph construction API
*
* This API is optional, but useful to allow user
* to construct NNVM Graph easily, and quickly create
* front-end host languages.
*/
#ifndef NNVM_SYMBOLIC_H_
#define NNVM_SYMBOLIC_H_
#include <string>
#include <vector>
#include <tuple>
#include <utility>
#include "base.h"
#include "node.h"
namespace nnvm {
/*!
* \brief Symbol is help class used to represent the operator node in Graph.
*
* Symbol acts as an interface for building graphs from different components
* like Variable, Functor and Group. Symbol is also exported to python front-end
* (while Graph is not) to enable quick test and deployment. Conceptually,
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/
class NNVM_DLL Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
/*! \brief recursively list all attributes */
kRecursive = 0,
/*! \brief only list attributes in current node */
kShallow = 1
};
/*! \brief option passed to ListInputNames */
enum ListInputOption {
/*! \brief list all the arguments */
kAll = 0,
/*! \brief list only read only arguments */
kReadOnlyArgs = 1,
/*!
* \brief List auxiliary states that can be mutated by the graph.
* This excludes the ReadOnly arguments
*/
kAuxiliaryStates = 2
};
/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
/*!
* \brief Copy the symbol.
* \return A deep copy of this symbol.
*/
Symbol Copy() const;
/*!
* \brief Print the symbol info to output stream.
* \param os The output stream to print to.
*/
void Print(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Get the index-th element from the returned tuple.
* \param index Index of multi output.
* \return The symbol corresponds to the indexed element.
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the input variable nodes.
*
* The order of the returned list is the same as the order of the input list to `operator()`.
*
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<NodePtr> ListInputs(ListInputOption option) const;
/*!
* \brief List the input names.
*
* The order of the returned list is the same as the order of the input list to `operator()`.
*
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<std::string> ListInputNames(ListInputOption option) const;
/*!
* \brief List the names of outputs for this symbol.
*
* For normal operators, it is usually symbol node name + "_output".
*
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputNames() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
*
* The rest of the symbols will remain the same name.
*
* \param args Positional arguments.
* \param kwargs Keyword arguments for the symbol.
* \param name Name of returned symbol.
*/
void Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name);
/*!
* \brief Apply the symbol as a function, compose with arguments
*
* This is equivalent to Copy then Compose.
*
* \param args Positional arguments for the symbol.
* \param kwargs Keyword arguments for the symbol.
* \param name Name of returned symbol.
* \return A new Symbol which is the composition of current symbol with its arguments.
*/
Symbol operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const;
/*!
* \brief Add control flow dependencies to the operators in symbols.
*
* For grouped symbol, an error will be raised. This mutates current symbolic Node.
*
* \param src The symbols to depend on.
*/
void AddControlDeps(const Symbol& src);
/*
* \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols
* including input variables and intermediate outputs.
*/
Symbol GetInternals() const;
/*
* \brief Get the direct inputs of the head node(s) of this symbol.
* \return symbol A new symbol whose output contains all the inputs of the head
* node(s).
*/
Symbol GetChildren() const;
/*!
* \brief Set additional attributes to current node.
*
* This only works for symbol with outputs from single operators.
* For grouped symbol, an error will be raised.
*
* This function mutates the node's symbol and is not recommended.
*
* \param attrs The attributes to set.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attributes from the symbol.
*
* This only works for symbol with outputs from single operators.
* For grouped symbol, an error will be raised.
*
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out The output value of the attribute.
* \return true If the attribute exists, false if the attribute does not exist.
*/
bool GetAttr(const std::string& key, std::string* out) const;
/*!
* \brief Get attribute dictionary from the symbol.
*
* For grouped symbol, an error will be raised.
*
* \param option If recursive flag is set, the attributes of all children are retrieved.
* The name of symbol will be pre-pended to each key.
* \return The created attribute.
*/
std::unordered_map<std::string, std::string> ListAttrs(ListAttrOption option) const;
/*!
* \brief Get attribute dictionary from the symbol and all children.
*
* For grouped symbol, an error will be raised.
*
* \return The created attribute in format <operator_name, key, value>.
*/
std::vector<std::tuple<std::string, std::string, std::string> >
ListAttrsRecursive() const;
/*!
* \brief Create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op The operator.
* \param attrs The additional attributes.
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs);
/*!
* \brief Create symbolic functor(AtomicSymbol) by given node attributes.
* \param attrs pre-initialized Node attributes.
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const NodeAttrs& attrs);
/*!
* \brief Create symbol node representing variable.
* \param name Name of the variable.
* \return The symbol.
*/
static Symbol CreateVariable(const std::string& name);
/*!
* \brief Create equivalence of symbol by grouping the symbols together.
* \param symbols A list of symbols to be grouped.
* \return The grouped symbol.
*/
static Symbol CreateGroup(const std::vector<Symbol>& symbols);
};
} // namespace nnvm
#endif // NNVM_SYMBOLIC_H_