| /*! |
| * Copyright (c) 2017 by Contributors |
| * \file loop_partition.cc |
| */ |
| #include <tvm/ir.h> |
| #include <tvm/ir_visitor.h> |
| #include <tvm/ir_mutator.h> |
| #include <tvm/ir_pass.h> |
| #include <tvm/arithmetic.h> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include "../arithmetic/int_set_internal.h" |
| #include "../runtime/thread_storage_scope.h" |
| |
| namespace tvm { |
| namespace ir { |
| |
| using arith::IntSet; |
| using arith::DeduceBound; |
| using arith::Intersect; |
| |
| // a partition means the expr is equal to true in the interval |
| struct Partition { |
| Expr expr; |
| IntSet interval; |
| }; |
| |
| bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { |
| bool success = false; |
| PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { |
| if (const Variable* v = node.as<Variable>()) { |
| if (vars.count(v)) { |
| success = true; |
| return; |
| } |
| } |
| }); |
| return success; |
| } |
| |
| // Select potential candidate IRs that can be partitioned. |
| // Rule: |
| // - the range should not be const |
| // - there exist a condition expression in the scope that use the var |
| class CandidateSelector final : public IRVisitor { |
| public: |
| using VarIsUsed = bool; |
| explicit CandidateSelector(bool split_const_loop) |
| : split_const_loop_(split_const_loop) {} |
| |
| void Visit_(const For* op) { |
| // partition const loop when sets split_const_loop_ |
| if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) { |
| const Variable* var = op->loop_var.get(); |
| record_.insert({var, false}); |
| IRVisitor::Visit_(op); |
| if (record_.at(var) && !no_split_) { |
| candidates.insert(op); |
| } |
| record_.erase(var); |
| } else { |
| IRVisitor::Visit_(op); |
| } |
| } |
| |
| void Visit_(const AttrStmt* op) { |
| if (op->attr_key == attr::thread_extent) { |
| const IterVarNode *iv = op->node.as<IterVarNode>(); |
| CHECK(iv); |
| Var var = iv->var; |
| runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); |
| if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) { |
| record_.insert({var.get(), false}); |
| IRVisitor::Visit_(op); |
| if (record_.at(var.get()) && !no_split_) { |
| candidates.insert(op); |
| } |
| record_.erase(var.get()); |
| return; |
| } |
| } |
| IRVisitor::Visit_(op); |
| } |
| |
| void Visit_(const Block* op) { |
| bool temp = no_split_; |
| this->Visit(op->first); |
| // erase the no split state of first when visit rest. |
| std::swap(temp, no_split_); |
| this->Visit(op->rest); |
| // restore the no split flag. |
| no_split_ = no_split_ || temp; |
| } |
| |
| void Visit_(const Call* op) { |
| if (op->is_intrinsic(Call::likely)) { |
| in_likely_ = true; |
| IRVisitor::Visit_(op); |
| in_likely_ = false; |
| } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { |
| // no split if the body contains allreduce. |
| no_split_ = true; |
| return; |
| } else { |
| IRVisitor::Visit_(op); |
| } |
| } |
| |
| void Visit_(const Variable* op) { |
| if (in_likely_ && record_.count(op)) { |
| record_.at(op) = true; |
| } |
| } |
| |
| std::unordered_set<const Node*> candidates; |
| |
| private: |
| bool in_likely_{false}; |
| bool no_split_{false}; |
| bool split_const_loop_{false}; |
| std::unordered_map<const Variable*, VarIsUsed> record_; |
| }; |
| |
| // Find valid partition for specific variable |
| class PartitionFinder : public IRVisitor { |
| public: |
| explicit PartitionFinder(VarExpr current_var, |
| const std::unordered_map<const Variable*, IntSet>& hint_map, |
| const std::unordered_map<const Variable*, IntSet>& relax_map) |
| : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { |
| for (const auto& kv : hint_map) { |
| out_vars_.insert(kv.first); |
| } |
| for (const auto& kv : relax_map) { |
| out_vars_.insert(kv.first); |
| } |
| } |
| |
| void Visit_(const For* op) { |
| if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; |
| |
| const Variable* var = op->loop_var.get(); |
| hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); |
| relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); |
| IRVisitor::Visit_(op); |
| relax_map_.erase(var); |
| hint_map_.erase(var); |
| } |
| |
| void Visit_(const AttrStmt* op) { |
| // handle thread_axis |
| if (op->attr_key == attr::thread_extent) { |
| const IterVarNode* thread_axis = op->node.as<IterVarNode>(); |
| CHECK(thread_axis); |
| const Variable* var = thread_axis->var.get(); |
| IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value)); |
| hint_map_.insert({var, dom}); |
| relax_map_.insert({var, dom}); |
| IRVisitor::Visit_(op); |
| relax_map_.erase(var); |
| hint_map_.erase(var); |
| } else { |
| IRVisitor::Visit_(op); |
| } |
| } |
| |
| void Visit_(const Call* op) { |
| if (op->is_intrinsic(Call::likely)) { |
| Expr cond = op->args[0]; |
| if (ExprUseVars(cond, |
| std::unordered_set<const Variable*>({current_var_.get()}))) { |
| IntSet interval = |
| DeduceBound(current_var_, cond, hint_map_, relax_map_); |
| if (!interval.is_nothing()) { |
| partitions[cond.get()] = Partition{cond, interval}; |
| } |
| } |
| } else { |
| IRVisitor::Visit_(op); |
| } |
| } |
| |
| std::unordered_map<const Node*, Partition> partitions; |
| |
| private: |
| VarExpr current_var_; |
| std::unordered_set<const Variable*> out_vars_; |
| std::unordered_map<const Variable*, IntSet> hint_map_; |
| std::unordered_map<const Variable*, IntSet> relax_map_; |
| }; |
| |
| // Eliminate the condition expressions by partitions |
| class ConditionEliminator : public IRMutator { |
| public: |
| explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps) |
| : ps_(ps) {} |
| |
| using IRMutator::Mutate; |
| Expr Mutate(Expr e) final { |
| if (ps_.count(e.get())) return Mutate(const_true()); |
| return IRMutator::Mutate(e); |
| } |
| |
| private: |
| const std::unordered_map<const Node*, Partition>& ps_; |
| }; |
| |
| |
| // Insert the partition branch at the innermost thread scope |
| class ThreadPartitionInserter : public IRMutator { |
| public: |
| explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps, |
| Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} |
| |
| Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { |
| if (op->attr_key == attr::thread_extent) { |
| innermost_thread_scope_ = true; |
| Stmt stmt = IRMutator::Mutate_(op, s); |
| // add branch code inside the innermost thread scope |
| if (innermost_thread_scope_) { |
| Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body); |
| Stmt body = IfThenElse::make(cond_, simplified_body, op->body); |
| Expr value = this->Mutate(op->value); |
| stmt = AttrStmt::make(op->node, op->attr_key, value, body); |
| } |
| innermost_thread_scope_ = false; |
| return stmt; |
| } else { |
| return IRMutator::Mutate_(op, s); |
| } |
| } |
| |
| private: |
| const std::unordered_map<const Node*, Partition>& ps_; |
| Expr cond_; |
| bool innermost_thread_scope_; |
| }; |
| |
| // Try to do partition at the candidate IRs |
| class LoopPartitioner : public IRMutator { |
| public: |
| explicit LoopPartitioner(bool split_const_loop) |
| : selector(CandidateSelector(split_const_loop)) {} |
| |
| Stmt VisitAndMutate(const Stmt& stmt) { |
| selector.Visit(stmt); |
| return Mutate(stmt); |
| } |
| |
| Stmt Mutate_(const For* op, const Stmt& stmt) { |
| if (selector.candidates.count(op)) { |
| Stmt s = TryPartition(op, stmt, op->loop_var, |
| op->min, op->min + op->extent - 1, op->body, false); |
| if (s.defined()) return s; |
| } |
| |
| // normal path when loop parittion fails |
| // normal loop variable can be put into hint map. |
| hint_map_.insert({op->loop_var.get(), |
| IntSet::interval(op->min, op->min + op->extent - 1)}); |
| Stmt res = IRMutator::Mutate_(op, stmt); |
| hint_map_.erase(op->loop_var.get()); |
| return res; |
| } |
| |
| Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { |
| if (op->attr_key != attr::thread_extent) { |
| return IRMutator::Mutate_(op, stmt); |
| } |
| |
| const IterVarNode *iv = op->node.as<IterVarNode>(); |
| CHECK(iv); |
| Var var = iv->var; |
| if (selector.candidates.count(op)) { |
| Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); |
| if (s.defined()) return s; |
| } |
| |
| // normal path when loop parittion fails. |
| runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); |
| Stmt res; |
| if (scope.rank == 1) { |
| // threadIdx should be put into relax map, in case of divergence. |
| relax_map_.insert({var.get(), |
| IntSet::interval(make_zero(var.type()), op->value - 1)}); |
| res = IRMutator::Mutate_(op, stmt); |
| relax_map_.erase(var.get()); |
| } else { |
| hint_map_.insert({var.get(), |
| IntSet::interval(make_zero(var.type()), op->value - 1)}); |
| res = IRMutator::Mutate_(op, stmt); |
| hint_map_.erase(var.get()); |
| } |
| return res; |
| } |
| |
| private: |
| Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var, |
| Expr min, Expr max, Stmt body, bool partition_thread_scope); |
| inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); |
| |
| /* Candidate IRs that may be partitioned potentially */ |
| std::unordered_map<const Variable*, IntSet> hint_map_; |
| std::unordered_map<const Variable*, IntSet> relax_map_; |
| CandidateSelector selector; |
| }; |
| |
| Stmt LoopPartitioner::TryPartition(const Node* node, |
| const Stmt& stmt, |
| VarExpr var, |
| Expr min, |
| Expr max, |
| Stmt body, |
| bool partition_thread_scope) { |
| PartitionFinder finder(var, hint_map_, relax_map_); |
| finder.Visit(body); |
| const auto& partitions = finder.partitions; |
| if (partitions.empty()) return Stmt(); |
| |
| Array<IntSet> sets; |
| // merge partitions (take their intersect) |
| for (const auto& kv : partitions) { |
| sets.push_back(kv.second.interval); |
| } |
| IntSet true_itrv = Intersect(sets); |
| |
| Expr body_begin; |
| Stmt pre_stmt; |
| arith::Interval true_itrv_i = true_itrv.as<arith::IntervalSet>()->i; |
| if (true_itrv_i.has_lower_bound()) { |
| body_begin = ir::Simplify(true_itrv.min()); |
| if (!can_prove(body_begin == min)) { |
| Expr cond = (body_begin - min >= 0); |
| if (!can_prove(cond)) { |
| LOG(WARNING) << "Cannot prove: " << cond |
| << ", when generating the pre doubt loop"; |
| body_begin = Max::make(body_begin, min); |
| } |
| // [min, body_begin) |
| if (!partition_thread_scope) { |
| Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); |
| pre_stmt = MakeFor(node, body_begin - min, pre_body); |
| } |
| } |
| } else { |
| body_begin = min; |
| } |
| |
| Expr post_doubt_begin; |
| Stmt post_stmt; |
| if (true_itrv_i.has_upper_bound()) { |
| post_doubt_begin = ir::Simplify(true_itrv.max() + 1); |
| if (!can_prove(true_itrv.max() == max)) { |
| // require the extent to be non-negative |
| Expr cond = (max - post_doubt_begin + 1 >= 0); |
| if (!can_prove(cond)) { |
| LOG(WARNING) << "Cannot prove: " << cond |
| << ", when generating the post doubt loop"; |
| post_doubt_begin = Min::make(post_doubt_begin, max); |
| } |
| // [post_doubt_begin, max] |
| if (!partition_thread_scope) { |
| Stmt post_body; |
| // If the loop is going from 0 to 1, replace the loop var with min value |
| if (as_const_int(max) && as_const_int(post_doubt_begin)) { |
| if (*as_const_int(max) == *as_const_int(post_doubt_begin)) { |
| post_body = Substitute(body, {{Var{var}, post_doubt_begin}}); |
| post_stmt = post_body; |
| } |
| } else { |
| post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); |
| post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); |
| } |
| } |
| } |
| } else { |
| post_doubt_begin = max + 1; |
| } |
| |
| Stmt s; |
| if (!partition_thread_scope) { |
| // [body_begin, post_doubt_begin) |
| Stmt simplified_body = ConditionEliminator(partitions).Mutate(body); |
| Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); |
| s = MakeFor(node, post_doubt_begin - body_begin, new_body); |
| |
| if (!(pre_stmt.defined() && post_stmt.defined())) s = VisitAndMutate(s); |
| if (pre_stmt.defined()) s = Block::make(pre_stmt, s); |
| if (post_stmt.defined()) { |
| if (as_const_int(max) && as_const_int(post_doubt_begin)) { |
| post_stmt = VisitAndMutate(post_stmt); |
| } |
| s = Block::make(s, post_stmt); |
| } |
| } else { |
| Expr cond = const_true(); |
| if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); |
| if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); |
| s = ThreadPartitionInserter(partitions, cond).Mutate(stmt); |
| } |
| s = ConvertSSA(s); |
| return s; |
| } |
| |
| inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { |
| const For *for_node = static_cast<const For*>(node); |
| CHECK(for_node); |
| return For::make(for_node->loop_var, 0, extent, |
| for_node->for_type, for_node->device_api, body); |
| } |
| |
| class RemoveLikelyTags : public IRMutator { |
| public: |
| using IRMutator::Mutate; |
| |
| Expr Mutate_(const Call *op, const Expr& e) { |
| if (op->is_intrinsic(Call::likely)) { |
| CHECK_EQ(op->args.size(), 1); |
| return IRMutator::Mutate(op->args[0]); |
| } else { |
| return IRMutator::Mutate_(op, e); |
| } |
| } |
| }; |
| |
| Stmt LoopPartition(Stmt stmt, bool split_const_loop) { |
| stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt); |
| stmt = RemoveLikelyTags().Mutate(stmt); |
| return stmt; |
| } |
| |
| } // namespace ir |
| } // namespace tvm |