| /*! |
| * Copyright (c) 2018 by Contributors |
| * \file attr_functor.h |
| * \brief A way to define arbitrary function signature |
| * with dispatch on common attributes. |
| * |
| * Common attributes include: |
| * - int, float, str constants |
| * - array of attributes |
| * - map of attributes |
| */ |
| #ifndef TVM_LANG_ATTR_FUNCTOR_H_ |
| #define TVM_LANG_ATTR_FUNCTOR_H_ |
| |
| namespace tvm { |
| |
| template <typename FType> |
| class AttrFunctor; |
| |
| #define ATTR_FUNCTOR_DEFAULT \ |
| { return VisitAttrDefault_(op, std::forward<Args>(args)...); } |
| |
| |
| #define ATTR_FUNCTOR_DISPATCH(OP) \ |
| vtable.template set_dispatch<OP>( \ |
| [](const NodeRef& n, TSelf* self, Args... args) { \ |
| return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \ |
| std::forward<Args>(args)...); \ |
| }); \ |
| |
| // A functor for common attribute information. |
| template <typename R, typename... Args> |
| class AttrFunctor<R(const NodeRef& n, Args...)> { |
| private: |
| using TSelf = AttrFunctor<R(const NodeRef& n, Args...)>; |
| using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>; |
| |
| public: |
| /*! \brief the result type of this functor */ |
| using result_type = R; |
| /*! |
| * \brief The functor call. |
| * \param n The expression node. |
| * \param args Additional arguments. |
| * \return The result of the call |
| */ |
| virtual R VisitAttr(const NodeRef& n, Args... args) { |
| static FType vtable = InitVTable(); |
| if (vtable.can_dispatch(n)) { |
| return vtable(n, this, std::forward<Args>(args)...); |
| } else { |
| return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); |
| } |
| } |
| virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; |
| virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| // deep comparison of symbolic integer expressions. |
| virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
| |
| private: |
| // initialize the vtable. |
| static FType InitVTable() { |
| using namespace ir; |
| FType vtable; |
| // Set dispatch |
| ATTR_FUNCTOR_DISPATCH(StrMapNode); |
| ATTR_FUNCTOR_DISPATCH(ArrayNode); |
| ATTR_FUNCTOR_DISPATCH(IntImm); |
| ATTR_FUNCTOR_DISPATCH(UIntImm); |
| ATTR_FUNCTOR_DISPATCH(FloatImm); |
| ATTR_FUNCTOR_DISPATCH(StringImm); |
| ATTR_FUNCTOR_DISPATCH(Variable); |
| ATTR_FUNCTOR_DISPATCH(Add); |
| ATTR_FUNCTOR_DISPATCH(Sub); |
| ATTR_FUNCTOR_DISPATCH(Mul); |
| ATTR_FUNCTOR_DISPATCH(Div); |
| ATTR_FUNCTOR_DISPATCH(Min); |
| ATTR_FUNCTOR_DISPATCH(Max); |
| ATTR_FUNCTOR_DISPATCH(GE); |
| ATTR_FUNCTOR_DISPATCH(GT); |
| ATTR_FUNCTOR_DISPATCH(LE); |
| ATTR_FUNCTOR_DISPATCH(LT); |
| ATTR_FUNCTOR_DISPATCH(EQ); |
| ATTR_FUNCTOR_DISPATCH(NE); |
| ATTR_FUNCTOR_DISPATCH(And); |
| ATTR_FUNCTOR_DISPATCH(Or); |
| ATTR_FUNCTOR_DISPATCH(Not); |
| ATTR_FUNCTOR_DISPATCH(Cast); |
| ATTR_FUNCTOR_DISPATCH(Call); |
| ATTR_FUNCTOR_DISPATCH(Select); |
| return vtable; |
| } |
| }; |
| |
| class AttrsEqualHandler : |
| protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> { |
| public: |
| /*! |
| * \brief Check if lhs equals rhs |
| * \param lhs The left operand. |
| * \param rhs The right operand. |
| */ |
| bool Equal(const NodeRef& lhs, const NodeRef& rhs); |
| |
| protected: |
| bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; |
| bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; |
| }; |
| |
| class AttrsHashHandler : |
| protected AttrFunctor<size_t(const NodeRef&)> { |
| public: |
| /*! |
| * \brief Get hash value of node |
| * \param node The node to be hashed. |
| */ |
| size_t Hash(const NodeRef& node) { |
| if (!node.defined()) return 0; |
| return this->VisitAttr(node); |
| } |
| |
| protected: |
| size_t VisitAttrDefault_(const Node* lhs) final; |
| size_t VisitAttr_(const ir::IntImm* lhs) final; |
| size_t VisitAttr_(const ir::UIntImm* lhs) final; |
| size_t VisitAttr_(const ir::FloatImm* lhs) final; |
| size_t VisitAttr_(const ir::StringImm* lhs) final; |
| size_t VisitAttr_(const ArrayNode* lhs) final; |
| size_t VisitAttr_(const StrMapNode* lhs) final; |
| size_t VisitAttr_(const ir::Add* op) final; |
| size_t VisitAttr_(const ir::Sub* op) final; |
| size_t VisitAttr_(const ir::Mul* op) final; |
| size_t VisitAttr_(const ir::Div* op) final; |
| size_t VisitAttr_(const ir::Mod* op) final; |
| size_t VisitAttr_(const ir::Min* op) final; |
| size_t VisitAttr_(const ir::Max* op) final; |
| size_t VisitAttr_(const ir::GE* op) final; |
| size_t VisitAttr_(const ir::GT* op) final; |
| size_t VisitAttr_(const ir::LE* op) final; |
| size_t VisitAttr_(const ir::LT* op) final; |
| size_t VisitAttr_(const ir::EQ* op) final; |
| size_t VisitAttr_(const ir::NE* op) final; |
| size_t VisitAttr_(const ir::And* op) final; |
| size_t VisitAttr_(const ir::Or* op) final; |
| size_t VisitAttr_(const ir::Not* op) final; |
| size_t VisitAttr_(const ir::Cast* op) final; |
| size_t VisitAttr_(const ir::Call* op) final; |
| size_t VisitAttr_(const ir::Select* op) final; |
| /*! |
| * \brief alias of dmlc::HashCombine |
| * \param lhs The first hash value. |
| * \param rhs The second hash value. |
| */ |
| static size_t Combine(size_t lhs, size_t rhs) { |
| return dmlc::HashCombine(lhs, rhs); |
| } |
| }; |
| } // namespace tvm |
| #endif // TVM_LANG_ATTR_FUNCTOR_H_ |