| /*! |
| * Copyright (c) 2018 by Contributors |
| * \file tvm/attrs.h |
| * \brief TVM attribute module |
| * |
| * This module enables declaration of named attributes |
| * which support default value setup and bound checking. |
| * |
| * \code |
| * struct MyAttrs : public tvm::AttrsNode<MyAttrs> { |
| * float learning_rate; |
| * int num_hidden; |
| * std::string name; |
| * // declare attribute fields in header file |
| * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { |
| * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); |
| * TVM_ATTR_FIELD(learning_rate).set_default(0.01f); |
| * TVM_ATTR_FIELD(name).set_default("hello"); |
| * } |
| * }; |
| * // register it in cc file |
| * TVM_REGISTER_NODE_TYPE(MyAttrs); |
| * \endcode |
| * |
| * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD |
| */ |
| #ifndef TVM_ATTRS_H_ |
| #define TVM_ATTRS_H_ |
| |
| #include <dmlc/common.h> |
| #include <unordered_map> |
| #include <vector> |
| #include <functional> |
| #include <type_traits> |
| #include <string> |
| #include "ir.h" |
| #include "base.h" |
| #include "expr.h" |
| #include "packed_func_ext.h" |
| |
| namespace tvm { |
| /*! |
| * \brief Declare an attribute function. |
| * \param ClassName The name of the class. |
| * \param TypeKey The type key to be used by the TVM node system. |
| */ |
| #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ |
| static constexpr const char* _type_key = TypeKey; \ |
| TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode); \ |
| template<typename FVisit> \ |
| void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) |
| |
| |
| /*! |
| * \brief Declare an attribute field. |
| * \param FieldName The field name. |
| */ |
| #define TVM_ATTR_FIELD(FieldName) \ |
| __fvisit__(#FieldName, &FieldName) |
| |
| |
| /*! |
| * \brief Create a NodeRef type that represents null. |
| * \tparam TNodeRef the type to be created. |
| * \return A instance that will represent None. |
| */ |
| template<typename TNodeRef> |
| inline TNodeRef NullValue() { |
| return TNodeRef(NodePtr<Node>(nullptr)); |
| } |
| |
| template<> |
| inline Type NullValue<Type>() { |
| return Type(Type::Handle, 0, 0); |
| } |
| |
| /*! \brief Error thrown during attribute checking. */ |
| struct AttrError : public dmlc::Error { |
| /*! |
| * \brief constructor |
| * \param msg error message |
| */ |
| explicit AttrError(const std::string &msg) |
| : dmlc::Error(msg) {} |
| }; |
| |
| /*! |
| * \brief Information about attribute fields in string representations. |
| */ |
| class AttrFieldInfoNode : public Node { |
| public: |
| /*! \brief name of the field */ |
| std::string name; |
| /*! \brief type docstring information in str. */ |
| std::string type_info; |
| /*! \brief detailed description of the type */ |
| std::string description; |
| |
| void VisitAttrs(AttrVisitor* v) final { |
| v->Visit("name", &name); |
| v->Visit("type_info", &type_info); |
| v->Visit("description", &description); |
| } |
| static constexpr const char* _type_key = "AttrFieldInfo"; |
| TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); |
| }; |
| |
| /*! \brief AttrFieldInfo */ |
| TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); |
| |
| class AttrsHashHandler; |
| class AttrsEqualHandler; |
| /*! |
| * \brief Content-aware Equality comparator for attrs. |
| * |
| * This comparator will recursively deep compare the following Attributes. |
| * |
| * - IntImm, UIntImm, FloatImm, StringImm |
| * - Any subclass of BaseAttrsNode |
| * - Array of Attributes. |
| * - Map from string to Attributes. |
| */ |
| class AttrsEqual { |
| public: |
| bool operator()(const double& lhs, const double& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const int64_t& lhs, const int64_t& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const int& lhs, const int& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const bool& lhs, const bool& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const std::string& lhs, const std::string& rhs) const { |
| return lhs == rhs; |
| } |
| bool operator()(const Type& lhs, const Type& rhs) const { |
| return lhs == rhs; |
| } |
| // node comparator |
| TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; |
| |
| protected: |
| friend class AttrsEqualHandler; |
| /*! \brief internal handle. */ |
| AttrsEqualHandler* handler_{nullptr}; |
| }; |
| |
| /*! |
| * \brief Content-aware hash function. |
| * |
| * This hash functor will recursively hash the content of the Attributes. |
| * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b); |
| */ |
| class AttrsHash { |
| public: |
| size_t operator()(const double& value) const { |
| return std::hash<double>()(value); |
| } |
| size_t operator()(const int64_t& value) const { |
| return std::hash<int64_t>()(value); |
| } |
| size_t operator()(const uint64_t& value) const { |
| return std::hash<uint64_t>()(value); |
| } |
| size_t operator()(const int& value) const { |
| return std::hash<int>()(value); |
| } |
| size_t operator()(const bool& value) const { |
| return std::hash<bool>()(value); |
| } |
| size_t operator()(const std::string& value) const { |
| return std::hash<std::string>()(value); |
| } |
| size_t operator()(const Type& value) const { |
| return std::hash<int>()( |
| static_cast<int>(value.code()) | |
| (static_cast<int>(value.bits()) << 8) | |
| (static_cast<int>(value.lanes()) << 16)); |
| } |
| TVM_DLL size_t operator()(const NodeRef& value) const; |
| |
| private: |
| friend class AttrsHashHandler; |
| /*! \brief internal handle. */ |
| AttrsHashHandler* handler_{nullptr}; |
| }; |
| |
| /*! |
| * \brief Base class of all attribute class |
| * \note Do not subclass AttrBaseNode directly, |
| * subclass AttrsNode instead. |
| * \sa AttrsNode |
| */ |
| class BaseAttrsNode : public Node { |
| public: |
| using TVMArgs = runtime::TVMArgs; |
| using TVMRetValue = runtime::TVMRetValue; |
| /*! |
| * \brief Initialize the attributes by sequence of arguments |
| * \param args The postional arguments in the form |
| * [key0, value0, key1, value1, ..., key_n, value_n] |
| */ |
| template<typename... Args> |
| inline void InitBySeq(Args&& ...args); |
| /*! |
| * \brief Print readible docstring to ostream, add newline. |
| * \param os the stream to print the docstring to. |
| */ |
| inline void PrintDocString(std::ostream &os) const; // NOLINT(*) |
| /*! |
| * \brief Visit attributes that do not equal the default value. |
| * |
| * \note This is useful to extract fields for concise printing. |
| * \param v The visitor |
| */ |
| TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0; |
| /*! |
| * \brief Get the field information |
| * \return The fields in the Attrs. |
| */ |
| TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; |
| /*! |
| * \brief Initialize the attributes by arguments. |
| * \param kwargs The key value pairs for initialization. |
| * [key0, value0, key1, value1, ..., key_n, value_n] |
| * \param allow_unknown Whether allow additional unknown fields. |
| * \note This function throws when the required field is not present. |
| */ |
| TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; |
| /*! |
| * \brief Whether this attribute's content equals to another node. |
| * \param other The pointer to another node. |
| * \param equal The equal comparator |
| * \return The comparison result. |
| */ |
| TVM_DLL virtual bool ContentEqual( |
| const Node* other, AttrsEqual equal) const = 0; |
| /*! |
| * \brief Content aware hash. |
| * \param hasher The hasher to run the hash. |
| * \return the hash result. |
| */ |
| TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; |
| |
| static constexpr const char* _type_key = "Attrs"; |
| TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); |
| }; |
| |
| /*! \brief Base attribute container for all attributes */ |
| class Attrs : public NodeRef { |
| public: |
| // normal constructor |
| Attrs() {} |
| // construct from shared ptr. |
| explicit Attrs(NodePtr<Node> n) : NodeRef(n) {} |
| |
| /*! \return The attribute node */ |
| const BaseAttrsNode* operator->() const { |
| return ptr(); |
| } |
| /*! \brief specify container node */ |
| using ContainerType = BaseAttrsNode; |
| |
| private: |
| /*! \return the internal attribute node */ |
| const BaseAttrsNode* ptr() const { |
| return static_cast<const BaseAttrsNode*>(node_.get()); |
| } |
| }; |
| |
| /*! |
| * \brief Specialized attribute type that is backed by a map. |
| * The DictAttrsNode implements the Attrs behavior, |
| * its fields are directly accessible via object.field_name |
| * like other normal nodes. |
| */ |
| class DictAttrsNode : public BaseAttrsNode { |
| public: |
| /*! \brief internal attrs map */ |
| Map<std::string, NodeRef> dict; |
| /*! |
| * \brief Consruct a Attrs backed by DictAttrsNode. |
| * \param dict The attributes. |
| * \return The dict attributes. |
| */ |
| TVM_DLL static Attrs make(Map<std::string, NodeRef> dict); |
| // implementations |
| void VisitAttrs(AttrVisitor* v) final; |
| void VisitNonDefaultAttrs(AttrVisitor* v) final; |
| void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; |
| Array<AttrFieldInfo> ListFieldInfo() const final; |
| bool ContentEqual(const Node* other, AttrsEqual equal) const final; |
| size_t ContentHash(AttrsHash hasher) const final; |
| // type info |
| static constexpr const char* _type_key = "DictAttrs"; |
| TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); |
| }; |
| |
| |
| // Namespace containing detail implementations |
| namespace detail { |
| using runtime::TVMArgValue; |
| |
| // helper entry that does nothing in set_default/bound/describe calls. |
| struct AttrNopEntry { |
| using TSelf = AttrNopEntry; |
| |
| TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { |
| return *this; |
| } |
| }; |
| |
| // Wrapper for normal visitor. |
| class AttrNormalVisitor { |
| public: |
| explicit AttrNormalVisitor(AttrVisitor* visitor) |
| : visitor_(visitor) { |
| } |
| template<typename T> |
| AttrNopEntry operator()(const char* key, T* value) { |
| visitor_->Visit(key, value); |
| return AttrNopEntry(); |
| } |
| |
| private: |
| AttrVisitor* visitor_; |
| }; |
| |
| // Wrapper for normal visitor. |
| class AttrsEqualVisitor { |
| public: |
| bool result_{true}; |
| // constructor |
| AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) |
| : lhs_(lhs), rhs_(rhs), equal_(equal) { |
| } |
| template<typename T> |
| AttrNopEntry operator()(const char* key, T* lhs_value) { |
| if (!result_) return AttrNopEntry(); |
| const T* rhs_value = |
| reinterpret_cast<const T*>( |
| reinterpret_cast<const char*>(rhs_) + |
| (reinterpret_cast<const char*>(lhs_value) - |
| reinterpret_cast<const char*>(lhs_))); |
| if (!equal_(*lhs_value, *rhs_value)) { |
| result_ = false; |
| } |
| return AttrNopEntry(); |
| } |
| |
| private: |
| const Node* lhs_; |
| const Node* rhs_; |
| const AttrsEqual& equal_; |
| }; |
| |
| class AttrsHashVisitor { |
| public: |
| explicit AttrsHashVisitor(const AttrsHash& hasher) |
| : hasher_(hasher) {} |
| |
| size_t result_{0}; |
| |
| template<typename T> |
| AttrNopEntry operator()(const char* key, T* value) { |
| result_ = dmlc::HashCombine(result_, hasher_(*value)); |
| return AttrNopEntry(); |
| } |
| |
| private: |
| const AttrsHash& hasher_; |
| }; |
| |
| // helper entry that does initialization, set default. |
| template<typename T> |
| struct AttrInitEntry { |
| // The attributes |
| using TSelf = AttrInitEntry<T>; |
| // The type key |
| const char* type_key_; |
| // field name |
| const char* key_; |
| // internal value. |
| T* value_; |
| // whether the value is missing. |
| bool value_missing_{true}; |
| // If the value is still missing in destruction time throw an error. |
| ~AttrInitEntry() DMLC_THROW_EXCEPTION { |
| if (value_missing_) { |
| std::ostringstream os; |
| os << type_key_ << ": Cannot find required field \'" << key_ |
| << "\' during initialization"; |
| throw AttrError(os.str()); |
| } |
| } |
| // override fields. |
| // This function sets the lower bound of the attribute |
| TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { |
| if (this->value_missing_) return *this; |
| const T& val = *value_; |
| if (begin > val) { |
| std::ostringstream os; |
| os << type_key_ << "." << key_ << ": " |
| << "value " << val |
| << " is smaller than the lower bound " << begin; |
| throw AttrError(os.str()); |
| } |
| return *this; |
| } |
| // This function sets the upper bound of the attribute |
| TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { |
| if (this->value_missing_) return *this; |
| const T& val = *value_; |
| if (val > end) { |
| std::ostringstream os; |
| os << type_key_ << "." << key_ << ": " |
| << "value " << val |
| << " is bigger than the upper bound " << end; |
| throw AttrError(os.str()); |
| } |
| return *this; |
| } |
| // set default when |
| TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { |
| if (!value_missing_) return *this; |
| *value_ = value; |
| value_missing_ = false; |
| return *this; |
| } |
| TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { |
| return *this; |
| } |
| }; |
| |
| // Template function to allow smart conversion |
| // from Expr types into the constants. |
| template<typename T> |
| inline void SetValue(T* ptr, const TVMArgValue& val) { |
| *ptr = val.operator T(); |
| } |
| template<typename T> |
| inline void SetIntValue(T* ptr, const TVMArgValue& val) { |
| if (val.type_code() == kDLInt) { |
| *ptr = static_cast<T>(val.value().v_int64); |
| } else { |
| Expr expr = val; |
| CHECK(expr.defined()); |
| if (const ir::IntImm* op = expr.as<ir::IntImm>()) { |
| *ptr = static_cast<T>(op->value); |
| } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { |
| *ptr = static_cast<T>(op->value); |
| } else { |
| LOG(FATAL) << "Expect int value, but get " << expr->type_key(); |
| } |
| } |
| } |
| template<> |
| inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { |
| if (val.type_code() == kStr) { |
| *ptr = val.operator std::string(); |
| } else { |
| Expr expr = val; |
| const ir::StringImm* op = expr.as<ir::StringImm>(); |
| CHECK(op != nullptr); |
| *ptr = op->value; |
| } |
| } |
| template<> |
| inline void SetValue(Type* ptr, const TVMArgValue& val) { |
| *ptr = val.operator Type(); |
| } |
| template<> |
| inline void SetValue<double>(double* ptr, const TVMArgValue& val) { |
| if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { |
| *ptr = val.operator double(); |
| } else { |
| Expr expr = val; |
| CHECK(expr.defined()); |
| if (const ir::IntImm* op = expr.as<ir::IntImm>()) { |
| *ptr = static_cast<double>(op->value); |
| } else if (const ir::IntImm* op = expr.as<ir::IntImm>()) { |
| *ptr = static_cast<double>(op->value); |
| } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { |
| *ptr = static_cast<double>(op->value); |
| } else { |
| LOG(FATAL) << "Expect float value, but get " << expr->type_key(); |
| } |
| } |
| } |
| template<> |
| inline void SetValue<int>(int* ptr, const TVMArgValue& val) { |
| SetIntValue(ptr, val); |
| } |
| template<> |
| inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) { |
| SetIntValue(ptr, val); |
| } |
| template<> |
| inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) { |
| SetIntValue(ptr, val); |
| } |
| template<> |
| inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) { |
| SetIntValue(ptr, val); |
| } |
| |
| // Visitor for value initialization |
| template<typename FFind> |
| class AttrInitVisitor { |
| public: |
| // Counter of number of matched attributes during visit. |
| // This is used to decide if there is additional unmatched attributes. |
| size_t hit_count_{0}; |
| // constructor |
| AttrInitVisitor(const char* type_key, FFind ffind) |
| : type_key_(type_key), ffind_(ffind) { |
| } |
| |
| template<typename T> |
| AttrInitEntry<T> operator()(const char* key, T* value) { |
| TVMArgValue val; |
| AttrInitEntry<T> opt; |
| opt.type_key_ = type_key_; |
| opt.key_ = key; |
| opt.value_ = value; |
| if (ffind_(key, &val)) { |
| SetValue(value, val); |
| opt.value_missing_ = false; |
| ++hit_count_; |
| } else { |
| opt.value_missing_ = true; |
| } |
| return opt; |
| } |
| |
| private: |
| // the type key |
| const char* type_key_; |
| FFind ffind_; |
| }; |
| |
| template<typename FFind> |
| inline AttrInitVisitor<FFind> CreateInitVisitor( |
| const char* type_key, |
| FFind ffind) { |
| return AttrInitVisitor<FFind>(type_key, ffind); |
| } |
| |
| /*! |
| * \brief Helper struct to get the type name known to tvm. |
| * \tparam T the type we are interested in. |
| */ |
| template<typename T> |
| struct TypeName { |
| static constexpr const char* value = T::ContainerType::_type_key; |
| }; |
| |
| template<> |
| struct TypeName<int> { |
| static constexpr const char* value = "int"; |
| }; |
| |
| template<> |
| struct TypeName<int64_t> { |
| static constexpr const char* value = "int64"; |
| }; |
| |
| template<> |
| struct TypeName<uint64_t> { |
| static constexpr const char* value = "uint64_t"; |
| }; |
| |
| template<> |
| struct TypeName<Type> { |
| static constexpr const char* value = "Type"; |
| }; |
| |
| template<> |
| struct TypeName<std::string> { |
| static constexpr const char* value = "str"; |
| }; |
| |
| template<> |
| struct TypeName<bool> { |
| static constexpr const char* value = "bool"; |
| }; |
| |
| template<> |
| struct TypeName<void*> { |
| static constexpr const char* value = "handle"; |
| }; |
| |
| template<> |
| struct TypeName<double> { |
| static constexpr const char* value = "double"; |
| }; |
| |
| class AttrDocEntry { |
| public: |
| using TSelf = AttrDocEntry; |
| |
| explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info) |
| : info_(info) { |
| } |
| TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { |
| info_->description = str; |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { |
| std::ostringstream os; |
| os << info_->type_info << ", default=" << value; |
| info_->type_info = os.str(); |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { |
| return *this; |
| } |
| template<typename T> |
| TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { |
| return *this; |
| } |
| |
| private: |
| NodePtr<AttrFieldInfoNode> info_; |
| }; |
| |
| class AttrDocVisitor { |
| public: |
| template<typename T> |
| AttrDocEntry operator()(const char* key, T* v) { |
| NodePtr<AttrFieldInfoNode> info |
| = make_node<AttrFieldInfoNode>(); |
| info->name = key; |
| info->type_info = TypeName<T>::value; |
| fields_.push_back(AttrFieldInfo(info)); |
| return AttrDocEntry(info); |
| } |
| |
| Array<AttrFieldInfo> fields_; |
| }; |
| |
| class AttrExistVisitor { |
| public: |
| std::string key_; |
| bool exist_{false}; |
| |
| template<typename T> |
| AttrNopEntry operator()(const char* key, T* v) { |
| if (exist_) return AttrNopEntry(); |
| if (key == key_) exist_ = true; |
| return AttrNopEntry(); |
| } |
| }; |
| |
| template<typename T> |
| struct AttrTriggerNonDefaultEntry { |
| using TSelf = AttrTriggerNonDefaultEntry<T>; |
| // constructor |
| AttrTriggerNonDefaultEntry( |
| AttrVisitor* visitor, const char* key, T* data) |
| : visitor_(visitor), key_(key), data_(data) {} |
| |
| ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { |
| if (trigger_) { |
| visitor_->Visit(key_, data_); |
| } |
| } |
| TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { |
| return *this; |
| } |
| TSelf& set_default(const T& value) { |
| if (AttrsEqual()(value, *data_)) { |
| trigger_ = false; |
| } |
| return *this; |
| } |
| TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { |
| return *this; |
| } |
| TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { |
| return *this; |
| } |
| |
| private: |
| AttrVisitor* visitor_; |
| const char * key_; |
| T *data_; |
| bool trigger_{true}; |
| }; |
| |
| class AttrNonDefaultVisitor { |
| public: |
| explicit AttrNonDefaultVisitor(AttrVisitor* visitor) |
| : visitor_(visitor) { |
| } |
| template<typename T> |
| AttrTriggerNonDefaultEntry<T> |
| operator()(const char* key, T* value) { |
| return AttrTriggerNonDefaultEntry<T>(visitor_, key, value); |
| } |
| |
| private: |
| AttrVisitor* visitor_; |
| }; |
| } // namespace detail |
| |
| /*! |
| * \brief The base class of the all the |
| * Use "curiously recurring template pattern". |
| * |
| * \tparam DerivedType The final attribute type. |
| */ |
| template<typename DerivedType> |
| class AttrsNode : public BaseAttrsNode { |
| public: |
| void VisitAttrs(AttrVisitor* v) final { |
| ::tvm::detail::AttrNormalVisitor vis(v); |
| self()->__VisitAttrs__(vis); |
| } |
| |
| void VisitNonDefaultAttrs(AttrVisitor* v) final { |
| ::tvm::detail::AttrNonDefaultVisitor vis(v); |
| self()->__VisitAttrs__(vis); |
| } |
| |
| void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { |
| CHECK_EQ(args.size() % 2, 0); |
| const int kLinearSearchBound = 16; |
| int hit_count = 0; |
| // applies two stratgies to lookup |
| if (args.size() < kLinearSearchBound) { |
| // linear search. |
| auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { |
| for (int i = 0; i < args.size(); i += 2) { |
| CHECK_EQ(args.type_codes[i], kStr); |
| if (!std::strcmp(key, args.values[i].v_str)) { |
| *val = args[i + 1]; |
| return true; |
| } |
| } |
| return false; |
| }; |
| auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); |
| self()->__VisitAttrs__(vis); |
| hit_count = vis.hit_count_; |
| } else { |
| // construct a map then do lookup. |
| std::unordered_map<std::string, runtime::TVMArgValue> kwargs; |
| for (int i = 0; i < args.size(); i += 2) { |
| CHECK_EQ(args.type_codes[i], kStr); |
| kwargs[args[i].operator std::string()] = args[i + 1]; |
| } |
| auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { |
| auto it = kwargs.find(key); |
| if (it != kwargs.end()) { |
| *val = it->second; |
| return true; |
| } |
| return false; |
| }; |
| auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); |
| self()->__VisitAttrs__(vis); |
| hit_count = vis.hit_count_; |
| } |
| // error handling, slow path |
| if (hit_count * 2 != args.size() && !allow_unknown) { |
| for (int i = 0; i < args.size(); i += 2) { |
| ::tvm::detail::AttrExistVisitor visitor; |
| visitor.key_ = args[i].operator std::string(); |
| self()->__VisitAttrs__(visitor); |
| if (!visitor.exist_) { |
| std::ostringstream os; |
| os << DerivedType::_type_key |
| << ": does not have field \'" << visitor.key_ |
| << "\', Possible fields:\n"; |
| os << "----------------\n"; |
| this->PrintDocString(os); |
| throw AttrError(os.str()); |
| } |
| } |
| } |
| } |
| |
| Array<AttrFieldInfo> ListFieldInfo() const final { |
| ::tvm::detail::AttrDocVisitor visitor; |
| self()->__VisitAttrs__(visitor); |
| return visitor.fields_; |
| } |
| |
| bool ContentEqual(const Node* other, AttrsEqual equal) const final { |
| DerivedType* pself = self(); |
| if (pself == other) return true; |
| if (other == nullptr) return false; |
| if (pself->type_index() != other->type_index()) return false; |
| ::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal); |
| self()->__VisitAttrs__(visitor); |
| return visitor.result_; |
| } |
| |
| size_t ContentHash(AttrsHash hasher) const final { |
| ::tvm::detail::AttrsHashVisitor visitor(hasher); |
| visitor.result_ = std::hash<std::string>()(this->type_key()); |
| self()->__VisitAttrs__(visitor); |
| return visitor.result_; |
| } |
| |
| private: |
| DerivedType* self() const { |
| return const_cast<DerivedType*>( |
| static_cast<const DerivedType*>(this)); |
| } |
| }; |
| |
| |
| template<typename... Args> |
| inline void BaseAttrsNode::InitBySeq(Args&& ...args) { |
| runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { |
| this->InitByPackedArgs(args); |
| }); |
| pf(std::forward<Args>(args)...); |
| } |
| |
| inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) |
| Array<AttrFieldInfo> entry = this->ListFieldInfo(); |
| for (AttrFieldInfo info : entry) { |
| os << info->name << " : " << info->type_info << '\n'; |
| if (info->description.length() != 0) { |
| os << " " << info->description << '\n'; |
| } |
| } |
| } |
| |
| } // namespace tvm |
| #endif // TVM_ATTRS_H_ |