blob: 23f41e1676a63db69ceb785eb4d0a1596ef374d8 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file
#include <tvm/arith/analyzer.h>
#include <tvm/arith/bound.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../arith/interval_set.h"
#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
namespace tvm {
namespace tir {
struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> {
bool partition_const_loop;
TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") {
TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false);
class LoopPartitionConfig : public Attrs {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopPartitionConfig, Attrs, LoopPartitionConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.LoopPartition", LoopPartitionConfig);
using arith::DeduceBound;
using arith::Intersect;
using arith::IntSet;
using PartitionKey = std::pair<PrimExpr, bool>;
struct PartitionKeyHash {
std::size_t operator()(PartitionKey const& k) const noexcept {
std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces)
std::size_t h2 = std::hash<bool>{}(k.second);
return h1 ^ h2;
struct PartitionKeyEqual {
bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
// NOLINTNEXTLINE(whitespace/braces)
return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
// Each mapping (cond, cond_value) -> interval represents the fact that
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;
using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
if (const VarNode* v =<VarNode>()) {
if (vars.count(v)) {
success = true;
return success;
// Select potential candidate IRs that can be partitioned.
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector final : public StmtExprVisitor {
using VarIsUsed = bool;
explicit CandidateSelector(bool partition_const_loop)
: partition_const_loop_(partition_const_loop) {}
void VisitStmt_(const ForNode* op) final {
// partition const loop when sets partition_const_loop_
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
const VarNode* var = op->loop_var.get();
record_.insert({var, false});
if ( && !no_split_) {
} else {
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* iv = op-><IterVarNode>();
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
record_.insert({var.get(), false});
if ( && !no_split_) {
void VisitStmt_(const SeqStmtNode* op) final {
bool init_no_split = no_split_;
for (Stmt stmt : op->seq) {
// erase the no split state of before visiting the next one.
bool temp = init_no_split;
std::swap(temp, no_split_);
// restore the no split flag.
no_split_ = no_split_ || temp;
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
in_likely_ = true;
in_likely_ = false;
} else if (op->op.same_as(builtin::tvm_thread_allreduce())) {
// no split if the body contains allreduce.
no_split_ = true;
} else {
void VisitExpr_(const VarNode* op) final {
if (in_likely_ && record_.count(op)) { = true;
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
bool in_likely_{false};
bool no_split_{false};
bool partition_const_loop_{false};
std::unordered_map<const VarNode*, VarIsUsed> record_;
// Populate partitions data structure, i.e., for a specific variable,
// find an interval in which each condition
// (currently, "likely" conditions) has fixed true or false value
class PartitionFinder : public StmtExprVisitor {
explicit PartitionFinder(Var current_var,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
for (const auto& kv : hint_map) {
for (const auto& kv : relax_map) {
void VisitStmt_(const ForNode* op) final {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
const VarNode* var = op->loop_var.get();
hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
void VisitStmt_(const AttrStmtNode* op) final {
// handle thread_axis
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op-><IterVarNode>();
const VarNode* var = thread_axis->var.get();
IntSet dom = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
} else {
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
PrimExpr cond = op->args[0];
if (ExprUseVars(cond, std::unordered_set<const VarNode*>({current_var_.get()}))) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond, false}] = interval;
} else {
Partition partitions;
PrimExpr InverseCond(const PrimExpr& cond) {
PrimExpr inverse_cond;
if (const LTNode* op =<LTNode>()) {
// a < b -> a >= b
inverse_cond = GE(op->a, op->b);
} else if (const GTNode* op =<GTNode>()) {
// a > b -> a <= b
inverse_cond = LE(op->a, op->b);
} else if (const LENode* op =<LENode>()) {
// a <= b -> a > b
inverse_cond = GT(op->a, op->b);
} else if (const GENode* op =<GENode>()) {
// a >= b -> a < b
inverse_cond = LT(op->a, op->b);
} else if (const EQNode* op =<EQNode>()) {
// a == b -> a != b
inverse_cond = NE(op->a, op->b);
// a != b -> a == b
} else if (const NENode* op =<NENode>()) {
inverse_cond = EQ(op->a, op->b);
return inverse_cond;
Var current_var_;
std::unordered_set<const VarNode*> out_vars_;
std::unordered_map<const VarNode*, IntSet> hint_map_;
std::unordered_map<const VarNode*, IntSet> relax_map_;
// Replace the set of conditions given by ps with cond_value (true or false)
class ConditionEliminator : public StmtExprMutator {
explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}
PrimExpr VisitExpr(const PrimExpr& e) final {
if (ps_.find(e) != ps_.end()) {
return VisitExpr(cond_value_ ? const_true() : const_false());
return StmtExprMutator::VisitExpr(e);
ExpressionSet ps_;
bool cond_value_;
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public StmtMutator {
explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
: ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
innermost_thread_scope_ = true;
Stmt stmt = StmtMutator::VisitStmt_(op);
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_)(op->body);
Stmt body = IfThenElse(cond_, simplified_body, op->body);
PrimExpr value = this->VisitExpr(op->value);
stmt = AttrStmt(op->node, op->attr_key, value, body);
innermost_thread_scope_ = false;
return stmt;
} else {
return StmtMutator::VisitStmt_(op);
const ExpressionSet& ps_;
PrimExpr cond_;
bool innermost_thread_scope_;
// Try to partition range of iteration variables in order to remove (some)
// likely conditions
class LoopPartitioner : public StmtMutator {
explicit LoopPartitioner(bool partition_const_loop)
: selector(CandidateSelector(partition_const_loop)) {}
Stmt VisitAndMutate(Stmt stmt) {
return operator()(std::move(stmt));
Stmt VisitStmt_(const ForNode* op) final {
auto fs = GetRef<Stmt>(op);
if (selector.candidates.count(fs)) {
Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
// normal path when loop partition fails
// normal loop variable can be put into hint map.
hint_map_.insert({op->loop_var.get(), IntSet::Interval(op->min, op->min + op->extent - 1)});
Stmt res = StmtMutator::VisitStmt_(op);
return res;
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key != attr::thread_extent) {
return StmtMutator::VisitStmt_(op);
const IterVarNode* iv = op-><IterVarNode>();
Var var = iv->var;
auto as = GetRef<Stmt>(op);
if (selector.candidates.count(as)) {
Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
// normal path when loop parittion fails.
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
Stmt res;
if (scope.rank == 1) {
// threadIdx should be put into relax map, in case of divergence.
relax_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
res = StmtMutator::VisitStmt_(op);
} else {
hint_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
res = StmtMutator::VisitStmt_(op);
return res;
Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope);
std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
const arith::IntervalSet& for_interval,
bool cond_value);
inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
/* Candidate IRs that may be partitioned potentially */
std::unordered_map<const VarNode*, IntSet> hint_map_;
std::unordered_map<const VarNode*, IntSet> relax_map_;
arith::Analyzer analyzer_;
CandidateSelector selector;
// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
Array<IntSet> sets;
ExpressionSet cond_set;
for (const auto& kv : partitions) {
if (kv.first.second == cond_value) {
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);
if (!intersection->IsEmpty()) {
IntSet interval = sets.empty() ? IntSet::Nothing() : Intersect(sets);
return std::make_pair(interval, cond_set);
* Tries to recursively partition the range of the variable (given by var) of
* the for loop (given by node and stmt) into a
* number of disjoint ranges such that in some ranges one or more predicates
* in the loopnest are provably true or false in each range. For example, given the
* following loop to partition:
* for (i = 0; i < 4; i++)
* for (j = 0; j < 10; j++)
* if (likely(i*10 + j < 36))
* A[10*i+j] = B[10*i+j]
* We first partition range of i, i.e., [0,3] into subranges [0,2] and [3,3] because the
* likely condition is always true for the first subrange but not always true for the
* second subrange. Therefore, we'll have
* for (i = 0; i < 3; i++)
* for (j = 0; j < 10; j++)
* if (likely(1))
* A[10*i+j] = B[10*i+j]
* for (i = 0; i < 1; i++)
* for (j = 0; j < 10; j++)
* if (likely((i+3)*10 + j < 36))
* A[10*(i+3)+j] = B[10*(i+3)+j]
* Which is simplified as:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 10; j++)
* A[10*i+j] = B[10*i+j]
* for (j = 0; j < 10; j++) // loopnest 1
* if (likely(j < 6))
* A[30+j] = B[30+j]
* Now, we recursively partition j in loopnest 1 into subranges [0,5] and [6,9] where the
* condition is true for the first subrange and now always true for the second subrange.
* for (j = 0; j < 6; j++)
* if (likely(1))
* A[30+j] = B[30+j]
* for (j = 0; j < 4; j++) // loop 2
* if (likely(j < 0))
* A[36+j] = B[36+j]
* Finally we recursively partition loop 2 above into subrange [0,3] where the
* condition is false and empty interval where the condition is not false,
* therefore we generate
* for (j = 0; j < 4; j++)
* if (likely(0))
* A[36+j] = B[36+j]
* which will eventually be simplified to empty code. And because only one loop was generated
* from loop 2 we stop recursing.
Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
PartitionFinder finder(var, hint_map_, relax_map_);
if (finder.partitions.empty()) return Stmt();
arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
ExpressionSet cond_set;
// find an interval in which all conditions on var are true
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, true);
if (middle_interval.IsNothing()) {
// if such interval doesn't exist, find an interval in which all
// conditions on var are false
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, false);
if (middle_interval.IsNothing())
// we couldn't find an interval in which the conditions are provably true or false
// Therefore, we can't partition the loop based on those conds
return Stmt();
cond_value = false;
} else {
cond_value = true;
IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
// subrange is prefixed with pre- (post- resp.)
// Calculating pre-subrange and generating code for it.
// pre-subrange = [min, body_begin)
PrimExpr body_begin;
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop";
body_begin = Max(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
} else {
body_begin = min;
// Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max+1)
PrimExpr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop";
post_doubt_begin = Min(post_doubt_begin, max + 1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
} else {
post_doubt_begin = max + 1;
Stmt s;
// Generating code for middle subrange
if (!partition_thread_scope) {
Stmt mid_stmt;
if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
// Recurse for each non-empty subrange only if there are at least
// two non-empty subranges
if (pre_stmt.defined() || post_stmt.defined()) {
mid_stmt = VisitAndMutate(mid_stmt);
if (pre_stmt.defined() && pre_stmt_recurse) {
pre_stmt = VisitAndMutate(pre_stmt);
if (post_stmt.defined() && post_stmt_recurse) {
post_stmt = VisitAndMutate(post_stmt);
s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
} else {
PrimExpr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond)(stmt);
s = ConvertSSA(s);
return s;
inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) {
const ForNode* for_node = static_cast<const ForNode*>(node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type,
for_node->device_api, body);
class RemoveLikelyTags : public StmtExprMutator {
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
CHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
return StmtExprMutator::VisitExpr_(op);
Stmt LoopPartition(Stmt stmt, bool partition_const_loop) {
stmt = LoopPartitioner(partition_const_loop).VisitAndMutate(std::move(stmt));
stmt = RemoveLikelyTags()(std::move(stmt));
return stmt;
namespace transform {
Pass LoopPartition() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<LoopPartitionConfig>("tir.LoopPartition");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop);
return f;
return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
} // namespace transform
} // namespace tir
} // namespace tvm