| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file thread_storage_sync.cc |
| */ |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/stmt_functor.h> |
| #include <tvm/tir/transform.h> |
| |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| #include "../../runtime/thread_storage_scope.h" |
| #include "ir_util.h" |
| #include "storage_access.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| class ThreadSyncPlanner : public StorageAccessVisitor { |
| public: |
| explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} |
| |
| // The syncs inserted before each statement |
| std::unordered_set<const Object*> syncs_inserted_; |
| |
| protected: |
| bool Enabled(const VarNode* buf, const StorageScope& scope) const final { |
| return in_device_env() && scope == sync_scope_; |
| } |
| // Plan the sync |
| std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final { |
| // Unsynced reads and writes |
| std::vector<AccessEntry> reads; |
| std::vector<AccessEntry> writes; |
| // if it is a loop, rotate two times to consider effect of loop. |
| // simulation based approach to find dependenceies |
| for (size_t i = 0; i < seq.size(); ++i) { |
| const StmtEntry& s = seq[i]; |
| // check if sync before statement is needed. |
| bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); |
| // Apply the syncs added already. |
| if (sync_before_stmt) { |
| reads.clear(); |
| writes.clear(); |
| } |
| for (const AccessEntry& acc : s.access) { |
| if (acc.type == kRead) { |
| if (FindConflict(writes, acc, false)) { |
| sync_before_stmt = true; |
| break; |
| } |
| } else if (acc.type == kWrite) { |
| if (FindConflict(reads, acc, false)) { |
| sync_before_stmt = true; |
| break; |
| } |
| } else if (acc.type == kSync) { |
| reads.clear(); |
| writes.clear(); |
| } |
| } |
| // If sync is inserted. remove the irrelevant things. |
| if (sync_before_stmt) { |
| reads.clear(); |
| writes.clear(); |
| } |
| // Add the read/write of current statement |
| for (const AccessEntry& acc : s.access) { |
| if (acc.type == kRead) { |
| reads.push_back(acc); |
| } else if (acc.type == kWrite) { |
| writes.push_back(acc); |
| } else if (acc.type == kSync) { |
| reads.clear(); |
| writes.clear(); |
| } |
| } |
| if (sync_before_stmt) { |
| CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; |
| syncs_inserted_.insert(s.stmt); |
| } |
| } |
| if (loop != nullptr) { |
| for (size_t i = 0; i < seq.size(); ++i) { |
| const StmtEntry& s = seq[i]; |
| if (syncs_inserted_.count(s.stmt) != 0) break; |
| if (reads.empty() && writes.empty()) break; |
| bool sync_before_stmt = false; |
| for (const AccessEntry& acc : s.access) { |
| if (acc.type == kRead) { |
| if (FindConflict(writes, acc, true)) { |
| sync_before_stmt = true; |
| break; |
| } |
| } else if (acc.type == kWrite) { |
| if (FindConflict(reads, acc, true)) { |
| sync_before_stmt = true; |
| break; |
| } |
| } else if (acc.type == kSync) { |
| reads.clear(); |
| writes.clear(); |
| } |
| } |
| if (sync_before_stmt) { |
| CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; |
| syncs_inserted_.insert(s.stmt); |
| break; |
| } |
| } |
| } |
| // return the exposed entries, remove unecessary ones. |
| int sync_count = 0; |
| // head are before first sync, tail are after last sync |
| std::vector<AccessEntry> head, tail; |
| AccessEntry esync; |
| esync.threads = this->env_threads(); |
| esync.type = kSync; |
| esync.scope = sync_scope_; |
| |
| for (const StmtEntry& s : seq) { |
| if (syncs_inserted_.count(s.stmt)) { |
| if (sync_count != 0) { |
| tail.clear(); |
| } else { |
| head.push_back(esync); |
| } |
| ++sync_count; |
| } |
| for (const AccessEntry& acc : s.access) { |
| if (acc.type == kSync) { |
| if (sync_count != 0) { |
| tail.clear(); |
| } else { |
| head.push_back(esync); |
| } |
| ++sync_count; |
| } else { |
| if (sync_count != 0) { |
| tail.push_back(acc); |
| } else { |
| head.push_back(acc); |
| } |
| } |
| } |
| } |
| head.insert(head.end(), tail.begin(), tail.end()); |
| if (loop != nullptr) { |
| // clear double buffer flag after a loop is finished. |
| for (AccessEntry& e : head) { |
| e.double_buffer_write = false; |
| } |
| } |
| return head; |
| } |
| |
| private: |
| // find conflicting entry in vec. |
| bool FindConflict(const std::vector<AccessEntry>& vec, const AccessEntry& e, bool loop_carry) { |
| for (const AccessEntry& x : vec) { |
| if (x.buffer.same_as(e.buffer)) { |
| // Assumes no race between threads |
| // Same index value means no conflicts |
| // TODO(tqchen) more standard set based testing. |
| if (e.touched.IsSinglePoint() && x.touched.IsSinglePoint()) { |
| if (ExprDeepEqual()(e.touched.PointValue(), x.touched.PointValue())) continue; |
| } |
| if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| private: |
| // synchronization scope |
| StorageScope sync_scope_; |
| }; |
| |
| class ThreadSyncInserter : public StmtExprMutator { |
| public: |
| ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs) |
| : sync_scope_(sync_scope), syncs_(syncs) {} |
| |
| Stmt VisitStmt(const Stmt& stmt) final { |
| if (syncs_.size() == 0) return stmt; |
| if (syncs_.count(stmt.get())) { |
| Stmt barrier; |
| if (sync_scope_.rank == StorageRank::kGlobal) { |
| barrier = MakeGlobalBarrier(); |
| } else { |
| barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), |
| {StringImm(sync_scope_.to_string())})); |
| } |
| // Mutate after query, to avoid stmt change. |
| auto ret = StmtExprMutator::VisitStmt(stmt); |
| ret = SeqStmt({barrier, ret}); |
| return ret; |
| } else { |
| return StmtExprMutator::VisitStmt(stmt); |
| } |
| } |
| PrimExpr VisitExpr_(const LoadNode* op) final { |
| if (sync_scope_.rank == StorageRank::kGlobal && |
| GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { |
| ++rw_stats_[op->buffer_var].read_count; |
| } |
| return StmtExprMutator::VisitExpr_(op); |
| } |
| Stmt VisitStmt_(const StoreNode* op) final { |
| if (sync_scope_.rank == StorageRank::kGlobal && |
| GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { |
| ++rw_stats_[op->buffer_var].write_count; |
| } |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == attr::thread_extent) { |
| bool temp = true; |
| std::swap(temp, in_thread_env_); |
| thread_extents_.push_back(op); |
| Stmt ret = StmtExprMutator::VisitStmt_(op); |
| thread_extents_.pop_back(); |
| std::swap(temp, in_thread_env_); |
| // first thread scope. |
| if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { |
| ret = InitGlobalBarrier(ret.as<AttrStmtNode>()); |
| num_blocks_ = PrimExpr(); |
| is_lead_ = PrimExpr(); |
| } |
| return ret; |
| } else if (op->attr_key == attr::storage_scope) { |
| const VarNode* buf = op->node.as<VarNode>(); |
| storage_scope_[buf] = StorageScope::Create(op->value.as<StringImmNode>()->value); |
| return StmtExprMutator::VisitStmt_(op); |
| } else { |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| } |
| |
| PrimExpr VisitExpr_(const CallNode* op) final { |
| if (op->op.same_as(builtin::tvm_access_ptr())) { |
| PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
| op = expr.as<CallNode>(); |
| CHECK_EQ(op->args.size(), 5U); |
| const VarNode* buffer_var = op->args[1].as<VarNode>(); |
| Var var(GetRef<Var>(buffer_var)); |
| const IntImmNode* flag = op->args[4].as<IntImmNode>(); |
| if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && |
| GetScope(buffer_var).rank == StorageRank::kGlobal) { |
| ++rw_stats_[var].read_count; |
| } |
| if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && |
| GetScope(buffer_var).rank == StorageRank::kGlobal) { |
| ++rw_stats_[var].write_count; |
| } |
| return expr; |
| } else { |
| return StmtExprMutator::VisitExpr_(op); |
| } |
| } |
| |
| private: |
| // RW statistics about data |
| struct Entry { |
| int read_count{0}; |
| int write_count{0}; |
| }; |
| // Get current storage scope. |
| StorageScope GetScope(const VarNode* buf) const { |
| auto it = storage_scope_.find(buf); |
| StorageScope s; |
| s.rank = StorageRank::kGlobal; |
| if (it == storage_scope_.end()) return s; |
| return it->second; |
| } |
| // private functions. |
| Stmt InitGlobalBarrier(const AttrStmtNode* op) { |
| CHECK(op != nullptr); |
| Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; |
| Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); |
| Stmt body = op->body; |
| for (const auto& kv : rw_stats_) { |
| const auto& e = kv.second; |
| if (e.read_count != 0 && e.write_count != 0) { |
| body = AttrStmt(kv.first, attr::volatile_scope, 1, body); |
| } |
| } |
| rw_stats_.clear(); |
| Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); |
| body = SeqStmt({kinit, body}); |
| body = AttrStmt(op->node, op->attr_key, op->value, body); |
| return SeqStmt({prep, body}); |
| } |
| Stmt MakeGlobalBarrier() { |
| CHECK(sync_scope_.rank == StorageRank::kGlobal); |
| if (!num_blocks_.defined()) { |
| CHECK(!is_lead_.defined()); |
| num_work_dim_ = thread_extents_.size(); |
| for (const AttrStmtNode* attr : thread_extents_) { |
| IterVar iv = Downcast<IterVar>(attr->node); |
| runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); |
| if (s.rank == 0) { |
| num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); |
| } else if (s.rank == 1) { |
| PrimExpr cond = iv->var == make_zero(iv->var.dtype()); |
| is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; |
| } |
| } |
| } else { |
| CHECK_EQ(num_work_dim_, thread_extents_.size()); |
| } |
| return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), |
| {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); |
| } |
| // data structure. |
| StorageScope sync_scope_; |
| const std::unordered_set<const Object*>& syncs_; |
| // The storage scope of each buffer |
| std::unordered_map<const VarNode*, StorageScope> storage_scope_; |
| // The read write statistics of storage |
| std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_; |
| // The statistics for global barrier |
| bool in_thread_env_{false}; |
| // memorized results |
| std::vector<const AttrStmtNode*> thread_extents_; |
| size_t num_work_dim_{0}; |
| PrimExpr num_blocks_; |
| PrimExpr is_lead_; |
| }; |
| |
| Stmt ThreadSync(Stmt stmt, std::string storage_scope) { |
| StorageScope sync_scope = StorageScope::Create(storage_scope); |
| ThreadSyncPlanner planner(sync_scope); |
| planner(stmt); |
| return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); |
| } |
| |
| namespace transform { |
| |
| Pass ThreadSync(String storage_scope) { |
| auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { |
| auto* n = f.CopyOnWrite(); |
| n->body = ThreadSync(std::move(n->body), storage_scope); |
| return f; |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); |
| |
| } // namespace transform |
| } // namespace tir |
| } // namespace tvm |