blob: 0fd54ff5b8fa7ee4537f8cca6b863046b7e1c404 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/op.h
* \brief Primitive operator definition.
*/
#ifndef TVM_RELAY_OP_H_
#define TVM_RELAY_OP_H_
#include <functional>
#include <limits>
#include <string>
#include <typeinfo>
#include <utility>
#include <vector>
#include "base.h"
#include "expr.h"
#include "type.h"
namespace tvm {
namespace relay {
// forward declare name.
template <typename ValueType>
class OpMap;
class GenericOpMap;
class OpRegistry;
/*!
* \brief Node container of operator structure.
*/
class OpNode : public relay::ExprNode {
public:
/*! \brief name of the operator */
std::string name;
/*! \brief the type of the operator */
mutable FuncType op_type;
/*!
* \brief detailed description of the operator
* This can be used to generate docstring automatically for the operator.
*/
std::string description;
/* \brief Information of input arguments to the operator */
Array<AttrFieldInfo> arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to anything.
*/
std::string attrs_type_key;
/*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator,
* -1 means it is variable length
*/
int32_t num_inputs = -1;
/*!
* \brief support level of the operator,
* The lower the more priority it contains.
* This is in analogies to BLAS levels.
*/
int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
v->Visit("arguments", &arguments);
v->Visit("attrs_type_key", &attrs_type_key);
v->Visit("num_inputs", &num_inputs);
v->Visit("support_level", &support_level);
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}
static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
private:
// friend class
friend class GenericOpMap;
friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
};
/*!
* \brief Operator reference class.
*/
class Op : public relay::Expr {
public:
/*! \brief default constructor */
Op() {}
/*! \brief constructor from node pointer */
explicit Op(NodePtr<Node> n) : Expr(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OpNode* operator->() const;
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
TVM_DLL static const Op& Get(const std::string& op_name);
/*! \brief specify container node */
using ContainerType = OpNode;
private:
/*!
* \brief Get generic attrmap given attr name
* \param key The attribute key
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
};
/*! \brief Helper structure to register operators */
class OpRegistry {
public:
/*! \return the operator */
const Op& op() const { return op_; }
/*!
* \brief setter function during registration
* Set the description of operator
* \param descr the description string.
* \return reference to self.
*/
inline OpRegistry& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline OpRegistry& add_argument(const std::string& name,
const std::string& type,
const std::string& description);
/*!
* \brief Attach the type function corresponding to the return type.
* \param rel_name The type relation name to register.
* \param type_rel_func The backing relation function which can solve an arbitrary
* relation on variables.
* \return reference to self.
*/
inline OpRegistry& add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func);
/*!
* \brief Set the type key of attributes.
* \param type_key The type of of the attrs field.
* \return reference to self.
*/
inline OpRegistry& set_attrs_type_key(const std::string& type_key);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*)
/*!
* \brief Set the support level of op.
* \param level The support level.
* \return reference to self.
*/
inline OpRegistry& set_support_level(int32_t level); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template <typename ValueType>
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);
// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
get()->name = name;
}
return *this;
}
/*! \return The global single registry */
TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
private:
friend class ::dmlc::Registry<OpRegistry>;
// the name
std::string name;
/*! \brief The operator */
Op op_;
// private constructor
OpRegistry();
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
int plevel);
};
/*!
* \brief Generic map to store additional information of Op.
*/
class GenericOpMap {
public:
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op& op) const;
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const TVMRetValue& operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Expr& expr, ValueType def_value) const;
private:
friend class OpRegistry;
// the attribute field.
std::string attr_name_;
// internal data
std::vector<std::pair<TVMRetValue, int> > data_;
// The value
GenericOpMap() = default;
};
/*!
* \brief Map<Op,ValueType> used to store meta-information about Op.
* \tparam ValueType The type of the value stored in map.
*/
template <typename ValueType>
class OpMap {
public:
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op& op) const;
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline ValueType operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline ValueType get(const Expr& expr, ValueType def_value) const;
private:
friend class Op;
// constructor
explicit OpMap(const GenericOpMap& map) : map_(map) {}
/*! \brief The internal map field */
const GenericOpMap& map_;
};
// internal macros to make
#define RELAY_REGISTER_VAR_DEF \
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
/*!
* \def RELAY_REGISTER_OP
* \brief Register a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
*
* RELAY_REGISTER_OP("add")
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
#define RELAY_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
::tvm::relay::OpRegistry::Registry() \
->__REGISTER_OR_GET__(OpName) \
.set_name()
// implementations
inline const OpNode* Op::operator->() const {
return static_cast<const OpNode*>(node_.get());
}
template <typename ValueType>
inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}
inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
inline OpRegistry& OpRegistry::describe(
const std::string& descr) { // NOLINT(*)
get()->description = descr;
return *this;
}
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type,
const std::string& description) {
auto n = make_node<AttrFieldInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
get()->arguments.push_back(AttrFieldInfo(n));
return *this;
}
inline OpRegistry& OpRegistry::add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
} else {
runtime::Registry::Register(func_name)
.set_body(type_rel_func.packed());
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
}
Array<TypeVar> type_params;
Array<Type> arg_types;
// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}
Array<Type> ty_call_args = arg_types;
// Add output type.
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
// The attributes of primitive op is nullptr
//
// The attributes of primitive operator can vary at the call site.
// The type of sum is also dependent on Attrs being passed.
// So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
//
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel =
TypeRelationNode::make(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());
auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
get()->op_type = func_type;
return *this;
}
inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
get()->num_inputs = n;
return *this;
}
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) {
get()->attrs_type_key = type_key;
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
return *this;
}
inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*)
get()->support_level = n;
return *this;
}
template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv;
rv = value;
UpdateAttr(attr_name, rv, plevel);
return *this;
}
// member functions of OpMap
inline int GenericOpMap::count(const Op& op) const {
if (op.defined()) {
const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
} else {
return 0;
}
}
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
<< "Attribute " << attr_name_ << " has not been registered for Operator "
<< op->name;
return data_[idx].first;
}
template <typename ValueType>
inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
}
template <typename ValueType>
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
} else {
return value;
}
}
template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op);
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op];
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Expr& expr,
ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}
/*!
* \brief Check that an expression is a "primtive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();
return op != nullptr && op->IsPrimitiveOp();
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_