blob: c4fccfbe6b1b5a79e7857bd0944ac87db277e36e [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/ir_visitor.h
* \brief Visitor to quickly visit IR trees
*/
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_
#include "ir.h"
#include "tvm/node/ir_functor.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for visitor to iterative traverse the IR
*
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new Node.
*
* \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutation of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const NodeRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
*
* \endcode
*/
class TVM_DLL IRVisitor {
public:
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const NodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
/*! \brief destructor */
virtual ~IRVisitor() {}
/*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
virtual void Visit_(const Add* op);
virtual void Visit_(const Sub* op);
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
virtual void Visit_(const NE* op);
virtual void Visit_(const LT* op);
virtual void Visit_(const LE* op);
virtual void Visit_(const GT* op);
virtual void Visit_(const GE* op);
virtual void Visit_(const And* op);
virtual void Visit_(const Or* op);
virtual void Visit_(const Reduce* op);
virtual void Visit_(const Cast* op);
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
virtual void Visit_(const Broadcast* op);
virtual void Visit_(const AssertStmt* op);
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Prefetch* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
virtual void Visit_(const UIntImm* op);
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_VISITOR_H_