| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file ir_mutator.cc |
| */ |
| #include <tvm/ir.h> |
| #include <tvm/ir_mutator.h> |
| #include <tvm/packed_func_ext.h> |
| #include "ir_util.h" |
| |
| namespace tvm { |
| namespace ir { |
| |
| class IRTransformer final : public IRMutator { |
| public: |
| IRTransformer(const runtime::PackedFunc& f_preorder, |
| const runtime::PackedFunc& f_postorder, |
| const std::unordered_set<uint32_t>& only_enable) |
| : f_preorder_(f_preorder), |
| f_postorder_(f_postorder), |
| only_enable_(only_enable) { |
| } |
| Stmt Mutate(Stmt stmt) final { |
| return MutateInternal<Stmt>(stmt); |
| } |
| Expr Mutate(Expr expr) final { |
| return MutateInternal<Expr>(expr); |
| } |
| |
| private: |
| template<typename T> |
| T MutateInternal(T node) { |
| if (only_enable_.size() && |
| !only_enable_.count(node->type_index())) { |
| return IRMutator::Mutate(node); |
| } |
| if (f_preorder_ != nullptr) { |
| T pre = f_preorder_(node); |
| if (pre.defined()) return pre; |
| } |
| node = IRMutator::Mutate(node); |
| if (f_postorder_ != nullptr) { |
| T post = f_postorder_(node); |
| if (post.defined()) return post; |
| } |
| return node; |
| } |
| // The functions |
| const runtime::PackedFunc& f_preorder_; |
| const runtime::PackedFunc& f_postorder_; |
| // type indices enabled. |
| const std::unordered_set<uint32_t>& only_enable_; |
| }; |
| |
| Stmt IRTransform(const Stmt& ir_node, |
| const runtime::PackedFunc& f_preorder, |
| const runtime::PackedFunc& f_postorder, |
| const Array<Expr>& only_enable) { |
| std::unordered_set<uint32_t> only_type_index; |
| for (Expr s : only_enable) { |
| only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str())); |
| } |
| return IRTransformer(f_preorder, f_postorder, only_type_index) |
| .Mutate(ir_node); |
| } |
| |
| IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) |
| static FMutateExpr inst; return inst; |
| } |
| |
| IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) |
| static FMutateStmt inst; return inst; |
| } |
| |
| inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { |
| return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); |
| } |
| |
| inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { |
| std::vector<IterVar> new_dom(rdom.size()); |
| bool changed = false; |
| for (size_t i = 0; i < rdom.size(); i++) { |
| IterVar v = rdom[i]; |
| Range r = v->dom; |
| Expr new_min = m->Mutate(r->min); |
| Expr new_extent = m->Mutate(r->extent); |
| if (!r->min.same_as(new_min)) changed = true; |
| if (!r->extent.same_as(new_extent)) changed = true; |
| new_dom[i] = IterVarNode::make( |
| Range::make_by_min_extent(new_min, new_extent), |
| v->var, v->iter_type, v->thread_tag); |
| } |
| if (!changed) { |
| return rdom; |
| } else { |
| return Array<IterVar>(new_dom); |
| } |
| } |
| |
| |
| // Mutate Stmt |
| |
| #define DISPATCH_TO_MUTATE_STMT(OP) \ |
| set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \ |
| return m->Mutate_(op, s); \ |
| }) |
| |
| Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { |
| Expr value = this->Mutate(op->value); |
| Stmt body = this->Mutate(op->body); |
| if (value.same_as(op->value) && |
| body.same_as(op->body)) { |
| return s; |
| } else { |
| return AttrStmt::make(op->node, op->attr_key, value, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { |
| Expr value = this->Mutate(op->value); |
| Stmt body = this->Mutate(op->body); |
| if (value.same_as(op->value) && |
| body.same_as(op->body)) { |
| return s; |
| } else { |
| return LetStmt::make(op->var, value, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { |
| Expr min = this->Mutate(op->min); |
| Expr extent = this->Mutate(op->extent); |
| Stmt body = this->Mutate(op->body); |
| if (min.same_as(op->min) && |
| extent.same_as(op->extent) && |
| body.same_as(op->body)) { |
| return s; |
| } else { |
| return For::make( |
| op->loop_var, min, extent, op->for_type, op->device_api, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { |
| IRMutator* m = this; |
| std::vector<Expr> new_extents; |
| bool all_extents_unmodified = true; |
| for (size_t i = 0; i < op->extents.size(); i++) { |
| new_extents.push_back(m->Mutate(op->extents[i])); |
| all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); |
| } |
| Stmt body = m->Mutate(op->body); |
| Expr condition = m->Mutate(op->condition); |
| Expr new_expr; |
| if (op->new_expr.defined()) { |
| new_expr = m->Mutate(op->new_expr); |
| } |
| if (all_extents_unmodified && |
| body.same_as(op->body) && |
| condition.same_as(op->condition) && |
| new_expr.same_as(op->new_expr)) { |
| return s; |
| } else { |
| return Allocate::make( |
| op->buffer_var, op->type, |
| new_extents, condition, body, |
| new_expr, op->free_function); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { |
| Expr condition = this->Mutate(op->condition); |
| Stmt then_case = this->Mutate(op->then_case); |
| Stmt else_case; |
| if (op->else_case.defined()) { |
| else_case = this->Mutate(op->else_case); |
| } |
| if (condition.same_as(op->condition) && |
| then_case.same_as(op->then_case) && |
| else_case.same_as(op->else_case)) { |
| return s; |
| } else { |
| return IfThenElse::make(condition, then_case, else_case); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { |
| Expr value = this->Mutate(op->value); |
| Expr index = this->Mutate(op->index); |
| Expr pred = this->Mutate(op->predicate); |
| if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) { |
| return s; |
| } else { |
| return Store::make(op->buffer_var, value, index, pred); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { |
| auto new_args = MutateArray(op->args, this); |
| auto new_value = this->Mutate(op->value); |
| if (op->args.same_as(new_args) && op->value.same_as(new_value)) { |
| return s; |
| } else { |
| return Provide::make(op->func, op->value_index, new_value, new_args); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { |
| IRMutator* m = this; |
| HalideIR::Internal::Region new_bounds; |
| bool bounds_changed = false; |
| |
| // Mutate the bounds |
| for (size_t i = 0; i < op->bounds.size(); i++) { |
| Expr old_min = op->bounds[i]->min; |
| Expr old_extent = op->bounds[i]->extent; |
| Expr new_min = m->Mutate(old_min); |
| Expr new_extent = m->Mutate(old_extent); |
| if (!new_min.same_as(old_min)) bounds_changed = true; |
| if (!new_extent.same_as(old_extent)) bounds_changed = true; |
| new_bounds.push_back( |
| Range::make_by_min_extent(new_min, new_extent)); |
| } |
| |
| Stmt body = m->Mutate(op->body); |
| Expr condition = m->Mutate(op->condition); |
| if (!bounds_changed && |
| body.same_as(op->body) && |
| condition.same_as(op->condition)) { |
| return s; |
| } else { |
| return Realize::make(op->func, op->value_index, |
| op->type, new_bounds, |
| condition, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { |
| IRMutator* m = this; |
| HalideIR::Internal::Region new_bounds; |
| bool bounds_changed = false; |
| |
| // Mutate the bounds |
| for (size_t i = 0; i < op->bounds.size(); i++) { |
| Expr old_min = op->bounds[i]->min; |
| Expr old_extent = op->bounds[i]->extent; |
| Expr new_min = m->Mutate(old_min); |
| Expr new_extent = m->Mutate(old_extent); |
| if (!new_min.same_as(old_min)) bounds_changed = true; |
| if (!new_extent.same_as(old_extent)) bounds_changed = true; |
| new_bounds.push_back( |
| Range::make_by_min_extent(new_min, new_extent)); |
| } |
| |
| if (!bounds_changed) { |
| return s; |
| } else { |
| return Prefetch::make(op->func, op->value_index, |
| op->type, new_bounds); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { |
| Stmt first = this->Mutate(op->first); |
| Stmt rest = this->Mutate(op->rest); |
| if (first.same_as(op->first) && |
| rest.same_as(op->rest)) { |
| return s; |
| } else { |
| return Block::make(first, rest); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { |
| Expr condition = this->Mutate(op->condition); |
| Expr message = this->Mutate(op->message); |
| Stmt body = this->Mutate(op->body); |
| |
| if (condition.same_as(op->condition) && |
| message.same_as(op->message) && |
| body.same_as(op->body)) { |
| return s; |
| } else { |
| return AssertStmt::make(condition, message, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { |
| Stmt body = this->Mutate(op->body); |
| if (body.same_as(op->body)) { |
| return s; |
| } else { |
| return ProducerConsumer::make(op->func, op->is_producer, body); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { |
| Expr v = this->Mutate(op->value); |
| if (v.same_as(op->value)) { |
| return s; |
| } else { |
| return Evaluate::make(v); |
| } |
| } |
| |
| Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { |
| return s; |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) |
| .DISPATCH_TO_MUTATE_STMT(LetStmt) |
| .DISPATCH_TO_MUTATE_STMT(AttrStmt) |
| .DISPATCH_TO_MUTATE_STMT(IfThenElse) |
| .DISPATCH_TO_MUTATE_STMT(For) |
| .DISPATCH_TO_MUTATE_STMT(Allocate) |
| .DISPATCH_TO_MUTATE_STMT(Store) |
| .DISPATCH_TO_MUTATE_STMT(Free) |
| .DISPATCH_TO_MUTATE_STMT(AssertStmt) |
| .DISPATCH_TO_MUTATE_STMT(ProducerConsumer) |
| .DISPATCH_TO_MUTATE_STMT(Provide) |
| .DISPATCH_TO_MUTATE_STMT(Realize) |
| .DISPATCH_TO_MUTATE_STMT(Block) |
| .DISPATCH_TO_MUTATE_STMT(Evaluate) |
| .DISPATCH_TO_MUTATE_STMT(Prefetch); |
| |
| |
| // Mutate Expr |
| |
| #define DISPATCH_TO_MUTATE_EXPR(OP) \ |
| set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \ |
| return m->Mutate_(op, e); \ |
| }) |
| |
| Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { |
| return e; |
| } |
| |
| Expr IRMutator::Mutate_(const Load *op, const Expr& e) { |
| Expr index = this->Mutate(op->index); |
| Expr pred = this->Mutate(op->predicate); |
| if (index.same_as(op->index) && pred.same_as(op->predicate)) { |
| return e; |
| } else { |
| return Load::make(op->type, op->buffer_var, index, pred); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Let *op, const Expr& e) { |
| Expr value = this->Mutate(op->value); |
| Expr body = this->Mutate(op->body); |
| if (value.same_as(op->value) && |
| body.same_as(op->body)) { |
| return e; |
| } else { |
| return Let::make(op->var, value, body); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Call* op, const Expr& e) { |
| auto new_args = MutateArray(op->args, this); |
| if (op->args.same_as(new_args)) { |
| return e; |
| } else { |
| return Call::make(op->type, op->name, new_args, op->call_type, |
| op->func, op->value_index); |
| } |
| } |
| |
| #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ |
| Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \ |
| Expr a = this->Mutate(op->a); \ |
| Expr b = this->Mutate(op->b); \ |
| if (a.same_as(op->a) && \ |
| b.same_as(op->b)) { \ |
| return e; \ |
| } else { \ |
| return OP::make(a, b); \ |
| } \ |
| } |
| |
| DEFINE_BIOP_EXPR_MUTATE_(Add) |
| DEFINE_BIOP_EXPR_MUTATE_(Sub) |
| DEFINE_BIOP_EXPR_MUTATE_(Mul) |
| DEFINE_BIOP_EXPR_MUTATE_(Div) |
| DEFINE_BIOP_EXPR_MUTATE_(Mod) |
| DEFINE_BIOP_EXPR_MUTATE_(Min) |
| DEFINE_BIOP_EXPR_MUTATE_(Max) |
| DEFINE_BIOP_EXPR_MUTATE_(EQ) |
| DEFINE_BIOP_EXPR_MUTATE_(NE) |
| DEFINE_BIOP_EXPR_MUTATE_(LT) |
| DEFINE_BIOP_EXPR_MUTATE_(LE) |
| DEFINE_BIOP_EXPR_MUTATE_(GT) |
| DEFINE_BIOP_EXPR_MUTATE_(GE) |
| DEFINE_BIOP_EXPR_MUTATE_(And) |
| DEFINE_BIOP_EXPR_MUTATE_(Or) |
| |
| Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { |
| Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); |
| Array<Expr> new_source = MutateArray(op->source, this); |
| Expr new_cond = this->Mutate(op->condition); |
| if (op->axis.same_as(new_axis) && |
| op->source.same_as(new_source) && |
| op->condition.same_as(new_cond)) { |
| return e; |
| } else { |
| return Reduce::make( |
| op->combiner, new_source, new_axis, new_cond, op->value_index); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { |
| Expr value = this->Mutate(op->value); |
| if (value.same_as(op->value)) { |
| return e; |
| } else { |
| return Cast::make(op->type, value); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Not *op, const Expr& e) { |
| Expr a = this->Mutate(op->a); |
| if (a.same_as(op->a)) { |
| return e; |
| } else { |
| return Not::make(a); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Select *op, const Expr& e) { |
| Expr cond = this->Mutate(op->condition); |
| Expr t = this->Mutate(op->true_value); |
| Expr f = this->Mutate(op->false_value); |
| if (cond.same_as(op->condition) && |
| t.same_as(op->true_value) && |
| f.same_as(op->false_value)) { |
| return e; |
| } else { |
| return Select::make(cond, t, f); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { |
| Expr base = this->Mutate(op->base); |
| Expr stride = this->Mutate(op->stride); |
| if (base.same_as(op->base) && |
| stride.same_as(op->stride)) { |
| return e; |
| } else { |
| return Ramp::make(base, stride, op->lanes); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { |
| Expr value = this->Mutate(op->value); |
| if (value.same_as(op->value)) { |
| return e; |
| } else { |
| return Broadcast::make(value, op->lanes); |
| } |
| } |
| |
| Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { |
| auto new_vec = MutateArray(op->vectors, this); |
| if (new_vec.same_as(op->vectors)) { |
| return e; |
| } else { |
| return Shuffle::make(new_vec, op->indices); |
| } |
| } |
| |
| #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ |
| Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ |
| return e; \ |
| } |
| |
| DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) |
| DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) |
| DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) |
| DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) |
| |
| TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) |
| .DISPATCH_TO_MUTATE_EXPR(Variable) |
| .DISPATCH_TO_MUTATE_EXPR(Load) |
| .DISPATCH_TO_MUTATE_EXPR(Let) |
| .DISPATCH_TO_MUTATE_EXPR(Call) |
| .DISPATCH_TO_MUTATE_EXPR(Add) |
| .DISPATCH_TO_MUTATE_EXPR(Sub) |
| .DISPATCH_TO_MUTATE_EXPR(Mul) |
| .DISPATCH_TO_MUTATE_EXPR(Div) |
| .DISPATCH_TO_MUTATE_EXPR(Mod) |
| .DISPATCH_TO_MUTATE_EXPR(Min) |
| .DISPATCH_TO_MUTATE_EXPR(Max) |
| .DISPATCH_TO_MUTATE_EXPR(EQ) |
| .DISPATCH_TO_MUTATE_EXPR(NE) |
| .DISPATCH_TO_MUTATE_EXPR(LT) |
| .DISPATCH_TO_MUTATE_EXPR(LE) |
| .DISPATCH_TO_MUTATE_EXPR(GT) |
| .DISPATCH_TO_MUTATE_EXPR(GE) |
| .DISPATCH_TO_MUTATE_EXPR(And) |
| .DISPATCH_TO_MUTATE_EXPR(Or) |
| .DISPATCH_TO_MUTATE_EXPR(Reduce) |
| .DISPATCH_TO_MUTATE_EXPR(Cast) |
| .DISPATCH_TO_MUTATE_EXPR(Not) |
| .DISPATCH_TO_MUTATE_EXPR(Select) |
| .DISPATCH_TO_MUTATE_EXPR(Ramp) |
| .DISPATCH_TO_MUTATE_EXPR(Broadcast) |
| .DISPATCH_TO_MUTATE_EXPR(IntImm) |
| .DISPATCH_TO_MUTATE_EXPR(UIntImm) |
| .DISPATCH_TO_MUTATE_EXPR(FloatImm) |
| .DISPATCH_TO_MUTATE_EXPR(StringImm) |
| .DISPATCH_TO_MUTATE_EXPR(Shuffle); |
| |
| } // namespace ir |
| } // namespace tvm |