blob: 7faf7d1b173e16932a97e4f7c8855d04000238af [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file narrow_channel_access.cc
* \brief Narrow channel access to a smaller range
* when possible by bringing it to the internal loop.
*/
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/channel.h>
#include "ir_util.h"
namespace tvm {
namespace ir {
using namespace arith;
// Bound deducer for channel access.
class ChannelAccessBound : public IRVisitor {
public:
ChannelAccessBound(const Variable* buf_var, bool read_access)
: buf_var_(buf_var), read_access_(read_access) {}
void Visit_(const Store* op) final {
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
ret_.emplace_back(EvalSet(op->index, dom_map_));
}
IRVisitor::Visit_(op);
}
void Visit_(const For* op) final {
CHECK(is_zero(op->min));
// We know that the extent of the loop won't depend on relaxed scope.
// TODO(tqchen) have a verification pass.
dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->extent - 1);
IRVisitor::Visit_(op);
}
void Visit_(const Load* op) final {
if (read_access_ && buf_var_ == op->buffer_var.get()) {
ret_.emplace_back(EvalSet(op->index, dom_map_));
}
IRVisitor::Visit_(op);
}
void Visit_(const Let* op) final {
LOG(FATAL) << "cannot pass through let";
}
void Visit_(const LetStmt* op) final {
LOG(FATAL) << "cannot pass through let";
}
IntSet Eval(const Stmt& stmt) {
Visit(stmt);
return Union(ret_);
}
private:
// The buffer variable.
const Variable* buf_var_;
// read or write
bool read_access_{true};
// Box
std::vector<IntSet> ret_;
// Domain map.
std::unordered_map<const Variable*, IntSet> dom_map_;
};
class ChannelAccessIndexRewriter : public IRMutator {
public:
ChannelAccessIndexRewriter(const Variable* buf_var,
Expr min,
bool read_access)
: buf_var_(buf_var), min_(min), read_access_(read_access) {}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
if (read_access_ && buf_var_ == op->buffer_var.get()) {
return Load::make(
op->type, op->buffer_var, ir::Simplify(op->index - min_),
op->predicate);
} else {
return expr;
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
return Store::make(
op->buffer_var, op->value, ir::Simplify(op->index - min_),
op->predicate);
} else {
return stmt;
}
}
private:
// The buffer variable.
const Variable* buf_var_;
// The min bound.
Expr min_;
// read or write
bool read_access_{true};
};
// Rewrite channel access pattern.
class ChannelAccessRewriter : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret;
const AttrStmt* adv = op->body.as<AttrStmt>();
if ((op->attr_key == ir::attr::channel_read_scope &&
adv && adv->attr_key == ir::attr::channel_read_advance) ||
(op->attr_key == ir::attr::channel_write_scope &&
adv && adv->attr_key == ir::attr::channel_write_advance)) {
RewriteEntry e;
e.window = op;
e.advance = adv;
e.read_access = op->attr_key == ir::attr::channel_read_scope;
tasks_.push_back(e);
ret = IRMutator::Mutate_(op, s);
if (tasks_.back().rewrite_success) {
ret = ret.as<AttrStmt>()->body.as<AttrStmt>()->body;
}
tasks_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
std::vector<RewriteEntry> tasks;
std::swap(tasks_, tasks);
Stmt body = op->body;
std::vector<Stmt> nest;
for (RewriteEntry& e : tasks) {
body = RewriteAccess(op, body, &e, &nest);
}
if (!body.same_as(op->body)) {
body = Mutate(body);
body = For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, body);
body = MergeNest(nest, body);
} else {
CHECK_EQ(nest.size(), 0U);
body = IRMutator::Mutate_(op, s);
}
std::swap(tasks_, tasks);
return body;
}
private:
struct RewriteEntry {
bool read_access;
const AttrStmt* window;
const AttrStmt* advance;
bool rewrite_success{false};
};
Stmt RewriteAccess(const For* for_op,
Stmt body,
RewriteEntry* e,
std::vector<Stmt>* outer_nest) {
const AttrStmt* adv_op = e->advance;
const Expr& window = e->window->value;
bool read_access = e->read_access;
Var var(for_op->loop_var);
Channel ch(adv_op->node.node_);
ChannelAccessBound acc(ch->handle_var.get(), read_access);
IntSet iset = acc.Eval(for_op->body);
Range r = iset.cover_range(Range::make_by_min_extent(0, window));
r = Range::make_by_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, {var});
if (linear_eq.size() == 0) return body;
Expr coeff = linear_eq[0];
Expr base = linear_eq[1];
if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body;
// rewrite access index.
ChannelAccessIndexRewriter rw(
ch->handle_var.get(), var * coeff, read_access);
body = rw.Mutate(body);
if (read_access) {
body = AttrStmt::make(
ch, ir::attr::channel_read_scope, r->extent,
AttrStmt::make(ch, ir::attr::channel_read_advance, coeff,
body));
} else {
body = AttrStmt::make(
ch, ir::attr::channel_write_scope, r->extent,
AttrStmt::make(ch, ir::attr::channel_write_advance, coeff,
body));
}
if (!is_zero(left)) {
Stmt no_op = Evaluate::make(0);
if (read_access) {
outer_nest->emplace_back(
AttrStmt::make(ch, ir::attr::channel_read_advance, left, no_op));
} else {
outer_nest->emplace_back(
AttrStmt::make(ch, ir::attr::channel_write_advance, left, no_op));
}
}
e->rewrite_success = true;
return body;
}
std::vector<RewriteEntry> tasks_;
};
Stmt NarrowChannelAccess(Stmt stmt) {
return ChannelAccessRewriter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm