blob: 954f4f7cc47d56a0cb54516bf7268b6c24617e8d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
#include "storage_access.h"
namespace tvm {
namespace tir {
class ThreadSyncPlanner : public StorageAccessVisitor {
public:
explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {}
// The syncs inserted before each statement
std::unordered_set<const Object*> syncs_inserted_;
protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependenceies
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
// If sync is inserted. remove the irrelevant things.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
reads.push_back(acc);
} else if (acc.type == kWrite) {
writes.push_back(acc);
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
}
}
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
if (reads.empty() && writes.empty()) break;
bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
break;
}
}
}
// return the exposed entries, remove unecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail;
AccessEntry esync;
esync.threads = this->env_threads();
esync.type = kSync;
esync.scope = sync_scope_;
for (const StmtEntry& s : seq) {
if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kSync) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
} else {
if (sync_count != 0) {
tail.push_back(acc);
} else {
head.push_back(acc);
}
}
}
}
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
e.double_buffer_write = false;
}
}
return head;
}
private:
// find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr,
bool loop_carry) {
for (const AccessEntry& x : prev) {
if (FindConflict(x, curr, loop_carry)) {
return true;
}
}
return false;
}
bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) {
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
}
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
bool has_same_index = true;
for (size_t i = 0; i < prev.touched.size(); i++) {
const auto& prev_intset = prev.touched[i];
const auto& curr_intset = curr.touched[i];
bool provably_same_index =
prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint() &&
ExprDeepEqual()(prev_intset.PointValue(), curr_intset.PointValue());
if (!provably_same_index) {
has_same_index = false;
break;
}
}
if (has_same_index) {
return false;
}
// If this is a read into a double buffer that was previously
// swapped out, then it doesn't conflict.
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
return false;
}
// If nothing else allows sharing the same buffer, then they are
// in conflict.
return true;
}
private:
// synchronization scope
StorageScope sync_scope_;
};
// There are cases where necessary syncthreads is not inserted by ThreadSyncInserter.
// For example, syncthreads is needed after async_wait_queue in the second loop below,
// but since ThreadSyncInserter is not aware of the asynchronous semantics, it cannot tell
// that the syncthreads is needed there.
//
// // Pipeline prologue
// for i in range(125):
// async_commit_queue(0):
// async_scope:
// shared[(i + 3) % 4] = ...
// ...
//
// // Pipeline Epilogue
// for i in range(3):
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
// This class adds syncthreads after all async_wait_queue. That includes syncthreads that
// can be inserted by ThreadSyncInserter as well, but ThreadSyncInserter will not insert
// duplicate syncthreads if it finds an existing one at the synchronization point.
class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
public:
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(sync_scope) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::async_wait_queue_scope) {
auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
auto zero = make_zero(DataType::Int(32));
auto new_body = SeqStmt({sync, inner->body});
return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value,
AttrStmt(zero, tir::attr::async_wait_inflight_count, inner->value, new_body));
}
return StmtExprMutator::VisitStmt_(op);
}
private:
StorageScope sync_scope_;
};
class ThreadSyncInserter : public StmtExprMutator {
public:
ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs)
: sync_scope_(sync_scope), syncs_(syncs) {}
Stmt VisitStmt(const Stmt& stmt) final {
if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = SeqStmt({barrier, ret});
return ret;
} else {
return StmtExprMutator::VisitStmt(stmt);
}
}
PrimExpr VisitExpr_(const LoadNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
return PrimExpr();
}
Stmt VisitStmt_(const StoreNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
return Stmt();
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer->data].read_count;
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer->data].write_count;
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
bool temp = true;
std::swap(temp, in_thread_env_);
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
std::swap(temp, in_thread_env_);
// first thread scope.
if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
num_blocks_ = PrimExpr();
is_lead_ = PrimExpr();
}
return ret;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
Var buffer_var(GetRef<Var>(op->args[1].as<VarNode>()));
const IntImmNode* flag = op->args[4].as<IntImmNode>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].read_count;
}
if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].write_count;
}
return expr;
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// RW statistics about data
struct Entry {
int read_count{0};
int write_count{0};
};
// Get current storage scope.
StorageScope GetScope(Var buffer_var) const {
return StorageScope::Create(GetPtrStorageScope(buffer_var));
}
// private functions.
Stmt InitGlobalBarrier(const AttrStmtNode* op) {
ICHECK(op != nullptr);
Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
if (e.read_count != 0 && e.write_count != 0) {
body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
}
}
rw_stats_.clear();
Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}));
body = SeqStmt({kinit, body});
body = AttrStmt(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
}
Stmt MakeGlobalBarrier() {
ICHECK(sync_scope_.rank == StorageRank::kGlobal);
if (!num_blocks_.defined()) {
ICHECK(!is_lead_.defined());
num_work_dim_ = thread_extents_.size();
for (const AttrStmtNode* attr : thread_extents_) {
IterVar iv = Downcast<IterVar>(attr->node);
runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
if (s.rank == 0) {
num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value);
} else if (s.rank == 1) {
PrimExpr cond = iv->var == make_zero(iv->var.dtype());
is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
}
}
} else {
ICHECK_EQ(num_work_dim_, thread_extents_.size());
}
return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}));
}
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
// The read write statistics of storage
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_;
// The statistics for global barrier
bool in_thread_env_{false};
// memorized results
std::vector<const AttrStmtNode*> thread_extents_;
size_t num_work_dim_{0};
PrimExpr num_blocks_;
PrimExpr is_lead_;
};
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
ThreadSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
}
namespace transform {
Pass ThreadSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ThreadSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync);
} // namespace transform
} // namespace tir
} // namespace tvm