|  | /*! | 
|  | *  Copyright (c) 2017 by Contributors | 
|  | *  Loop unrolling as in Halide pipeline. | 
|  | * \file unroll_loop.cc | 
|  | */ | 
|  | // Unrolls the loop as in Halide pipeline. | 
|  | #include <tvm/ir.h> | 
|  | #include <tvm/ir_pass.h> | 
|  | #include <tvm/ir_mutator.h> | 
|  | #include <unordered_set> | 
|  | #include <unordered_map> | 
|  | #include <vector> | 
|  | #include "../arithmetic/compute_expr.h" | 
|  |  | 
|  | namespace tvm { | 
|  | namespace ir { | 
|  |  | 
|  | class LoopUnroller : public IRMutator { | 
|  | public: | 
|  | explicit LoopUnroller(int auto_max_step, | 
|  | int auto_max_depth, | 
|  | int auto_max_extent, | 
|  | bool explicit_unroll) | 
|  | : auto_max_step_(auto_max_step), | 
|  | auto_max_depth_(auto_max_depth), | 
|  | auto_max_extent_(auto_max_extent), | 
|  | explicit_unroll_(explicit_unroll) { | 
|  | } | 
|  |  | 
|  | Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { | 
|  | if (op->attr_key == "pragma_auto_unroll_max_step") { | 
|  | int value = 0; | 
|  | CHECK(arith::GetConstInt(op->value, &value)); | 
|  | std::swap(value, auto_max_step_); | 
|  | Stmt ret = this->Mutate(op->body); | 
|  | std::swap(value, auto_max_step_); | 
|  | return ret; | 
|  | } else if (op->attr_key == "pragma_unroll_explicit") { | 
|  | int value = 0; | 
|  | CHECK(arith::GetConstInt(op->value, &value)); | 
|  | bool explicit_unroll = value; | 
|  | std::swap(explicit_unroll, explicit_unroll_); | 
|  | Stmt ret = this->Mutate(op->body); | 
|  | std::swap(explicit_unroll, explicit_unroll_); | 
|  | return ret; | 
|  | } else { | 
|  | return IRMutator::Mutate_(op, stmt); | 
|  | } | 
|  | } | 
|  |  | 
|  | Stmt Mutate_(const For* op, const Stmt& s) { | 
|  | Stmt stmt = IRMutator::Mutate_(op, s); | 
|  | op = stmt.as<For>(); | 
|  | int value = GetExtent(op); | 
|  | // condition for auto unroll | 
|  | bool auto_unroll = ( | 
|  | op->for_type == ForType::Serial && | 
|  | value >= 0 && | 
|  | normal_loop_depth_ == 0 && | 
|  | unroll_depth_ <= auto_max_depth_); | 
|  |  | 
|  | auto_unroll = auto_unroll && ( | 
|  | value * step_count_ <= auto_max_step_|| | 
|  | value <= auto_max_extent_); | 
|  |  | 
|  | if (op->for_type == ForType::Unrolled) { | 
|  | CHECK_GE(value, 0) | 
|  | << "Cannot unroll non-constant loop"; | 
|  | auto_unroll = true; | 
|  | } | 
|  |  | 
|  | if (auto_unroll) { | 
|  | step_count_  *=  value; | 
|  | unroll_depth_ += 1; | 
|  | } else { | 
|  | normal_loop_depth_ += 1; | 
|  | } | 
|  |  | 
|  | if ((auto_unroll && explicit_unroll_) || | 
|  | // unroll loops with extent = 1, no matter how many steps in body | 
|  | (value <= auto_max_extent_ && auto_max_extent_ == 1)) { | 
|  | return Unroll(op); | 
|  | } else { | 
|  | if (auto_unroll) { | 
|  | if (op->for_type != ForType::Unrolled) { | 
|  | return For::make( | 
|  | op->loop_var, op->min, op->extent, | 
|  | ForType::Unrolled, op->device_api, op->body); | 
|  | } | 
|  | } | 
|  | return stmt; | 
|  | } | 
|  | } | 
|  |  | 
|  | Stmt Mutate_(const Store* op, const Stmt& stmt) final { | 
|  | ++step_count_; | 
|  | return IRMutator::Mutate_(op, stmt); | 
|  | } | 
|  |  | 
|  | Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final { | 
|  | ++step_count_; | 
|  | return IRMutator::Mutate_(op, stmt); | 
|  | } | 
|  |  | 
|  | Stmt Mutate_(const Block* op, const Stmt& stmt) final { | 
|  | Stmt first = this->Mutate(op->first); | 
|  | // cleanup state | 
|  | int step_count = step_count_; | 
|  | int unroll_depth = unroll_depth_; | 
|  | int normal_loop_depth = normal_loop_depth_; | 
|  | step_count_ = 0; | 
|  | unroll_depth_ = 0; | 
|  | normal_loop_depth_ = 0; | 
|  | // work on rest part | 
|  | Stmt rest = this->Mutate(op->rest); | 
|  | step_count_ += step_count; | 
|  | normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); | 
|  | unroll_depth_ = std::max(unroll_depth_, unroll_depth); | 
|  | if (first.same_as(op->first) && | 
|  | rest.same_as(op->rest)) { | 
|  | return stmt; | 
|  | } else { | 
|  | return Block::make(first, rest); | 
|  | } | 
|  | } | 
|  |  | 
|  | Stmt Unroll(const For* op) { | 
|  | using arith::ComputeExpr; | 
|  | int value = GetExtent(op); | 
|  | // For loop must have a constant integer extent | 
|  | CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; | 
|  | if (value == 0) return Evaluate::make(0); | 
|  | Stmt body = op->body; | 
|  | Map<Var, Expr> vmap; | 
|  | Stmt unrolled; | 
|  | for (int i = 0; i < value; ++i) { | 
|  | Var lv(op->loop_var.node_); | 
|  | vmap.Set(lv, | 
|  | ComputeExpr<Add>( | 
|  | op->min, make_const(op->loop_var.type(), i))); | 
|  | Stmt step = Substitute(body, vmap); | 
|  | if (unrolled.defined()) { | 
|  | unrolled = Block::make(unrolled, step); | 
|  | } else { | 
|  | unrolled = step; | 
|  | } | 
|  | } | 
|  | return unrolled; | 
|  | } | 
|  |  | 
|  | private: | 
|  | // returns the extent of the loop if it's a constant integer, otherwise return -1 | 
|  | int GetExtent(const For* op) { | 
|  | // constant folding. | 
|  | Expr extent = ir::Simplify(op->extent); | 
|  | const IntImm  *v1 = extent.as<IntImm>(); | 
|  | const UIntImm *v2 = extent.as<UIntImm>(); | 
|  | int value = -1; | 
|  | if (v1 != nullptr) { | 
|  | value = static_cast<int>(v1->value); | 
|  | } | 
|  | if (v2 != nullptr) { | 
|  | value = static_cast<int>(v2->value); | 
|  | } | 
|  | return value; | 
|  | } | 
|  |  | 
|  | // maximum number of step to perform auto unroll. | 
|  | int auto_max_step_; | 
|  | int auto_max_depth_; | 
|  | // max extent of loop to auto unroll | 
|  | // this not not count the total steps, only count the number of loops | 
|  | int auto_max_extent_; | 
|  | bool explicit_unroll_; | 
|  | // Number of normal loops in scope | 
|  | int normal_loop_depth_{0}; | 
|  | // number of unrolled cases in current scope. | 
|  | int unroll_depth_{0}; | 
|  | // Number of total steps unrolled | 
|  | int step_count_{0}; | 
|  | }; | 
|  |  | 
|  |  | 
|  | Stmt UnrollLoop(Stmt stmt, | 
|  | int auto_max_step, | 
|  | int auto_max_depth, | 
|  | int auto_max_extent, | 
|  | bool explicit_unroll) { | 
|  | Stmt ret = LoopUnroller( | 
|  | auto_max_step, | 
|  | auto_max_depth, | 
|  | auto_max_extent, | 
|  | explicit_unroll).Mutate(stmt); | 
|  | if (!ret.same_as(stmt)) { | 
|  | return ConvertSSA(ret); | 
|  | } else { | 
|  | return ret; | 
|  | } | 
|  | } | 
|  |  | 
|  | Stmt UnrollLoopExplicitly(Stmt stmt) { | 
|  | const For* op = stmt.as<For>(); | 
|  | if (!op) { | 
|  | LOG(FATAL) << "attempted to unroll a non-loop statement"; | 
|  | } | 
|  | return LoopUnroller(0, 0, 0, false).Unroll(op); | 
|  | } | 
|  |  | 
|  | }  // namespace ir | 
|  | }  // namespace tvm |