blob: 676f162076ce0ea21002bf6e838f131c6b2fa05c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/block_scope.h>
#include <tvm/tir/utils.h>
namespace tvm {
namespace tir {
TVM_FFI_STATIC_INIT_BLOCK() {
StmtSRefNode::RegisterReflection();
DependencyNode::RegisterReflection();
BlockScopeNode::RegisterReflection();
}
/******** Utility functions ********/
template <class K, class V>
using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief Add a dependency relation.
* \param src The source of the dependency
* \param dst The destination of the dependecy
* \param kind Type of the dependency
* \note This method is effectively NOP on self-loops
*/
void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
if (!src.same_as(dst)) {
Dependency dep(src, dst, kind);
self->src2deps[src].push_back(dep);
self->dst2deps[dst].push_back(dep);
}
}
/******** Constructors ********/
StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
ObjectPtr<StmtSRefNode> n = ffi::make_object<StmtSRefNode>();
n->stmt = stmt;
n->parent = parent;
n->seq_index = seq_index;
data_ = std::move(n);
}
StmtSRef StmtSRef::InlineMark() {
static StmtSRef result(nullptr, nullptr, -1);
return result;
}
StmtSRef StmtSRef::RootMark() {
static StmtSRef result(nullptr, nullptr, -1);
return result;
}
Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
ObjectPtr<DependencyNode> node = ffi::make_object<DependencyNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->kind = kind;
data_ = std::move(node);
}
BlockScope::BlockScope() { data_ = ffi::make_object<BlockScopeNode>(); }
BlockScope::BlockScope(const ffi::Array<StmtSRef>& child_block_srefs) {
ObjectPtr<BlockScopeNode> n = ffi::make_object<BlockScopeNode>();
SMap<Buffer, ffi::Array<StmtSRef>> buffer_readers;
SMap<Buffer, ffi::Array<StmtSRef>>& buffer_writers = n->buffer_writers;
for (const StmtSRef& child_block_sref : child_block_srefs) {
const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
// Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
for (const BufferRegion& region : child_block->reads) {
buffer_readers[region->buffer].push_back(child_block_sref);
}
for (const BufferRegion& region : child_block->writes) {
buffer_writers[region->buffer].push_back(child_block_sref);
}
// Step 2. Update RAW dependency
for (const BufferRegion& region : child_block->reads) {
auto it = buffer_writers.find(region->buffer);
if (it != buffer_writers.end()) {
for (const StmtSRef& from : it->second) {
AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
}
}
}
// Step 3. Update WAW dependency
for (const BufferRegion& region : child_block->writes) {
auto it = buffer_writers.find(region->buffer);
if (it != buffer_writers.end()) {
for (const StmtSRef& from : it->second) {
AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
}
}
}
// Step 4. Update WAR dependency
for (const BufferRegion& region : child_block->writes) {
auto it = buffer_readers.find(region->buffer);
if (it != buffer_readers.end()) {
for (const StmtSRef& from : it->second) {
AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
}
}
}
}
data_ = std::move(n);
}
/******** Dependency ********/
ffi::Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const {
auto iter = this->src2deps.find(block_sref);
if (iter != this->src2deps.end()) {
return iter->second;
} else {
return {};
}
}
ffi::Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const {
auto iter = this->dst2deps.find(block_sref);
if (iter != this->dst2deps.end()) {
return iter->second;
} else {
return {};
}
}
/*!
* \brief Add a new statement to the stack, which becomes the current scope
* \param stmt A for-loop statement or a block statement
*/
void SRefTreeCreator::PushSRef(const StmtNode* stmt) {
if (srefs_.empty()) {
srefs_.push_back(
StmtSRef(stmt,
/*parent=*/nullptr,
/*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
} else {
StmtSRefNode* parent = srefs_.back().get();
srefs_.push_back(
StmtSRef(stmt, parent,
/*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
}
}
/*! \brief Pop the top of the scope and record it in stmt2ref map */
void SRefTreeCreator::PopAndRecordSRef() {
StmtSRef sref = std::move(srefs_.back());
stmt2ref_[sref->stmt] = sref;
srefs_.pop_back();
}
void SRefTreeCreator::VisitStmt_(const ForNode* loop) {
if (!include_loops_) {
VisitStmt(loop->body);
} else {
PushSRef(loop);
VisitStmt(loop->body);
PopAndRecordSRef();
}
}
void SRefTreeCreator::VisitStmt_(const BlockRealizeNode* realize) {
const BlockNode* block = realize->block.get();
PushSRef(block);
VisitStmt(block->body); // `block->init` is not visited
PopAndRecordSRef();
}
void SRefTreeCreator::VisitStmt_(const SeqStmtNode* seq_stmt) {
// Set `seq_index` information for SeqStmtNode
StmtVisitor::VisitStmt_(seq_stmt);
SetSeqIndexInChildren(stmt2ref_, seq_stmt, include_loops_);
}
/******** FFI ********/
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tir.StmtSRefStmt",
[](StmtSRef sref) -> ffi::Optional<Stmt> {
return ffi::GetRef<ffi::Optional<Stmt>>(sref->stmt);
})
.def("tir.StmtSRefParent",
[](StmtSRef sref) -> ffi::Optional<StmtSRef> {
return ffi::GetRef<ffi::Optional<StmtSRef>>(sref->parent);
})
.def("tir.StmtSRefRootMark", StmtSRef::RootMark)
.def("tir.StmtSRefInlineMark", StmtSRef::InlineMark)
.def_method("tir.BlockScopeGetDepsBySrc", &BlockScopeNode::GetDepsBySrc)
.def_method("tir.BlockScopeGetDepsByDst", &BlockScopeNode::GetDepsByDst);
}
} // namespace tir
} // namespace tvm