blob: 1384ea1a89ac5d7d92c119630b0759ae7e734f33 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
public:
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}
// The set of touched variable.
std::unordered_set<const Variable*> touched_;
};
class StripDoubleBufferWrite : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::double_buffer_write) {
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
};
class DoubleBufferInjector : public IRMutator {
public:
explicit DoubleBufferInjector(int split_loop)
: split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>
(op->extents, Expr()) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
alloc_nest.emplace_back(AttrStmt::make(
op->buffer_var, attr::storage_scope,
StringImm::make(it->second.scope),
Evaluate::make(0)));
alloc_nest.emplace_back(Allocate::make(
op->buffer_var, op->type, new_extents, op->condition,
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
if (split_loop_ != 0) {
// Explicitly unroll the loop
CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min;
Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Expr factor = make_const(new_ext.type(), split_loop_);
Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
std::unordered_map<const Variable*, Expr> vmap;
std::vector<Stmt> loop_seq;
for (int32_t i = 0; i < split_loop_; ++i) {
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
Stmt loop = For::make(
outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
MergeSeq(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
Expr idx = tail_base + make_const(tail_base.type(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
IfThenElse::make(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
stmt = Block::make(loop, MergeSeq(tail_seq));
}
stmt = Block::make(MergeSeq(it->second), stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
stmt = MergeNest(it->second, stmt);
}
loop_nest_.pop_back();
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.stride.defined());
return Store::make(op->buffer_var,
op->value,
e.switch_write_var * e.stride + op->index,
op->predicate);
} else {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined());
return Load::make(op->type,
op->buffer_var,
e.switch_read_var * e.stride + op->index,
op->predicate);
} else {
return expr;
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
CHECK(!dbuffer_info_.count(op));
return e;
}
private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
const VarExpr buffer(op->node.node_);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
Expr zero = make_const(e.loop->loop_var.type(), 0);
Expr one = make_const(e.loop->loop_var.type(), 1);
Expr two = make_const(e.loop->loop_var.type(), 2);
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two;
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two;
body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
Expr stride;
// The loop we need
const For* loop{nullptr};
// The switch variable.
VarExpr switch_write_var;
// The switch variable for reading.
Expr switch_read_var;
// The storage scope.
std::string scope;
};
// Whether split loop
int32_t split_loop_;
// Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false};
// The current loop next
std::vector<const For*> loop_nest_;
// The allocs to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
// The stmt to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
// The allocation size of the buffer
std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
};
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
} // namespace ir
} // namespace tvm