| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \brief Planning where buffers to be allocated and update the AST. |
| * \file plan_update_buffer_allocation_location.cc |
| */ |
| |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/s_tir/transform.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/stmt_functor.h> |
| #include <tvm/tir/var.h> |
| |
| #include "../../tir/transform/ir_utils.h" |
| |
| namespace tvm { |
| namespace s_tir { |
| using namespace tvm::tir; |
| |
| class CollectManagedAllocations : public StmtExprVisitor { |
| public: |
| void VisitStmt_(const SBlockNode* op) final { |
| for (const auto& buf : op->alloc_buffers) { |
| managed_allocations.insert(buf->data.get()); |
| } |
| for (const auto& buf : op->match_buffers) { |
| managed_allocations.insert(buf->buffer->data.get()); |
| } |
| StmtExprVisitor::VisitStmt_(op); |
| } |
| |
| /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by |
| * BufferAllocationLocator. */ |
| std::unordered_set<const VarNode*> managed_allocations; |
| }; |
| |
| /*! \brief Collect the allocate buffer order. */ |
| class BufferAllocateOrderCollector : public StmtExprVisitor { |
| public: |
| static ffi::Array<Buffer> Collect(const PrimFunc& func) { |
| BufferAllocateOrderCollector collector; |
| for (const auto& kv : func->buffer_map) { |
| collector.buffer_alloc_recorder_.push_back(kv.second); |
| } |
| collector(func->body); |
| return std::move(collector.buffer_alloc_recorder_); |
| } |
| |
| private: |
| bool find(const Buffer& buf) { |
| return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) != |
| buffer_alloc_recorder_.end(); |
| } |
| |
| void VisitStmt_(const SBlockNode* op) final { |
| for (const Buffer& buffer : op->alloc_buffers) { |
| buffer_alloc_recorder_.push_back(buffer); |
| } |
| // Also visit match_buffers to collect buffers that only appear in read and match_buffer |
| // regions. |
| for (const auto& region : op->match_buffers) { |
| if (!find(region->source->buffer)) { |
| buffer_alloc_recorder_.push_back(region->source->buffer); |
| } |
| } |
| |
| StmtExprVisitor::VisitStmt_(op); |
| } |
| |
| void VisitExpr_(const BufferLoadNode* op) final { |
| if (!find(op->buffer)) { |
| buffer_alloc_recorder_.push_back(op->buffer); |
| } |
| StmtExprVisitor::VisitExpr_(op); |
| } |
| |
| void VisitStmt_(const BufferStoreNode* op) final { |
| if (!find(op->buffer)) { |
| buffer_alloc_recorder_.push_back(op->buffer); |
| } |
| StmtExprVisitor::VisitStmt_(op); |
| } |
| |
| /*! \brief The buffer allocated order recorder. */ |
| ffi::Array<Buffer> buffer_alloc_recorder_; |
| }; |
| |
| class BufferAllocationLocator : public StmtExprMutator { |
| public: |
| explicit BufferAllocationLocator(const PrimFunc& func) { |
| ffi::Map<Buffer, ffi::Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func); |
| // The buffer_alloc_recorder Array is used to keep the buffer allocation order |
| // since the buffer_lca Map is unordered. |
| ffi::Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); |
| std::unordered_set<const VarNode*> arg_buffer_vars; |
| CollectManagedAllocations collector; |
| collector(func->body); |
| managed_allocations_ = collector.managed_allocations; |
| |
| for (const auto& kv : func->buffer_map) { |
| const Buffer& buffer = kv.second; |
| arg_buffer_vars.emplace(buffer->data.get()); |
| buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| // create buffers to be allocated at each stmts |
| for (const auto& buffer : buffer_alloc_recorder) { |
| auto it = buffer_lca.find(buffer); |
| if (it != buffer_lca.end()) { |
| const StmtNode* stmt = (*it).second.get(); |
| if (arg_buffer_vars.count(buffer->data.get())) { |
| continue; |
| } |
| if (managed_allocations_.count(buffer->data.get())) { |
| alloc_buffers_[stmt].push_back(buffer); |
| } |
| buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| } |
| } |
| |
| private: |
| Stmt VisitStmt_(const ForNode* op) final { |
| auto it = alloc_buffers_.find(op); |
| if (it == alloc_buffers_.end()) { |
| return StmtMutator::VisitStmt_(op); |
| } |
| for (const Buffer& buf : it->second) { |
| buffer_data_to_buffer_.Set(buf->data, buf); |
| } |
| auto node = Downcast<For>(StmtMutator::VisitStmt_(op)); |
| |
| ffi::Array<Buffer> new_block_alloc_bufs; |
| for (const Buffer& buf : it->second) { |
| if (managed_allocations_.count(buf->data.get())) { |
| buffer_data_to_buffer_.erase(buf->data); |
| new_block_alloc_bufs.push_back(buf); |
| } |
| } |
| |
| if (new_block_alloc_bufs.size()) { |
| node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs); |
| } |
| |
| return node; |
| } |
| |
| Stmt VisitStmt_(const SBlockNode* op) final { |
| ICHECK(!op->init.defined()); |
| ffi::Array<Buffer> alloc_buffers; |
| auto it = alloc_buffers_.find(op); |
| if (it != alloc_buffers_.end()) { |
| alloc_buffers = it->second; |
| for (const Buffer& buf : it->second) { |
| buffer_data_to_buffer_.Set(buf->data, buf); |
| } |
| } |
| for (const MatchBufferRegion match_buffer : op->match_buffers) { |
| const Var& target_var = match_buffer->buffer->data; |
| const Var& source_var = match_buffer->source->buffer->data; |
| ICHECK(buffer_data_to_buffer_.count(source_var)); |
| buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); |
| } |
| Stmt stmt = StmtMutator::VisitStmt_(op); |
| op = stmt.as<SBlockNode>(); |
| ICHECK(op != nullptr); |
| |
| // No longer consider buffers created by match_buffer inside the block when updating access |
| // region. |
| for (const MatchBufferRegion match_buffer : op->match_buffers) { |
| const Var& target_var = match_buffer->buffer->data; |
| buffer_data_to_buffer_.erase(target_var); |
| } |
| // No longer consider buffers allocated inside the block when updating access region. |
| if (it != alloc_buffers_.end()) { |
| for (const Buffer& buf : it->second) { |
| buffer_data_to_buffer_.erase(buf->data); |
| } |
| } |
| |
| ObjectPtr<SBlockNode> n = CopyOnWrite(op); |
| n->alloc_buffers = std::move(alloc_buffers); |
| // Erase buffer allocated inside the block from access region. |
| n->reads = RemoveRedundantBufferRegion(n->reads); |
| n->writes = RemoveRedundantBufferRegion(n->writes); |
| return Stmt(n); |
| } |
| |
| Stmt InjectOpaqueBlock(Stmt body, const ffi::Array<Buffer>& alloc_buffers) { |
| ICHECK(!alloc_buffers.empty()); |
| SBlock opaque_block(/*iter_vars=*/{}, |
| /*reads=*/{}, |
| /*writes=*/{}, |
| /*name_hint=*/"", |
| /*body=*/std::move(body), |
| /*init=*/std::nullopt, |
| /*alloc_buffers=*/alloc_buffers); |
| ObjectPtr<SBlockNode> n = CopyOnWrite(opaque_block.get()); |
| ffi::Array<ffi::Array<BufferRegion>> access = |
| GetSBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); |
| n->reads = access[0]; |
| n->writes = access[1]; |
| SBlockRealize realize({}, Bool(true), SBlock(n)); |
| return realize; |
| } |
| |
| ffi::Array<BufferRegion> RemoveRedundantBufferRegion( |
| const ffi::Array<BufferRegion>& region) const { |
| ffi::Array<BufferRegion> result; |
| for (const BufferRegion& buffer_region : region) { |
| if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { |
| result.push_back(buffer_region); |
| } |
| } |
| return result; |
| } |
| |
| /*! \brief The map from stmt to the buffers to be allocated under it. */ |
| std::unordered_map<const StmtNode*, ffi::Array<Buffer>> alloc_buffers_; |
| /*! \brief The buffer already allocated during recursive visiting. */ |
| ffi::Map<Var, Buffer> buffer_data_to_buffer_; |
| /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ |
| std::unordered_set<const VarNode*> managed_allocations_; |
| }; |
| |
| namespace transform { |
| |
| Pass PlanAndUpdateBufferAllocationLocation() { |
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
| auto fptr = f.CopyOnWrite(); |
| BufferAllocationLocator locator(f); |
| fptr->body = locator(fptr->body); |
| return f; |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "s_tir.PlanAndUpdateBufferAllocationLocation", {}); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("s_tir.transform.PlanAndUpdateBufferAllocationLocation", |
| PlanAndUpdateBufferAllocationLocation); |
| } |
| |
| } // namespace transform |
| |
| } // namespace s_tir |
| } // namespace tvm |