blob: 528f43bade77062936099b5caaa8e03fe1f195f9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Planning where buffers to be allocated and update the AST.
* \file plan_update_buffer_allocation_location.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/var.h>
#include "../../tir/transform/ir_utils.h"
namespace tvm {
namespace s_tir {
using namespace tvm::tir;
class CollectManagedAllocations : public StmtExprVisitor {
public:
void VisitStmt_(const SBlockNode* op) final {
for (const auto& buf : op->alloc_buffers) {
managed_allocations.insert(buf->data.get());
}
for (const auto& buf : op->match_buffers) {
managed_allocations.insert(buf->buffer->data.get());
}
StmtExprVisitor::VisitStmt_(op);
}
/*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by
* BufferAllocationLocator. */
std::unordered_set<const VarNode*> managed_allocations;
};
/*! \brief Collect the allocate buffer order. */
class BufferAllocateOrderCollector : public StmtExprVisitor {
public:
static ffi::Array<Buffer> Collect(const PrimFunc& func) {
BufferAllocateOrderCollector collector;
for (const auto& kv : func->buffer_map) {
collector.buffer_alloc_recorder_.push_back(kv.second);
}
collector(func->body);
return std::move(collector.buffer_alloc_recorder_);
}
private:
bool find(const Buffer& buf) {
return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) !=
buffer_alloc_recorder_.end();
}
void VisitStmt_(const SBlockNode* op) final {
for (const Buffer& buffer : op->alloc_buffers) {
buffer_alloc_recorder_.push_back(buffer);
}
// Also visit match_buffers to collect buffers that only appear in read and match_buffer
// regions.
for (const auto& region : op->match_buffers) {
if (!find(region->source->buffer)) {
buffer_alloc_recorder_.push_back(region->source->buffer);
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode* op) final {
if (!find(op->buffer)) {
buffer_alloc_recorder_.push_back(op->buffer);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
if (!find(op->buffer)) {
buffer_alloc_recorder_.push_back(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
/*! \brief The buffer allocated order recorder. */
ffi::Array<Buffer> buffer_alloc_recorder_;
};
class BufferAllocationLocator : public StmtExprMutator {
public:
explicit BufferAllocationLocator(const PrimFunc& func) {
ffi::Map<Buffer, ffi::Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
// The buffer_alloc_recorder Array is used to keep the buffer allocation order
// since the buffer_lca Map is unordered.
ffi::Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func);
std::unordered_set<const VarNode*> arg_buffer_vars;
CollectManagedAllocations collector;
collector(func->body);
managed_allocations_ = collector.managed_allocations;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffer_vars.emplace(buffer->data.get());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// create buffers to be allocated at each stmts
for (const auto& buffer : buffer_alloc_recorder) {
auto it = buffer_lca.find(buffer);
if (it != buffer_lca.end()) {
const StmtNode* stmt = (*it).second.get();
if (arg_buffer_vars.count(buffer->data.get())) {
continue;
}
if (managed_allocations_.count(buffer->data.get())) {
alloc_buffers_[stmt].push_back(buffer);
}
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
}
}
private:
Stmt VisitStmt_(const ForNode* op) final {
auto it = alloc_buffers_.find(op);
if (it == alloc_buffers_.end()) {
return StmtMutator::VisitStmt_(op);
}
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
auto node = Downcast<For>(StmtMutator::VisitStmt_(op));
ffi::Array<Buffer> new_block_alloc_bufs;
for (const Buffer& buf : it->second) {
if (managed_allocations_.count(buf->data.get())) {
buffer_data_to_buffer_.erase(buf->data);
new_block_alloc_bufs.push_back(buf);
}
}
if (new_block_alloc_bufs.size()) {
node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs);
}
return node;
}
Stmt VisitStmt_(const SBlockNode* op) final {
ICHECK(!op->init.defined());
ffi::Array<Buffer> alloc_buffers;
auto it = alloc_buffers_.find(op);
if (it != alloc_buffers_.end()) {
alloc_buffers = it->second;
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
}
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
const Var& source_var = match_buffer->source->buffer->data;
ICHECK(buffer_data_to_buffer_.count(source_var));
buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<SBlockNode>();
ICHECK(op != nullptr);
// No longer consider buffers created by match_buffer inside the block when updating access
// region.
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
buffer_data_to_buffer_.erase(target_var);
}
// No longer consider buffers allocated inside the block when updating access region.
if (it != alloc_buffers_.end()) {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
}
}
ObjectPtr<SBlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
// Erase buffer allocated inside the block from access region.
n->reads = RemoveRedundantBufferRegion(n->reads);
n->writes = RemoveRedundantBufferRegion(n->writes);
return Stmt(n);
}
Stmt InjectOpaqueBlock(Stmt body, const ffi::Array<Buffer>& alloc_buffers) {
ICHECK(!alloc_buffers.empty());
SBlock opaque_block(/*iter_vars=*/{},
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/"",
/*body=*/std::move(body),
/*init=*/std::nullopt,
/*alloc_buffers=*/alloc_buffers);
ObjectPtr<SBlockNode> n = CopyOnWrite(opaque_block.get());
ffi::Array<ffi::Array<BufferRegion>> access =
GetSBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_);
n->reads = access[0];
n->writes = access[1];
SBlockRealize realize({}, Bool(true), SBlock(n));
return realize;
}
ffi::Array<BufferRegion> RemoveRedundantBufferRegion(
const ffi::Array<BufferRegion>& region) const {
ffi::Array<BufferRegion> result;
for (const BufferRegion& buffer_region : region) {
if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
result.push_back(buffer_region);
}
}
return result;
}
/*! \brief The map from stmt to the buffers to be allocated under it. */
std::unordered_map<const StmtNode*, ffi::Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
ffi::Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief Buffers that are allocated within a BlockNode, and may be moved. */
std::unordered_set<const VarNode*> managed_allocations_;
};
namespace transform {
Pass PlanAndUpdateBufferAllocationLocation() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
BufferAllocationLocator locator(f);
fptr->body = locator(fptr->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "s_tir.PlanAndUpdateBufferAllocationLocation", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("s_tir.transform.PlanAndUpdateBufferAllocationLocation",
PlanAndUpdateBufferAllocationLocation);
}
} // namespace transform
} // namespace s_tir
} // namespace tvm