blob: d5405790a15aa7bd0c928620172cf5383590cac6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_virtual_thread.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "ir_util.h"
namespace tvm {
namespace tir {
// If expression is touched by var.
class ExprTouched final : public StmtExprVisitor {
public:
explicit ExprTouched(const std::unordered_set<const VarNode*>& touched, bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void VisitExpr(const PrimExpr& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitStmt(const Stmt& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const LoadNode* op) final {
HandleUseVar(op->buffer_var.get());
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
CHECK(rw_mask);
// read
if (rw_mask->value & 1) {
HandleUseVar(buffer_var);
}
if (rw_mask->value & 2) {
HandleWriteVar(buffer_var);
}
this->VisitExpr(op->args[2]);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
void HandleUseVar(const VarNode* var) {
auto it = touched_var_.find(var);
if (it != touched_var_.end()) {
expr_touched_ = true;
}
// rember the used vars
// in case the var get touched later in a loop.
if (!expr_touched_) {
used_vars_.push_back(var);
}
}
void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); }
// the fields.
bool expr_touched_{false};
std::vector<const VarNode*> used_vars_;
std::vector<const VarNode*> write_vars_;
const std::unordered_set<const VarNode*>& touched_var_;
bool check_write_;
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public StmtVisitor {
public:
void VisitStmt_(const LetStmtNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
Record(op->var.get(), tc);
this->VisitStmt(op->body);
}
void VisitStmt_(const StoreNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
tc(op->index);
Record(op->buffer_var.get(), tc);
}
void VisitStmt_(const ForNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->min);
tc(op->extent);
Record(op->loop_var.get(), tc);
this->VisitStmt(op->body);
}
// external function call
void VisitStmt_(const EvaluateNode* op) final {
ExprTouched tc(touched_var_, true);
tc(op->value);
for (const VarNode* var : tc.write_vars_) {
Record(var, tc);
}
}
void VisitStmt_(const AllocateNode* op) final {
ExprTouched tc(touched_var_, false);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc(op->extents[i]);
}
tc.VisitExpr(op->condition);
Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body);
}
void Record(const VarNode* var, const ExprTouched& tc) {
if (touched_var_.count(var)) return;
if (tc.expr_touched_) {
touched_var_.insert(var);
} else {
for (const VarNode* r : tc.used_vars_) {
if (r != var) {
affect_[r].push_back(var);
}
}
}
}
std::unordered_set<const VarNode*> TouchedVar(const Stmt& stmt, const VarNode* var) {
touched_var_.insert(var);
this->VisitStmt(stmt);
// do a DFS to push affect around dependency.
std::vector<const VarNode*> pending(touched_var_.begin(), touched_var_.end());
while (!pending.empty()) {
const VarNode* v = pending.back();
pending.pop_back();
for (const VarNode* r : affect_[v]) {
if (!touched_var_.count(r)) {
touched_var_.insert(r);
pending.push_back(r);
}
}
}
return std::move(touched_var_);
}
private:
// Whether variable is touched by the thread variable.
std::unordered_set<const VarNode*> touched_var_;
// x -> all the buffers x read from
std::unordered_map<const VarNode*, std::vector<const VarNode*> > affect_;
};
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public StmtExprMutator {
public:
// constructor
VTInjector(Var var, int num_threads, const std::unordered_set<const VarNode*>& touched_var,
bool allow_share)
: var_(var),
num_threads_(num_threads),
touched_var_(touched_var),
allow_share_(allow_share) {}
// Inject VTLoop when needed.
Stmt VisitStmt(const Stmt& s) final {
CHECK(!visit_touched_var_);
auto stmt = StmtExprMutator::VisitStmt(s);
if (visit_touched_var_ || trigger_base_inject_) {
if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
}
visit_touched_var_ = false;
trigger_base_inject_ = false;
}
return stmt;
}
// Variable
PrimExpr VisitExpr_(const VarNode* op) final {
CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
return GetRef<PrimExpr>(op);
}
PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
return index + var_ * alloc_extent;
}
// Load
PrimExpr VisitExpr_(const LoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate);
} else {
return expr;
}
}
// Expression.
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
visit_touched_var_ = true;
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
} else if (op->op.same_as(builtin::tvm_context_id())) {
return allow_share_ ? GetRef<PrimExpr>(op) : var_;
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const EvaluateNode* op) final {
trigger_base_inject_ = !allow_share_;
return StmtExprMutator::VisitStmt_(op);
}
// Store
Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
trigger_base_inject_ = !allow_share_;
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate);
} else {
return stmt;
}
}
// Attribute
Stmt VisitStmt_(const AttrStmtNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
} else if (!allow_share_ && !vt_loop_injected_ &&
(op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(GetRef<Stmt>(op), true);
} else {
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return AttrStmt(op->node, op->attr_key, value, body);
}
}
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
}
// For
Stmt VisitStmt_(const ForNode* op) final {
CHECK(is_zero(op->min));
PrimExpr extent = this->VisitExpr(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
++max_loop_depth_;
return stmt;
}
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
++max_loop_depth_;
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
CHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->VisitStmt(op->else_case);
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
}
// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
CHECK_EQ(max_loop_depth_, 0);
auto fmutate = [this](const Stmt& s) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt ret = this->VisitStmt(s);
max_loop_depth_ = std::max(max_loop_depth_, temp);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
bool changed = false;
Array<PrimExpr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
PrimExpr new_ext = this->VisitExpr(op->extents[i]);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
if (!new_ext.same_as(op->extents[i])) changed = true;
extents.push_back(new_ext);
}
visit_touched_var_ = false;
Stmt body;
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
PrimExpr stride =
foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
Array<PrimExpr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
for (PrimExpr e : extents) {
other.push_back(e);
}
extents = other;
changed = true;
// mark this buffer get touched.
alloc_remap_[op->buffer_var.get()] = stride;
// Mutate the body.
body = this->VisitStmt(op->body);
} else {
// Mutate the body.
body = this->VisitStmt(op->body);
}
if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
}
// inject vthread loop
Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
CHECK(!vt_loop_injected_);
// reset the flags
visit_touched_var_ = false;
trigger_base_inject_ = false;
vt_loop_injected_ = true;
if (before_mutation) {
stmt = this->VisitStmt(stmt);
}
// reset the flags after processing.
vt_loop_injected_ = false;
visit_touched_var_ = false;
// only unroll if number of vthreads are small
if (max_loop_depth_ == 0 && num_threads_ < 16) {
// do unrolling if it is inside innermost content.
Array<Stmt> seq;
for (int i = 0; i < num_threads_; ++i) {
seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
}
return SeqStmt::Flatten(seq);
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
ForType::Serial, DeviceAPI::None, stmt);
}
}
private:
// vthread variable
Var var_;
// the threads/lanes
int num_threads_;
// whethe the loop is already injected.
bool vt_loop_injected_{false};
// whether current expression get touched.
bool visit_touched_var_{false};
// Trigger base stmt
bool trigger_base_inject_{false};
// the counter of loops in after mutation.
int max_loop_depth_{0};
// The variables that get touched.
const std::unordered_set<const VarNode*>& touched_var_;
// Whether allow shareding.
bool allow_share_;
// The allocations that get touched -> extent
std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
};
class VirtualThreadInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
if (op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
bool allow_share = iv->thread_tag == "vthread";
int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched, allow_share);
return injecter(op->body);
} else {
return stmt;
}
}
Stmt VisitStmt_(const ProducerStoreNode* op) final {
LOG(FATAL) << "Need to call StorageFlatten first";
return GetRef<Stmt>(op);
}
};
Stmt InjectVirtualThread(Stmt stmt) {
stmt = VirtualThreadInjector()(std::move(stmt));
return ConvertSSA(std::move(stmt));
}
namespace transform {
Pass InjectVirtualThread() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body)));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread);
} // namespace transform
} // namespace tir
} // namespace tvm