| /*! |
| * Copyright (c) 2018 by Contributors |
| * \file tvm/relay/op.h |
| * \brief Primitive operator definition. |
| */ |
| #ifndef TVM_RELAY_OP_H_ |
| #define TVM_RELAY_OP_H_ |
| |
| #include <functional> |
| #include <limits> |
| #include <string> |
| #include <typeinfo> |
| #include <utility> |
| #include <vector> |
| |
| #include "base.h" |
| #include "expr.h" |
| #include "type.h" |
| |
| namespace tvm { |
| namespace relay { |
| |
| // forward declare name. |
| template <typename ValueType> |
| class OpMap; |
| class GenericOpMap; |
| class OpRegistry; |
| |
| /*! |
| * \brief Node container of operator structure. |
| */ |
| class OpNode : public relay::ExprNode { |
| public: |
| /*! \brief name of the operator */ |
| std::string name; |
| /*! \brief the type of the operator */ |
| mutable FuncType op_type; |
| /*! |
| * \brief detailed description of the operator |
| * This can be used to generate docstring automatically for the operator. |
| */ |
| std::string description; |
| /* \brief Information of input arguments to the operator */ |
| Array<AttrFieldInfo> arguments; |
| /*! |
| * \brief The type key of the attribute field |
| * This can be empty, in which case it defaults to anything. |
| */ |
| std::string attrs_type_key; |
| /*! |
| * \brief attribute type index, |
| * this field varies in each run and is not exposed to frontend. |
| */ |
| uint32_t attrs_type_index{0}; |
| /*! |
| * \brief number of input arguments to the operator, |
| * -1 means it is variable length |
| */ |
| int32_t num_inputs = -1; |
| /*! |
| * \brief support level of the operator, |
| * The lower the more priority it contains. |
| * This is in analogies to BLAS levels. |
| */ |
| int32_t support_level = 10; |
| |
| void VisitAttrs(tvm::AttrVisitor* v) final { |
| v->Visit("name", &name); |
| v->Visit("op_type", &op_type); |
| v->Visit("description", &description); |
| v->Visit("arguments", &arguments); |
| v->Visit("attrs_type_key", &attrs_type_key); |
| v->Visit("num_inputs", &num_inputs); |
| v->Visit("support_level", &support_level); |
| } |
| |
| /*! |
| * \brief Check that if current op is a "primtive operator". |
| * That is the arguments are all type variables, and there is a single |
| * type relation applied to the input and output types. |
| */ |
| bool IsPrimitiveOp() const { |
| if (is_primitive_ != -1) return is_primitive_ != 0; |
| is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; |
| return is_primitive_ != 0; |
| } |
| |
| static constexpr const char* _type_key = "relay.Op"; |
| TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); |
| |
| private: |
| // friend class |
| friend class GenericOpMap; |
| friend class OpRegistry; |
| friend bool IsPrimitiveOp(const Expr&); |
| // Program internal unique index of operator. |
| // Used to help index the program. |
| uint32_t index_{0}; |
| // whether this is a primitive op. -1 means unknown. |
| mutable int is_primitive_{-1}; |
| // Internal function to compute if it is primitive op |
| bool IsPrimitiveOp_() const { |
| const auto& fn_ty = this->op_type; |
| if (fn_ty->type_constraints.size() != 1) return false; |
| const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>(); |
| if (rel == nullptr) return false; |
| // validate if the type parameter matches up |
| for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { |
| if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; |
| } |
| return true; |
| } |
| }; |
| |
| /*! |
| * \brief Operator reference class. |
| */ |
| class Op : public relay::Expr { |
| public: |
| /*! \brief default constructor */ |
| Op() {} |
| /*! \brief constructor from node pointer */ |
| explicit Op(NodePtr<Node> n) : Expr(n) {} |
| /*! |
| * \brief access the internal node container |
| * \return the pointer to the internal node container |
| */ |
| inline const OpNode* operator->() const; |
| /*! |
| * \brief Get additional registered attribute about operators. |
| * If nothing has been registered, an empty OpMap will be returned. |
| * \param attr_name The name of the attribute. |
| * \return An OpMap of specified attr_name. |
| * \tparam ValueType The type of the attribute. |
| */ |
| template <typename ValueType> |
| inline static OpMap<ValueType> GetAttr(const std::string& attr_name); |
| /*! |
| * \brief Get an Op for a given operator name. |
| * Will raise an error if the op has not been registered. |
| * \param op_name Name of the operator. |
| * \return Pointer to a Op, valid throughout program lifetime. |
| */ |
| TVM_DLL static const Op& Get(const std::string& op_name); |
| |
| /*! \brief specify container node */ |
| using ContainerType = OpNode; |
| |
| private: |
| /*! |
| * \brief Get generic attrmap given attr name |
| * \param key The attribute key |
| * \return reference to GenericOpMap |
| */ |
| TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); |
| }; |
| |
| /*! \brief Helper structure to register operators */ |
| class OpRegistry { |
| public: |
| /*! \return the operator */ |
| const Op& op() const { return op_; } |
| /*! |
| * \brief setter function during registration |
| * Set the description of operator |
| * \param descr the description string. |
| * \return reference to self. |
| */ |
| inline OpRegistry& describe(const std::string& descr); // NOLINT(*) |
| /*! |
| * \brief Add argument information to the function. |
| * \param name Name of the argument. |
| * \param type Type of the argument. |
| * \param description Description of the argument. |
| * \return reference to self. |
| */ |
| inline OpRegistry& add_argument(const std::string& name, |
| const std::string& type, |
| const std::string& description); |
| /*! |
| * \brief Attach the type function corresponding to the return type. |
| * \param rel_name The type relation name to register. |
| * \param type_rel_func The backing relation function which can solve an arbitrary |
| * relation on variables. |
| * \return reference to self. |
| */ |
| inline OpRegistry& add_type_rel( |
| const std::string& rel_name, |
| runtime::TypedPackedFunc<bool(const Array<Type>&, |
| int, |
| const Attrs&, |
| const TypeReporter&)> type_rel_func); |
| /*! |
| * \brief Set the type key of attributes. |
| * \param type_key The type of of the attrs field. |
| * \return reference to self. |
| */ |
| inline OpRegistry& set_attrs_type_key(const std::string& type_key); |
| /*! |
| * \brief Set the num_inputs |
| * \param n The number of inputs to be set. |
| * \return reference to self. |
| */ |
| inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) |
| /*! |
| * \brief Set the support level of op. |
| * \param level The support level. |
| * \return reference to self. |
| */ |
| inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) |
| /*! |
| * \brief Register additional attributes to operator. |
| * \param attr_name The name of the attribute. |
| * \param value The value to be set. |
| * \param plevel The priority level of this set, |
| * an higher priority level attribute |
| * will replace lower priority level attribute. |
| * Must be bigger than 0. |
| * |
| * Cannot set with same plevel twice in the code. |
| * |
| * \tparam ValueType The type of the value to be set. |
| */ |
| template <typename ValueType> |
| inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) |
| const ValueType& value, int plevel = 10); |
| |
| // set the name of the op to be the same as registry |
| inline OpRegistry& set_name() { // NOLINT(*) |
| if (get()->name.length() == 0) { |
| get()->name = name; |
| } |
| return *this; |
| } |
| /*! \return The global single registry */ |
| TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry(); |
| |
| private: |
| friend class ::dmlc::Registry<OpRegistry>; |
| // the name |
| std::string name; |
| /*! \brief The operator */ |
| Op op_; |
| // private constructor |
| OpRegistry(); |
| // return internal pointer to op. |
| inline OpNode* get(); |
| // update the attribute OpMap |
| TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, |
| int plevel); |
| }; |
| |
| /*! |
| * \brief Generic map to store additional information of Op. |
| */ |
| class GenericOpMap { |
| public: |
| /*! |
| * \brief Check if the map has op as key. |
| * \param op The key to the map |
| * \return 1 if op is contained in map, 0 otherwise. |
| */ |
| inline int count(const Op& op) const; |
| /*! |
| * \brief get the corresponding value element at op |
| * \param op The key to the map |
| * \return the const reference to the content value. |
| */ |
| inline const TVMRetValue& operator[](const Op& op) const; |
| /*! |
| * \brief get the corresponding value element at op with default value. |
| * \param op The key to the map |
| * \param def_value The default value when the key does not exist. |
| * \return the const reference to the content value. |
| * \tparam ValueType The content value type. |
| */ |
| template <typename ValueType> |
| inline ValueType get(const Op& op, ValueType def_value) const; |
| /*! |
| * \brief get the corresponding value element at op with default value. |
| * \param expr The key to the map |
| * \param def_value The default value when the key does not exist |
| * or if expr is not an Op. |
| * \return the const reference to the content value. |
| * \tparam ValueType The content value type. |
| */ |
| template <typename ValueType> |
| inline ValueType get(const Expr& expr, ValueType def_value) const; |
| |
| private: |
| friend class OpRegistry; |
| // the attribute field. |
| std::string attr_name_; |
| // internal data |
| std::vector<std::pair<TVMRetValue, int> > data_; |
| // The value |
| GenericOpMap() = default; |
| }; |
| |
| /*! |
| * \brief Map<Op,ValueType> used to store meta-information about Op. |
| * \tparam ValueType The type of the value stored in map. |
| */ |
| template <typename ValueType> |
| class OpMap { |
| public: |
| /*! |
| * \brief Check if the map has op as key. |
| * \param op The key to the map |
| * \return 1 if op is contained in map, 0 otherwise. |
| */ |
| inline int count(const Op& op) const; |
| /*! |
| * \brief get the corresponding value element at op |
| * \param op The key to the map |
| * \return the const reference to the content value. |
| */ |
| inline ValueType operator[](const Op& op) const; |
| /*! |
| * \brief get the corresponding value element at op with default value. |
| * \param op The key to the map |
| * \param def_value The default value when the key does not exist. |
| * \return the const reference to the content value. |
| */ |
| inline ValueType get(const Op& op, ValueType def_value) const; |
| /*! |
| * \brief get the corresponding value element at op with default value. |
| * \param expr The key to the map |
| * \param def_value The default value when the key does not exist |
| * or if expr is not an Op. |
| * \return the const reference to the content value. |
| */ |
| inline ValueType get(const Expr& expr, ValueType def_value) const; |
| |
| private: |
| friend class Op; |
| // constructor |
| explicit OpMap(const GenericOpMap& map) : map_(map) {} |
| /*! \brief The internal map field */ |
| const GenericOpMap& map_; |
| }; |
| |
| // internal macros to make |
| #define RELAY_REGISTER_VAR_DEF \ |
| static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp |
| |
| /*! |
| * \def RELAY_REGISTER_OP |
| * \brief Register a new operator, or set attribute of the corresponding op. |
| * |
| * \param OpName The name of registry |
| * |
| * \code |
| * |
| * RELAY_REGISTER_OP("add") |
| * .describe("add two inputs together") |
| * .set_num_inputs(2) |
| * .set_attr<OpKernel>("gpu_kernel", AddKernel); |
| * |
| * \endcode |
| */ |
| #define RELAY_REGISTER_OP(OpName) \ |
| DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ |
| ::tvm::relay::OpRegistry::Registry() \ |
| ->__REGISTER_OR_GET__(OpName) \ |
| .set_name() |
| |
| // implementations |
| inline const OpNode* Op::operator->() const { |
| return static_cast<const OpNode*>(node_.get()); |
| } |
| |
| template <typename ValueType> |
| inline OpMap<ValueType> Op::GetAttr(const std::string& key) { |
| return OpMap<ValueType>(Op::GetGenericAttr(key)); |
| } |
| |
| inline OpNode* OpRegistry::get() { |
| return const_cast<OpNode*>(op_.operator->()); |
| } |
| |
| inline OpRegistry& OpRegistry::describe( |
| const std::string& descr) { // NOLINT(*) |
| get()->description = descr; |
| return *this; |
| } |
| |
| inline OpRegistry& OpRegistry::add_argument(const std::string& name, |
| const std::string& type, |
| const std::string& description) { |
| auto n = make_node<AttrFieldInfoNode>(); |
| n->name = name; |
| n->type_info = type; |
| n->description = description; |
| get()->arguments.push_back(AttrFieldInfo(n)); |
| return *this; |
| } |
| |
| inline OpRegistry& OpRegistry::add_type_rel( |
| const std::string& rel_name, |
| runtime::TypedPackedFunc<bool(const Array<Type>&, |
| int, |
| const Attrs&, |
| const TypeReporter&)> type_rel_func) { |
| auto func_name = std::string("tvm.relay.type_relation.") + rel_name; |
| TypeRelationFn env_type_rel_func; |
| |
| if (runtime::Registry::Get(func_name)) { |
| auto env_func = EnvFunc::Get(func_name); |
| env_type_rel_func = env_func; |
| } else { |
| runtime::Registry::Register(func_name) |
| .set_body(type_rel_func.packed()); |
| auto env_func = EnvFunc::Get(func_name); |
| env_type_rel_func = env_func; |
| } |
| |
| Array<TypeVar> type_params; |
| Array<Type> arg_types; |
| |
| // Add inputs. |
| std::string input_name_prefix = "in"; |
| for (int i = 0; i < get()->num_inputs; i++) { |
| auto name = input_name_prefix + std::to_string(i); |
| auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); |
| type_params.push_back(param); |
| arg_types.push_back(param); |
| } |
| |
| Array<Type> ty_call_args = arg_types; |
| |
| // Add output type. |
| auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); |
| type_params.push_back(out_param); |
| // this will trigger copy on write. |
| ty_call_args.push_back(out_param); |
| |
| // The attributes of primitive op is nullptr |
| // |
| // The attributes of primitive operator can vary at the call site. |
| // The type of sum is also dependent on Attrs being passed. |
| // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. |
| // |
| // A common example is sum(x, axis), where the choice of axis |
| // can affect the type of the function. |
| TypeConstraint type_rel = |
| TypeRelationNode::make(env_type_rel_func, |
| ty_call_args, |
| arg_types.size(), |
| Attrs()); |
| |
| auto func_type = |
| FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); |
| |
| get()->op_type = func_type; |
| |
| return *this; |
| } |
| |
| inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) |
| get()->num_inputs = n; |
| return *this; |
| } |
| |
| inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) |
| const std::string& type_key) { |
| get()->attrs_type_key = type_key; |
| get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str()); |
| return *this; |
| } |
| |
| inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) |
| get()->support_level = n; |
| return *this; |
| } |
| |
| template <typename ValueType> |
| inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) |
| const std::string& attr_name, const ValueType& value, int plevel) { |
| CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; |
| TVMRetValue rv; |
| rv = value; |
| UpdateAttr(attr_name, rv, plevel); |
| return *this; |
| } |
| |
| // member functions of OpMap |
| inline int GenericOpMap::count(const Op& op) const { |
| if (op.defined()) { |
| const uint32_t idx = op->index_; |
| return idx < data_.size() ? (data_[idx].second != 0) : 0; |
| } else { |
| return 0; |
| } |
| } |
| |
| inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { |
| CHECK(op.defined()); |
| const uint32_t idx = op->index_; |
| CHECK(idx < data_.size() && data_[idx].second != 0) |
| << "Attribute " << attr_name_ << " has not been registered for Operator " |
| << op->name; |
| return data_[idx].first; |
| } |
| |
| template <typename ValueType> |
| inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { |
| CHECK(op.defined()); |
| const uint32_t idx = op->index_; |
| if (idx < data_.size() && data_[idx].second != 0) { |
| return data_[idx].first; |
| } else { |
| return value; |
| } |
| } |
| |
| template <typename ValueType> |
| inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const { |
| CHECK(expr.defined()); |
| if (const OpNode* op = expr.as<OpNode>()) { |
| const uint32_t idx = op->index_; |
| if (idx < data_.size() && data_[idx].second != 0) { |
| return data_[idx].first; |
| } else { |
| return value; |
| } |
| } else { |
| return value; |
| } |
| } |
| |
| template <typename ValueType> |
| inline int OpMap<ValueType>::count(const Op& op) const { |
| return map_.count(op); |
| } |
| |
| template <typename ValueType> |
| inline ValueType OpMap<ValueType>::operator[](const Op& op) const { |
| return map_[op]; |
| } |
| |
| template <typename ValueType> |
| inline ValueType OpMap<ValueType>::get(const Op& op, |
| ValueType def_value) const { |
| return map_.get<ValueType>(op, def_value); |
| } |
| |
| template <typename ValueType> |
| inline ValueType OpMap<ValueType>::get(const Expr& expr, |
| ValueType def_value) const { |
| return map_.get<ValueType>(expr, def_value); |
| } |
| |
| |
| /*! |
| * \brief Check that an expression is a "primtive operator". |
| * |
| * Will return true if the expression is an operator which |
| * matches the form of primtive operators registered directly |
| * by the Relay codebase. |
| * |
| * That is the arguments are all type variables, and there is a single |
| * type relation applied to the input and output types. |
| */ |
| inline bool IsPrimitiveOp(const Expr& expr) { |
| const auto* op = expr.as<OpNode>(); |
| return op != nullptr && op->IsPrimitiveOp(); |
| } |
| |
| } // namespace relay |
| } // namespace tvm |
| #endif // TVM_RELAY_OP_H_ |