blob: 69b7ec1f6e607b38392a75ca6c9a03648389ec62 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file attr_functor.h
* \brief A way to define arbitrary function signature
* with dispatch on common attributes.
*
* Common attributes include:
* - int, float, str constants
* - array of attributes
* - map of attributes
*/
#ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_
namespace tvm {
template <typename FType>
class AttrFunctor;
#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
// A functor for common attribute information.
template <typename R, typename... Args>
class AttrFunctor<R(const NodeRef& n, Args...)> {
private:
using TSelf = AttrFunctor<R(const NodeRef& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitAttr(const NodeRef& n, Args... args) {
static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...);
} else {
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
}
}
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
private:
// initialize the vtable.
static FType InitVTable() {
using namespace ir;
FType vtable;
// Set dispatch
ATTR_FUNCTOR_DISPATCH(StrMapNode);
ATTR_FUNCTOR_DISPATCH(ArrayNode);
ATTR_FUNCTOR_DISPATCH(IntImm);
ATTR_FUNCTOR_DISPATCH(UIntImm);
ATTR_FUNCTOR_DISPATCH(FloatImm);
ATTR_FUNCTOR_DISPATCH(StringImm);
ATTR_FUNCTOR_DISPATCH(Variable);
ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
ATTR_FUNCTOR_DISPATCH(GT);
ATTR_FUNCTOR_DISPATCH(LE);
ATTR_FUNCTOR_DISPATCH(LT);
ATTR_FUNCTOR_DISPATCH(EQ);
ATTR_FUNCTOR_DISPATCH(NE);
ATTR_FUNCTOR_DISPATCH(And);
ATTR_FUNCTOR_DISPATCH(Or);
ATTR_FUNCTOR_DISPATCH(Not);
ATTR_FUNCTOR_DISPATCH(Cast);
ATTR_FUNCTOR_DISPATCH(Call);
ATTR_FUNCTOR_DISPATCH(Select);
return vtable;
}
};
class AttrsEqualHandler :
protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs);
protected:
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const NodeRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const NodeRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Node* lhs) final;
size_t VisitAttr_(const ir::IntImm* lhs) final;
size_t VisitAttr_(const ir::UIntImm* lhs) final;
size_t VisitAttr_(const ir::FloatImm* lhs) final;
size_t VisitAttr_(const ir::StringImm* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
size_t VisitAttr_(const ir::GT* op) final;
size_t VisitAttr_(const ir::LE* op) final;
size_t VisitAttr_(const ir::LT* op) final;
size_t VisitAttr_(const ir::EQ* op) final;
size_t VisitAttr_(const ir::NE* op) final;
size_t VisitAttr_(const ir::And* op) final;
size_t VisitAttr_(const ir::Or* op) final;
size_t VisitAttr_(const ir::Not* op) final;
size_t VisitAttr_(const ir::Cast* op) final;
size_t VisitAttr_(const ir::Call* op) final;
size_t VisitAttr_(const ir::Select* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_