blob: 716ec625d5a8a645cc6663ad277a7eb69e3c2819 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file coproc_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
namespace tvm {
namespace tir {
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public StmtExprVisitor {
public:
void VisitExpr_(const LoadNode* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const StoreNode* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode* buffer = op->args[1].as<VarNode>();
if (in_scope_) {
touched_[buffer].coproc = true;
} else {
touched_[buffer].normal = true;
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv = Downcast<IterVar>(op->node);
coproc_.insert(iv);
StmtExprVisitor::VisitStmt_(op);
in_scope_ = false;
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
// Touch Entry
struct TouchEntry {
bool normal{false};
bool coproc{false};
};
std::unordered_map<const VarNode*, TouchEntry> touched_;
std::unordered_set<IterVar> coproc_;
private:
bool in_scope_{false};
};
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
explicit CoProcSyncPlanner(const std::unordered_set<const VarNode*>& touched,
const std::string& coproc_name)
: touched_(touched), coproc_name_(coproc_name) {}
void Plan(const Stmt& stmt) {
this->VisitStmt(stmt);
PlanSync(scope_.back(), nullptr, true);
if (sync_.size() == 0) {
sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
}
}
// Write synchronization to be inserted before or after stmt.
std::unordered_map<const Object*, std::vector<Stmt> > sync_;
protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
return PlanSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const ForNode* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
bool contain_sync = false;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_write = true;
break;
}
if (acc.type == kSync) {
co_access.clear();
contain_sync = true;
}
}
if (sync_write) {
CHECK_NE(i, 0U);
sync_[seq[i - 1].stmt] = GetSync(co_access);
co_access.clear();
contain_sync = true;
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0) {
co_access.push_back(acc);
}
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_at_end = true;
break;
}
}
if (sync_.count(s.stmt) || sync_at_end) break;
}
}
if (sync_at_end && co_access.size() != 0) {
CHECK_NE(seq.size(), 0);
contain_sync = true;
sync_[seq.back().stmt] = GetSync(co_access);
co_access.clear();
}
if (contain_sync) {
AccessEntry e;
e.type = kSync;
co_access.insert(co_access.begin(), e);
}
return co_access;
}
// Add write Synchronization
std::vector<Stmt> GetSync(const std::vector<AccessEntry>& co_access) {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
return GetSync(coproc_name_ + ".coproc_sync");
}
std::vector<Stmt> GetSync(std::string sync_name) {
return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))};
}
const std::unordered_set<const VarNode*>& touched_;
std::string coproc_name_;
};
// Detect memory barriers when coproc read/write memory
class CoProcBarrierDetector : public StorageAccessVisitor {
public:
explicit CoProcBarrierDetector(const std::unordered_set<const VarNode*>& touched,
const std::string& coproc_name)
: touched_(touched) {
read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier";
write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier";
}
void PlanReadBarrier(const Stmt& stmt) {
read_barrier_ = true;
this->VisitStmt(stmt);
PlanReadBarrier(scope_.back(), nullptr);
}
void PlanWriteBarrier(const Stmt& stmt) {
read_barrier_ = false;
this->VisitStmt(stmt);
PlanWriteBarrier(scope_.back(), nullptr);
}
std::unordered_map<const Object*, std::vector<Stmt> > barrier_before_;
std::unordered_map<const Object*, std::vector<Stmt> > barrier_after_;
protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
if (read_barrier_) {
return PlanReadBarrier(seq, loop);
} else {
return PlanWriteBarrier(seq, loop);
}
}
private:
// Plan write barrier at Read after write point.
std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> read_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > write_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = write_set.find(acc.buffer.get());
if (it != write_set.end()) {
CHECK_NE(i, 0U);
barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second));
write_set.erase(it);
}
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kRead) {
fupdate(i, acc);
read_seq.push_back(acc);
}
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0 && acc.type == kWrite) {
write_set[acc.buffer.get()].push_back(acc);
}
}
}
// loop carry
if (loop != nullptr) {
for (const AccessEntry& acc : read_seq) {
fupdate(seq.size(), acc);
}
}
for (const auto& kv : write_set) {
read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end());
}
return read_seq;
}
std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> write_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > read_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = read_set.find(acc.buffer.get());
if (it != read_set.end()) {
CHECK_NE(i, seq.size());
barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second));
read_set.erase(it);
}
};
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kWrite) {
fupdate(i, acc);
write_seq.push_back(acc);
}
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0 && acc.type == kRead) {
read_set[acc.buffer.get()].push_back(acc);
}
}
}
// loop carry
if (loop != nullptr) {
for (const AccessEntry& acc : write_seq) {
fupdate(0, acc);
}
}
for (const auto& kv : read_set) {
write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end());
}
return write_seq;
}
Stmt MakeBarrier(const std::string& func, const std::vector<AccessEntry>& wvec) {
// insert write point
Array<arith::IntSet> wset;
for (const AccessEntry& acc : wvec) {
CHECK(acc.dtype == wvec[0].dtype);
wset.push_back(acc.touched);
}
Range none;
Range r = arith::Union(wset).CoverRange(none);
CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
PrimExpr min = r->min;
PrimExpr extent = r->extent;
return Evaluate(Call(DataType::Int(32), Op::Get(func),
{wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}));
}
// Write barrier name
bool read_barrier_{false};
std::string read_barrier_name_;
std::string write_barrier_name_;
const std::unordered_set<const VarNode*>& touched_;
};
class CoProcInstDepDetector : public StmtVisitor {
public:
explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name)
: coproc_axis_(coproc_axis) {
sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push");
sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop");
}
void Plan(const Stmt& stmt) {
this->VisitStmt(stmt);
if (last_state_.node != nullptr) {
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
}
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) {
const IntImmNode* ctx_id = op->value.as<IntImmNode>();
CHECK(ctx_id != nullptr);
curr_state_.clear();
curr_state_.node = op->body.get();
curr_state_.enter_ctx.insert(ctx_id->value);
curr_state_.exit_ctx.insert(ctx_id->value);
UpdateState();
} else {
StmtVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const ForNode* op) final {
SyncState temp_first, temp_last;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
this->VisitStmt(op->body);
curr_state_.clear();
if (last_state_.node != nullptr) {
curr_state_.node = op;
CHECK(first_state_.node != nullptr);
// loop carry dependency
InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop));
curr_state_.enter_ctx = first_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.exit_ctx;
}
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
if (curr_state_.node != nullptr) {
UpdateState();
}
}
void VisitStmt_(const IfThenElseNode* op) final {
SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
{
// then stmt
this->VisitStmt(op->then_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
}
first_state_.clear();
last_state_.clear();
}
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
}
}
// update in the trace.
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
std::swap(curr_state_, curr_state);
if (curr_state_.node != nullptr) {
UpdateState();
}
}
// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Object*, std::vector<Stmt> > insert_before_;
std::unordered_map<const Object*, std::vector<Stmt> > insert_after_;
private:
// state in the sync entry
struct SyncState {
// The statement of the state.
const Object* node{nullptr};
// Set of all possible contexts in the entering moment.
std::unordered_set<int> enter_ctx;
// Set of all possible contexts in the exit moment.
std::unordered_set<int> exit_ctx;
// existing pop performed at enter
std::vector<std::pair<int, int> > enter_pop;
// existing push peformed at exit
std::vector<std::pair<int, int> > exit_push;
// clear the state
void clear() {
node = nullptr;
enter_ctx.clear();
exit_ctx.clear();
enter_pop.clear();
exit_push.clear();
}
};
// inject proper sync into the pair
// record the push/pop sequence that could be possibly un-matched.
// return the push/pop message at enter/exit of the Block
// after considering the existing unmatcheded events and added events
void InjectSync(const SyncState& prev, const SyncState& next,
std::vector<std::pair<int, int> >* prev_exit_push,
std::vector<std::pair<int, int> >* next_enter_pop) {
prev_exit_push->clear();
next_enter_pop->clear();
// quick path
if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 &&
next.enter_ctx.size() == 1) {
int from = *prev.exit_ctx.begin();
int to = *next.enter_ctx.begin();
if (from != to) {
insert_after_[prev.node].emplace_back(MakePush(from, to));
insert_before_[next.node].emplace_back(MakePop(from, to));
prev_exit_push->emplace_back(std::make_pair(from, to));
next_enter_pop->emplace_back(std::make_pair(from, to));
}
return;
}
// complicate path.
std::vector<std::pair<int, int> > vpush = prev.exit_push;
std::vector<std::pair<int, int> > vpop = next.enter_pop;
std::vector<std::pair<int, int> > pending;
for (int from : prev.exit_ctx) {
for (int to : next.enter_ctx) {
if (from != to) {
pending.emplace_back(std::make_pair(from, to));
}
}
}
// policy 1
std::vector<Stmt> prev_after, next_before;
for (const std::pair<int, int>& p : pending) {
if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) {
vpush.push_back(p);
prev_after.emplace_back(MakePush(p.first, p.second));
}
if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) {
vpop.push_back(p);
next_before.emplace_back(MakePop(p.first, p.second));
}
}
// fix pending
for (const std::pair<int, int>& p : vpush) {
if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) {
prev_after.emplace_back(MakePop(p.first, p.second));
} else {
prev_exit_push->push_back(p);
}
}
for (const std::pair<int, int>& p : vpop) {
if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) {
next_before.emplace_back(MakePush(p.first, p.second));
} else {
next_enter_pop->push_back(p);
}
}
if (prev_after.size() != 0) {
auto& v1 = insert_after_[prev.node];
v1.insert(v1.end(), prev_after.begin(), prev_after.end());
}
if (next_before.size() != 0) {
auto& v2 = insert_before_[next.node];
v2.insert(v2.end(), next_before.begin(), next_before.end());
}
}
void MatchFixEnterPop(const SyncState& state) {
if (state.enter_pop.size() == 0) return;
auto& vec = insert_before_[state.node];
for (const std::pair<int, int>& p : state.enter_pop) {
vec.push_back(MakePush(p.first, p.second));
}
}
void MatchFixExitPush(const SyncState& state) {
if (state.exit_push.size() == 0) return;
auto& vec = insert_after_[state.node];
for (const std::pair<int, int>& p : state.exit_push) {
vec.push_back(MakePop(p.first, p.second));
}
}
void UpdateState() {
if (last_state_.node != nullptr) {
std::vector<std::pair<int, int> > t1, t2;
InjectSync(last_state_, curr_state_, &t1, &t2);
std::swap(last_state_, curr_state_);
} else {
CHECK(first_state_.node == nullptr);
first_state_ = curr_state_;
last_state_ = curr_state_;
}
}
Stmt MakePush(int from, int to) {
return Evaluate(Call(DataType::Int(32), sync_push_op_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
}
Stmt MakePop(int from, int to) {
return Evaluate(Call(DataType::Int(32), sync_pop_op_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
// Variables
IterVar coproc_axis_;
Op sync_push_op_, sync_pop_op_;
};
class CoProcSyncInserter : public StmtMutator {
public:
Stmt Insert(Stmt stmt) {
CoProcTouchedBuffer visitor;
visitor(stmt);
if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const VarNode*> touched;
for (const auto& kv : visitor.touched_) {
if (kv.second.normal && kv.second.coproc) {
touched.insert(kv.first);
}
}
CHECK_EQ(visitor.coproc_.size(), 1U);
std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint;
// plan sync.
CoProcSyncPlanner sync_planner(touched, coproc_name);
sync_planner.Plan(stmt);
for (const auto& kv : sync_planner.sync_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
// Detect barrier
CoProcBarrierDetector barrier_detector(touched, coproc_name);
barrier_detector.PlanReadBarrier(stmt);
barrier_detector.PlanWriteBarrier(stmt);
for (const auto& kv : barrier_detector.barrier_before_) {
auto& vec = insert_before_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
for (const auto& kv : barrier_detector.barrier_after_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
// Detect barrier
CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name);
sync_detector.Plan(stmt);
for (const auto& kv : sync_detector.insert_before_) {
auto& vec = insert_before_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
for (const auto& kv : sync_detector.insert_after_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
return operator()(std::move(stmt));
}
Stmt VisitStmt(const Stmt& stmt) final {
auto it_before = insert_before_.find(stmt.get());
auto it_after = insert_after_.find(stmt.get());
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
return SeqStmt::Flatten(
it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(), new_stmt,
it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
}
private:
// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Object*, std::vector<Stmt> > insert_before_;
std::unordered_map<const Object*, std::vector<Stmt> > insert_after_;
};
Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); }
namespace transform {
Pass CoProcSync() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = CoProcSyncInserter().Insert(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
}
TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync);
} // namespace transform
} // namespace tir
} // namespace tvm