blob: cd7283a7ef4d8a05ea262f1a42b495965279f9fd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_virtual_thread.cc
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
// If expression is touched by var.
class ExprTouched final : public StmtExprVisitor {
public:
explicit ExprTouched(const std::unordered_set<const VarNode*>& touched, bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void VisitExpr(const PrimExpr& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitStmt(const Stmt& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const BufferLoadNode* op) final {
HandleUseVar(op->buffer->data.get());
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
ICHECK(rw_mask);
// read
if (rw_mask->value & 1) {
HandleUseVar(buffer_var);
}
if (rw_mask->value & 2) {
HandleWriteVar(buffer_var);
}
this->VisitExpr(op->args[2]);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
void HandleUseVar(const VarNode* var) {
auto it = touched_var_.find(var);
if (it != touched_var_.end()) {
expr_touched_ = true;
}
// rember the used vars
// in case the var get touched later in a loop.
if (!expr_touched_) {
used_vars_.push_back(var);
}
}
void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); }
// the fields.
bool expr_touched_{false};
std::vector<const VarNode*> used_vars_;
std::vector<const VarNode*> write_vars_;
const std::unordered_set<const VarNode*>& touched_var_;
bool check_write_;
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public StmtVisitor {
public:
void VisitStmt_(const LetStmtNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
Record(op->var.get(), tc);
this->VisitStmt(op->body);
}
void VisitStmt_(const BufferStoreNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
for (const auto& index : op->indices) {
tc(index);
}
Record(op->buffer->data.get(), tc);
}
void VisitStmt_(const ForNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->min);
tc(op->extent);
Record(op->loop_var.get(), tc);
this->VisitStmt(op->body);
}
// external function call
void VisitStmt_(const EvaluateNode* op) final {
ExprTouched tc(touched_var_, true);
tc(op->value);
for (const VarNode* var : tc.write_vars_) {
Record(var, tc);
}
}
void VisitStmt_(const AllocateNode* op) final {
ExprTouched tc(touched_var_, false);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc(op->extents[i]);
}
tc.VisitExpr(op->condition);
Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body);
}
void Record(const VarNode* var, const ExprTouched& tc) {
if (touched_var_.count(var)) return;
if (tc.expr_touched_) {
touched_var_.insert(var);
} else {
for (const VarNode* r : tc.used_vars_) {
if (r != var) {
affect_[r].push_back(var);
}
}
}
}
std::unordered_set<const VarNode*> TouchedVar(const Stmt& stmt, const VarNode* var) {
touched_var_.insert(var);
this->VisitStmt(stmt);
// do a DFS to push affect around dependency.
std::vector<const VarNode*> pending(touched_var_.begin(), touched_var_.end());
while (!pending.empty()) {
const VarNode* v = pending.back();
pending.pop_back();
for (const VarNode* r : affect_[v]) {
if (!touched_var_.count(r)) {
touched_var_.insert(r);
pending.push_back(r);
}
}
}
return std::move(touched_var_);
}
private:
// Whether variable is touched by the thread variable.
std::unordered_set<const VarNode*> touched_var_;
// x -> all the buffers x read from
std::unordered_map<const VarNode*, std::vector<const VarNode*>> affect_;
};
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
// constructor
VTInjector(arith::Analyzer* analyzer, Var var, int num_threads,
const std::unordered_set<const VarNode*>& touched_var, bool allow_share)
: IRMutatorWithAnalyzer(analyzer),
var_(var),
num_threads_(num_threads),
touched_var_(touched_var),
allow_share_(allow_share) {}
// Inject VTLoop when needed.
Stmt VisitStmt(const Stmt& s) final {
ICHECK(!visit_touched_var_);
auto stmt = StmtExprMutator::VisitStmt(s);
if (visit_touched_var_ || trigger_base_inject_) {
if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
}
visit_touched_var_ = false;
trigger_base_inject_ = false;
}
return stmt;
}
// Variable
PrimExpr VisitExpr_(const VarNode* op) final {
ICHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
return ffi::GetRef<PrimExpr>(op);
}
PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
return analyzer_->Simplify(index + var_ * alloc_extent);
}
// Expression.
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
visit_touched_var_ = true;
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
offset = RewriteIndex(offset, stride);
return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
} else if (op->op.same_as(builtin::tvm_context_id())) {
return allow_share_ ? ffi::GetRef<PrimExpr>(op) : var_;
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const EvaluateNode* op) final {
trigger_base_inject_ = !allow_share_;
return StmtExprMutator::VisitStmt_(op);
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
trigger_base_inject_ = !allow_share_;
return VisitBufferAccess(std::move(node));
}
template <typename Node>
Node VisitBufferAccess(Node node) {
if (touched_var_.count(node->buffer->data.get())) {
visit_touched_var_ = true;
}
auto it = alloc_remap_.find(node->buffer->data.get());
if (it != alloc_remap_.end()) {
ICHECK_EQ(node->indices.size(), 1)
<< "InjectVirtualThread expects rewritten allocations to be flat memory.";
auto writer = node.CopyOnWrite();
writer->buffer = GetRemappedBuffer(node->buffer, it->second);
writer->indices = {RewriteIndex(node->indices[0], it->second)};
}
return node;
}
Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) {
auto key = buf.get();
auto it = buf_remap_.find(key);
if (it != buf_remap_.end()) {
return it->second;
}
ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened.";
auto writer = buf.CopyOnWrite();
writer->shape = {buf->shape[0] * alloc_extent};
buf_remap_[key] = buf;
return buf;
}
// Attribute
Stmt VisitStmt_(const AttrStmtNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
} else if (!allow_share_ && !vt_loop_injected_ &&
(op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
} else {
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
return AttrStmt(op->node, op->attr_key, value, body);
}
}
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
}
// For
Stmt VisitStmt_(const ForNode* op) final {
ICHECK(is_zero(op->min));
PrimExpr extent = this->VisitExpr(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
Stmt stmt = InjectVTLoop(ffi::GetRef<Stmt>(op), true);
++max_loop_depth_;
return stmt;
}
visit_touched_var_ = false;
Stmt body = this->VisitStmt(op->body);
++max_loop_depth_;
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->extent = std::move(extent);
n->body = std::move(body);
return Stmt(n);
}
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
ICHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->VisitStmt(op->then_case);
ffi::Optional<Stmt> else_case = std::nullopt;
if (op->else_case) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->VisitStmt(op->else_case.value());
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return ffi::GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
}
// While
Stmt VisitStmt_(const WhileNode* op) final {
// TODO(masahi): What should we do for While nodes?
LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet";
}
// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
ICHECK_EQ(max_loop_depth_, 0);
auto fmutate = [this](const Stmt& s) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt ret = this->VisitStmt(s);
max_loop_depth_ = std::max(max_loop_depth_, temp);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
Allocate node = ffi::GetRef<Allocate>(op);
PrimExpr condition = this->VisitExpr(op->condition);
ffi::Array<PrimExpr> extents =
op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); });
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
// Rewrite the buffer if its shape or any value stored in it
// depends on the virtual thread var. If `allow_share_` is false,
// then the buffer is always rewritten, even if separate virtual
// threads only read from the buffer.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
// TODO(Lunderberg): Move pass to apply before
// FlattenBuffer. Would rewrite the Buffer to
// add the injected virtual thread as the first index.
ICHECK_EQ(extents.size(), 1)
<< "InjectVirtualThread expects rewritten allocations to be flat memory.";
PrimExpr stride = extents[0];
extents = {stride * num_threads_};
// Mark the buffer var as touched. BufferLoad/BufferStore should
// access locations at `current_index + stride*vthread_var`.
alloc_remap_[op->buffer_var.get()] = stride;
}
// Mutate the body. Depends on alloc_remap_.
auto body = this->VisitStmt(op->body);
if (extents.same_as(op->extents) && body.same_as(op->body) &&
condition.same_as(op->condition)) {
return ffi::GetRef<Stmt>(op);
} else {
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
}
// inject vthread loop
Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
ICHECK(!vt_loop_injected_);
// reset the flags
visit_touched_var_ = false;
trigger_base_inject_ = false;
vt_loop_injected_ = true;
if (before_mutation) {
stmt = this->VisitStmt(stmt);
}
// reset the flags after processing.
vt_loop_injected_ = false;
visit_touched_var_ = false;
// only unroll if number of vthreads are small
if (max_loop_depth_ == 0 && num_threads_ < 16) {
// do unrolling if it is inside innermost content.
ffi::Array<Stmt> seq;
for (int i = 0; i < num_threads_; ++i) {
seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
}
return SeqStmt::Flatten(seq);
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
ForKind::kSerial, stmt);
}
}
private:
// vthread variable
Var var_;
// the threads/lanes
int num_threads_;
// whether the loop is already injected.
bool vt_loop_injected_{false};
// whether current expression get touched.
bool visit_touched_var_{false};
// Trigger base stmt
bool trigger_base_inject_{false};
// the counter of loops in after mutation.
int max_loop_depth_{0};
// The variables that get touched.
const std::unordered_set<const VarNode*>& touched_var_;
// Whether allow shareding.
bool allow_share_;
/* \brief The allocations that get touched -> extent
*
* Maps from the buffer_var of an allocate node to the original
* extent of the allocation. Used when rewriting the indices of
* BufferLoad/BufferStore.
*/
std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
/*! \brief Map of buffers that are modified.
*
* Buffers allocated or written to within the virtual thread loop
* must have one copy per virtual thread. This is done by enlarging
* the allocated buffer size, then modifying the indices at which
* each virtual thread accesses the buffer.
*/
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
};
class VirtualThreadInjector : public arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
using IRMutatorWithAnalyzer::VisitStmt_;
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
if (op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread";
int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injector(analyzer_, iv->var, nthread, touched, allow_share);
return injector(op->body);
} else {
return stmt;
}
}
};
namespace transform {
Pass InjectVirtualThread() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
arith::Analyzer analyzer;
n->body = VirtualThreadInjector(&analyzer)(std::move(n->body));
n->body = ConvertSSA(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.InjectVirtualThread", InjectVirtualThread);
}
} // namespace transform
} // namespace tir
} // namespace tvm