| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file inject_software_pipeline.cc |
| * \brief Transform annotated loops into pipelined one that parallelize producers and consumers |
| */ |
| #include <tvm/target/target.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/transform.h> |
| |
| #include <unordered_set> |
| |
| #include "../../support/utils.h" |
| #include "../schedule/utils.h" |
| #include "./ir_utils.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| namespace software_pipeline { |
| |
| /*! |
| * \brief Create a block and infer the access region with the given body. |
| * |
| * The result is a opaque block that doesn't contain any block iter vars. In case the body is a |
| * block realize without predicate, it is unnecessary to create a new block, the block of the block |
| * realize will be returned. |
| * |
| * \param body The body of the block. |
| * \param buffer_data_to_buffer The map from buffer data to buffer. |
| * \return The result block. |
| */ |
| Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) { |
| if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) { |
| if (is_one(block_realize->predicate)) { |
| // no need to create a new block |
| return block_realize->block; |
| } |
| } |
| Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); |
| Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); |
| BlockNode* n = block.CopyOnWrite(); |
| n->reads = access[0]; |
| n->writes = access[1]; |
| return block; |
| } |
| |
| /*! Structure that represents the provided annotation per block or loop. */ |
| struct PipelineAnnotation { |
| int stage; |
| int order; |
| bool async; |
| }; |
| |
| using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>; |
| |
| struct BufferAccessInfo { |
| int def = -1; // the defining stage of the buffer |
| int use = -1; // the last using stage of the buffer |
| }; |
| |
| class PipelineOpaqueAccessRewriter { |
| public: |
| /*! |
| * \brief Constructor |
| * \param buffer_data_to_buffer The map from buffer data to buffer. |
| * \param buffer_remap The map from original buffer to the buffer with updated shape for |
| * multi-versioning in the software pipeline. |
| * \param pipeline_loop The original loop to be software pipelined. |
| * \param fragment_info Information about tensor core fragment |
| */ |
| PipelineOpaqueAccessRewriter( |
| const Map<Var, Buffer>& buffer_data_to_buffer, const Map<Buffer, Buffer>& buffer_remap, |
| const For& pipeline_loop, |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) |
| : buffer_data_to_buffer_(buffer_data_to_buffer), |
| buffer_remap_(buffer_remap), |
| pipeline_loop_(pipeline_loop), |
| fragment_info_(fragment_info) {} |
| |
| PrimExpr Rewrite(const Call& call) { |
| // Intrinsic calls should be handled explicitly here as they are opaque accesses to |
| // buffer. |
| static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); |
| static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); |
| static const auto& mma_sync = builtin::tvm_mma_sync(); |
| static const auto& access_ptr = builtin::tvm_access_ptr(); |
| static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix(); |
| static const auto& ptx_mma = builtin::ptx_mma(); |
| if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { |
| const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0])); |
| auto it = buffer_remap_.find(buffer); |
| if (it != buffer_remap_.end()) { |
| Array<PrimExpr> new_args = call->args; |
| const Buffer& new_buffer = (*it).second; |
| new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); |
| return Call(call->dtype, call->op, new_args, call->span); |
| } |
| } else if (call->op.same_as(mma_sync)) { |
| Array<PrimExpr> new_args = call->args; |
| for (int i = 0; i < 4; i++) { |
| const Var& buffer_var = Downcast<Var>(call->args[i * 2]); |
| const PrimExpr& index = call->args[i * 2 + 1]; |
| const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); |
| auto it = buffer_remap_.find(buffer); |
| if (it != buffer_remap_.end()) { |
| PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); |
| new_args.Set(i * 2 + 1, new_index); |
| } |
| } |
| return Call(call->dtype, call->op, new_args, call->span); |
| } else if (call->op.same_as(access_ptr)) { |
| return RewriteBufferAccess(call, {1}); |
| } else if (call->op.same_as(ptx_mma)) { |
| return RewriteBufferAccess(call, {6, 8, 10}); |
| } else if (call->op.same_as(ptx_ldmatrix)) { |
| return RewriteBufferAccess(call, {3}); |
| } |
| return call; |
| } |
| |
| private: |
| int GetWmmaFragmentSize(const Buffer& buffer) { |
| auto it = fragment_info_.find(buffer->data.get()); |
| ICHECK(it != fragment_info_.end()); |
| const FragmentInfo& info = (*it).second; |
| return info.GetSize(); |
| } |
| |
| PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, |
| const PrimExpr& old_index) { |
| PrimExpr new_buffer_offset = old_index; |
| |
| int fragment_size = GetWmmaFragmentSize(old_buffer); |
| PrimExpr offset = |
| floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, |
| make_const(DataType::Int(32), 1), old_buffer->shape), |
| fragment_size); |
| new_buffer_offset += |
| floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; |
| return new_buffer_offset; |
| } |
| |
| PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) { |
| auto product = [](const Array<PrimExpr>& input) { |
| return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, |
| make_const(DataType::Int(32), 1), input); |
| }; |
| Array<PrimExpr> new_args = call->args; |
| for (int i : arg_indices) { |
| const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i])); |
| auto it = buffer_remap_.find(buffer); |
| if (it != buffer_remap_.end()) { |
| const Buffer& new_buffer = (*it).second; |
| const PrimExpr& old_index = call->args[i + 1]; |
| PrimExpr offset; |
| if (new_buffer->strides.empty()) { |
| offset = product(buffer->shape); |
| } else { |
| offset = new_buffer->strides[0]; |
| } |
| PrimExpr new_index = |
| old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; |
| new_args.Set(i + 1, new_index); |
| } |
| } |
| return Call(call->dtype, call->op, new_args, call->span); |
| } |
| |
| const Map<Var, Buffer>& buffer_data_to_buffer_; |
| const Map<Buffer, Buffer>& buffer_remap_; |
| const For& pipeline_loop_; |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_; |
| }; |
| |
| /*! |
| * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices |
| * of the remapped buffer to select the version corresponding to the pipeline stage. |
| */ |
| class PipelineBodyRewriter : public StmtExprMutator { |
| public: |
| /*! |
| * \brief Constructor of PipelineBodyRewriter. |
| * \param buffer_data_to_buffer The map from buffer data to buffer. |
| * \param buffer_remap The map from original buffer to the buffer with updated shape for |
| * multi-versioning in the software pipeline. |
| * \param pipeline_loop The original loop to be software pipelined. |
| * \param access_all_versions Whether all versions the buffers in the software pipeline are |
| * accessed. This will be used to update block access region. In the prologue and epilogue |
| * of a two-stage software pipeline, only one version of these buffers are accessed. |
| * \param fragment_info Information about tensor core fragment |
| */ |
| PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer, |
| const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop, |
| bool access_all_versions, |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) |
| : buffer_data_to_buffer_(buffer_data_to_buffer), |
| buffer_remap_(buffer_remap), |
| pipeline_loop_(pipeline_loop), |
| access_all_versions_(access_all_versions), |
| opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, |
| fragment_info) {} |
| |
| private: |
| BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { |
| auto it = buffer_remap_.find(buffer_region->buffer); |
| if (it != buffer_remap_.end()) { |
| Region new_region = buffer_region->region; |
| const Buffer& new_buffer = (*it).second; |
| // For pipeline buffers, relax the access region of the first dimension to full extent |
| // if access_all_versions == true |
| Range accessed_version = |
| access_all_versions_ |
| ? Range::FromMinExtent(0, new_buffer->shape[0]) |
| : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), |
| new_buffer->shape[0]), |
| Integer(1)); |
| new_region.insert(new_region.begin(), accessed_version); |
| return BufferRegion(new_buffer, new_region); |
| } |
| return buffer_region; |
| } |
| |
| Stmt VisitStmt_(const BlockNode* op) final { |
| for (const Buffer& alloc_buffer : op->alloc_buffers) { |
| buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); |
| } |
| Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
| BlockNode* n = block.CopyOnWrite(); |
| n->reads.MutateByApply([this](const BufferRegion& buffer_region) { |
| return RewritePipelineBufferRegion(buffer_region); |
| }); |
| n->writes.MutateByApply([this](const BufferRegion& buffer_region) { |
| return RewritePipelineBufferRegion(buffer_region); |
| }); |
| for (const Buffer& alloc_buffer : op->alloc_buffers) { |
| buffer_data_to_buffer_.erase(alloc_buffer->data); |
| } |
| return std::move(block); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| auto it = buffer_remap_.find(store->buffer); |
| if (it == buffer_remap_.end()) { |
| return std::move(store); |
| } |
| const Buffer& new_buffer = (*it).second; |
| auto* n = store.CopyOnWrite(); |
| n->buffer = new_buffer; |
| PrimExpr version = |
| floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); |
| n->indices.insert(n->indices.begin(), version); |
| return std::move(store); |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| auto it = buffer_remap_.find(load->buffer); |
| if (it == buffer_remap_.end()) { |
| return std::move(load); |
| } |
| const Buffer& new_buffer = (*it).second; |
| auto* n = load.CopyOnWrite(); |
| n->buffer = new_buffer; |
| PrimExpr version = |
| floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); |
| n->indices.insert(n->indices.begin(), version); |
| return std::move(load); |
| } |
| |
| PrimExpr VisitExpr_(const CallNode* op) final { |
| Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); |
| return opaque_access_rewriter_.Rewrite(call); |
| } |
| |
| Map<Var, Buffer> buffer_data_to_buffer_; |
| Map<Buffer, Buffer> buffer_remap_; |
| For pipeline_loop_; |
| bool access_all_versions_; |
| PipelineOpaqueAccessRewriter opaque_access_rewriter_; |
| }; |
| |
| /*! |
| * \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one. |
| */ |
| class PipelineRewriter : public StmtExprMutator { |
| public: |
| static Stmt Rewrite( |
| Map<Var, Buffer> buffer_data_to_buffer, |
| const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, |
| const Array<Buffer> pipeline_allocs, const For& pipeline_loop, |
| const PipelineInfo& pipeline_info, |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) { |
| PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, |
| pipeline_info, fragment_info); |
| return rewriter.BuildPipeline(); |
| } |
| |
| private: |
| PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, |
| const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, |
| const Array<Buffer>& pipeline_allocs, const For& pipeline_loop, |
| const PipelineInfo& pipeline_info, |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) |
| |
| : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), |
| double_buffers_(double_buffers), |
| pipeline_allocs_(pipeline_allocs), |
| pipeline_loop_(pipeline_loop), |
| pipeline_info_(pipeline_info), |
| fragment_info_(fragment_info) {} |
| |
| Stmt BuildPipeline() { |
| // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions |
| // need to maintain for each buffer. |
| std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos = |
| GetBufferAccessInfo(); |
| for (const Buffer& buffer : pipeline_allocs_) { |
| int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); |
| if (num_versions > 1) { |
| buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); |
| } |
| } |
| |
| ordered_stmts_.resize(pipeline_info_.size()); |
| for (const auto& pair : pipeline_info_) { |
| const Block& block = pair.first; |
| int order = pair.second.order; |
| ordered_stmts_.Set(order, block); |
| } |
| |
| // Step 2: Emit the pipeline prologue, body and epilogue. |
| Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); |
| Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, |
| pipeline_loop_->min + pipeline_loop_->extent, false); |
| Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, |
| pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); |
| |
| SeqStmt stmt = SeqStmt({prologue, body, epilogue}); |
| |
| // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. |
| Array<Buffer> alloc_buffers; |
| for (const auto& alloc : pipeline_allocs_) { |
| alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); |
| buffer_data_to_buffer_.erase(alloc->data); |
| } |
| Block block = MakeBlock(stmt, buffer_data_to_buffer_); |
| block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); |
| return BlockRealize({}, Bool(true), block); |
| } |
| |
| private: |
| /*! |
| * \brief Analyze accesses to the buffers in the software pipeline. |
| * |
| * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which |
| * can be used to compute the number of versions needed to maintain after rewriting. |
| */ |
| std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> |
| GetBufferAccessInfo() { |
| std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos; |
| for (const auto& pair : pipeline_info_) { |
| const Block& block = pair.first; |
| int stage = pair.second.stage; |
| max_stage_ = std::max(max_stage_, stage); |
| |
| for (const BufferRegion& write : block->writes) { |
| if (!infos.count(write->buffer)) { |
| infos.emplace(write->buffer, BufferAccessInfo{}); |
| } |
| auto& info = infos.at(write->buffer); |
| if (info.def == -1) { |
| info.def = stage; |
| } else { |
| info.def = std::min(info.def, stage); |
| } |
| } |
| |
| for (const BufferRegion& read : block->reads) { |
| if (!infos.count(read->buffer)) { |
| infos.emplace(read->buffer, BufferAccessInfo{}); |
| } |
| auto& info = infos.at(read->buffer); |
| info.use = std::max(info.use, stage); |
| } |
| } |
| return infos; |
| } |
| |
| /*! |
| * \brief Check whether two regions have intersections. |
| * \param region1 The first region. |
| * \param region2 The second region. |
| * \return Whether region1 and region2 have intersections. |
| */ |
| bool MayConflict(Region region1, Region region2) { |
| ICHECK(region1.size() == region2.size()); |
| for (size_t i = 0; i < region1.size(); i++) { |
| Range dim1 = region1[i]; |
| Range dim2 = region2[i]; |
| auto int_set1 = arith::IntSet::FromRange(dim1); |
| auto int_set2 = arith::IntSet::FromRange(dim2); |
| if (arith::Intersect({int_set1, int_set2}).IsNothing()) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /*! |
| * \brief Compute the number of versions need to maintain for buffer accessed in the software |
| * pipeline. |
| * |
| * This method applies liveness analysis to the target buffer to compute the number of versions |
| * need to maintain during the software pipeline. |
| * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the |
| * result of the analysis. Additional double buffering in the software pipeline can be useful |
| * to eliminate synchronizations in GPU devices. |
| * |
| * \param buffer The target buffer |
| * \param buffer_info The access information of the target buffer. |
| * \return The number of versions required for the target buffer. |
| */ |
| int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { |
| if (buffer_info.def == -1) { |
| // Keep the original number of versions as buffers defined outside the software pipeline |
| // should not be mutated. |
| return 1; |
| } |
| |
| // `use - def + 1` is a upper bound of the needed versions |
| // We optimize a few case where the number of versions can be smaller than the upper bound |
| int num_versions = buffer_info.use - buffer_info.def + 1; |
| if (num_versions == 2) { |
| // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when |
| // these exists a reader block_i and a writer block_j such that |
| // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions |
| // of block_i and block_j overlap. |
| bool need_multi_version = false; |
| for (const auto& pair1 : pipeline_info_) { |
| const Block& writer_block = pair1.first; |
| const auto& writer_info = pair1.second; |
| |
| auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), |
| [&](const BufferRegion& buffer_region) { |
| return buffer_region->buffer.same_as(buffer); |
| }); |
| if (it1 == writer_block->writes.end()) { |
| continue; |
| } |
| |
| for (const auto& pair2 : pipeline_info_) { |
| const Block& reader_block = pair2.first; |
| const auto& reader_info = pair2.second; |
| auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), |
| [&](const BufferRegion& buffer_region) { |
| return buffer_region->buffer.same_as(buffer); |
| }); |
| if (it2 == reader_block->reads.end()) { |
| continue; |
| } |
| if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && |
| MayConflict((*it1)->region, (*it2)->region)) { |
| need_multi_version = true; |
| break; |
| } |
| } |
| } |
| if (!need_multi_version) { |
| num_versions = 1; |
| } |
| } |
| if (num_versions == 1 && double_buffers_.count(buffer)) { |
| num_versions = 2; |
| } |
| return num_versions; |
| } |
| |
| /*! |
| * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined |
| * accesses. |
| * \param buffer The buffer to be resized. |
| * \param num_versions The number of versions to keep. |
| * \return The resized buffer. |
| */ |
| Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { |
| ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); |
| new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); |
| if (new_buffer->strides.size()) { |
| ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); |
| PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; |
| new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); |
| } |
| return Buffer(new_buffer); |
| } |
| |
| // Per-stage states that need to be tracked across pipeline prologue, body, and epilogue. |
| struct AsyncStateGlobal { |
| // Buffers that this stage asynchronously writes. |
| std::unordered_set<const BufferNode*> dst_buffers; |
| // An imaginary index that the latest async operation associated with this stage has written |
| // into. Only valid if all associated predicates are true, so that we can count the number of |
| // async invocations exactly. When it is valid, it is the "sum of extents of loops that have |
| // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This |
| // is only needed to compute wait count for epilogue without async producers. |
| Optional<PrimExpr> producer_head{PrimExpr(-1)}; |
| |
| bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } |
| }; |
| |
| // Per-stage states that are local to each of pipeline prologue, body, and epilogue. |
| struct AsyncStateLocal { |
| struct { |
| // The index into a list of blocks, where async_wait_queue should be attached at the |
| // beginning. |
| int insert_before; |
| // in_flight_count would be a more precise name, but the implementation uses wait_count for |
| // brevity. |
| PrimExpr wait_count{nullptr}; |
| |
| bool valid() const { return wait_count.defined(); } |
| } pending_wait; |
| |
| // Destination buffers of async operations that have been encountered so far in the loop |
| // |
| // for (size_t i = 0; i < new_blocks.size(); ++i) { |
| // ... |
| // } |
| // |
| // This is for tracking which async operations have been issued at the "current" iteration, up |
| // until a point where we encounter a consumer of async result buffers. This is used to decide |
| // if the producer_head of each buffer points to a copy written in the current or previous |
| // iteration. |
| std::unordered_set<const BufferNode*> seen; |
| |
| // A symbolic expression representing the index the latest async operation associated with this |
| // stage has written into, at the "current" iteration. |
| Optional<PrimExpr> producer_head; |
| // The predicate of BlockRealize containing the async operation of this stage. |
| Optional<PrimExpr> predicate; |
| // Indices into a list of blocks, where async_commit_queue scope should be attached. |
| // If multiple async producers are interleaved with their consumer in between, we need separate |
| // async_commit_queue for each producer. Thus, we need multiple sets of indices. |
| std::vector<std::vector<size_t>> commit_groups; |
| |
| // This is set to true when we reach a stage that consumes this async stage. |
| bool consumed{false}; |
| }; |
| |
| /*! Structure holding intermediate information for pipeline loop rewriting. */ |
| struct RewrittenBlockInfo { |
| int stage; |
| PrimExpr predicate; |
| Block block; |
| PrimExpr access_index; |
| bool is_async; |
| }; |
| |
| // Determine where to insert async_wait and the corresponding wait count. |
| void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks, |
| arith::Analyzer* ana_normalized, |
| const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group, |
| std::map<int, AsyncStateLocal>* async_states_local) { |
| for (size_t i = 0; i < new_blocks.size(); ++i) { |
| if (new_blocks[i].is_async) { |
| // Record the fact that we have encountered these write buffers. |
| for (auto write_region : new_blocks[i].block->writes) { |
| (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get()); |
| } |
| } |
| |
| int producer_stage_idx = -1; |
| for (auto read_region : new_blocks[i].block->reads) { |
| for (auto kv : async_states) { |
| if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) { |
| // Found an earlier stage where read_region->buffer was asynchronously written |
| ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) |
| << "A dependency on multiple async stages is not supported"; |
| producer_stage_idx = kv.first; |
| } |
| } |
| } |
| |
| if (producer_stage_idx == -1) continue; |
| |
| // The following logic has become complicated to handle case like this: |
| // |
| // for i in range(13): |
| // # Stage 0 |
| // async_commit_queue(0): |
| // async_scope: |
| // A_shared[(i + 3) % 4] = A[...] |
| // |
| // |
| // # Stage 1 |
| // async_wait_queue(0, 5): |
| // compute(A_shared[i], B_shared[i]) |
| // |
| // # Stage 0 |
| // async_commit_queue(0) |
| // async_scope: |
| // B_shared[(i + 3) % 4] = B[...] |
| // |
| // |
| // Here, multiple async producers in the same stage are interleaved with their consumer in |
| // between. Since each buffer is associated with different commit groups, the wait_count |
| // before the consumer should be bigger than the simpler case: |
| // |
| // for i in range(13): |
| // # Stage 0 |
| // async_commit_queue(0): |
| // async_scope: |
| // A_shared[(i + 3) % 4] = A[...] |
| // B_shared[(i + 3) % 4] = B[...] |
| // |
| // # Stage 1 |
| // async_wait_queue(0, 3): |
| // compute(A_shared[i], B_shared[i]) |
| // |
| // The correct wait_count can be determined by considering each commit group separately, and |
| // summing "per-commit" wait_counts. |
| // |
| // From A_shared's perspective, it allows for (i + 3) - i async commit groups to be in |
| // flight while from B_shared's perspective, the producer head at compute points to the copy |
| // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The |
| // sum of the two wait_counts gives 5. |
| |
| auto& dep_local_state = (*async_states_local)[producer_stage_idx]; |
| const auto num_commit_group = dep_local_state.commit_groups.size(); |
| std::vector<Optional<PrimExpr>> producer_head_per_commit; |
| |
| if (num_commit_group == 0) { |
| // Epilogue, no async producer. Since "local" producer_head is not available, use |
| // "global" producer_head. |
| ICHECK(!dep_local_state.producer_head); |
| producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head); |
| } else { |
| ICHECK(dep_local_state.producer_head); |
| std::vector<bool> need_wait_count(num_commit_group, true); |
| |
| for (auto read_region : new_blocks[i].block->reads) { |
| if (!async_states[producer_stage_idx].writes(read_region->buffer)) continue; |
| auto commit_group_id = buffer_to_commit_group.at(read_region->buffer.get()); |
| if (!need_wait_count[commit_group_id]) continue; |
| |
| if (!dep_local_state.seen.count(read_region->buffer.get())) { |
| // Multiple async producers interleaved: The most recent async write is from the |
| // previous iteration. This is the B_shared case above. |
| producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1); |
| } else { |
| // Normal case |
| producer_head_per_commit.push_back(dep_local_state.producer_head.value()); |
| } |
| |
| need_wait_count[commit_group_id] = false; |
| } |
| } |
| |
| auto wait_count = [=, &ana_normalized]() { |
| auto sum = PrimExpr(0); |
| for (auto producer_head : producer_head_per_commit) { |
| if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { |
| // Here, new_blocks[i].access_index corresponds to "consumer_head". |
| // The difference of producer_head and consumer_head is precisely the number of |
| // async commit groups that can still be in flight after this wait. |
| sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index); |
| } else { |
| // The precise count cannot be determined, give up. |
| return PrimExpr(0); |
| } |
| } |
| return sum; |
| }(); |
| |
| auto& pending_wait = dep_local_state.pending_wait; |
| |
| if (!pending_wait.valid()) { |
| pending_wait = {static_cast<int>(i), wait_count}; |
| } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { |
| // Coalesce multiple wait_queue if the later one allows fewer in-flight ops. |
| pending_wait = {pending_wait.insert_before, wait_count}; |
| } |
| } |
| } |
| |
| // Given pipelined blocks and async-related information, generate final loop statements with async |
| // scopes (if any). |
| Array<Stmt> CompletePipelineLoopStatements( |
| const std::vector<RewrittenBlockInfo>& blocks, |
| const std::map<int, AsyncStateLocal>& async_states_local, |
| arith::Analyzer* ana_normalized) const { |
| std::vector<RewrittenBlockInfo> new_blocks = blocks; |
| std::vector<int> commit_group_indices(new_blocks.size(), -1); |
| for (const auto& kv : async_states_local) { |
| const int stage_id = kv.first; |
| const AsyncStateLocal& state = kv.second; |
| |
| if (!state.commit_groups.empty()) { |
| for (size_t i = 0; i < state.commit_groups.size(); ++i) { |
| for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { |
| ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); |
| commit_group_indices[state.commit_groups[i][0] + j] = stage_id; |
| } |
| } |
| } |
| |
| if (state.pending_wait.valid()) { |
| auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { |
| auto& block = new_blocks[i].block; |
| BlockNode* n = block.CopyOnWrite(); |
| auto zero = make_zero(DataType::Int(32)); |
| n->body = |
| AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, |
| AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body)); |
| }; |
| |
| if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) { |
| // If the async operation that this wait_queue is waiting on is predicated, and we cannot |
| // prove that the predicate is always true, the precise wait count is only valid |
| // at iterations where the predicate is true; |
| auto wait_count = Call(DataType::Int(32), builtin::if_then_else(), |
| {state.predicate.value(), state.pending_wait.wait_count, 0}); |
| attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count); |
| } else { |
| attach_wait_scope(state.pending_wait.insert_before, stage_id, |
| state.pending_wait.wait_count); |
| } |
| } |
| } |
| |
| Array<Stmt> stmts; |
| |
| for (size_t i = 0; i < new_blocks.size();) { |
| if (commit_group_indices[i] == -1) { |
| // A synchrnous block, not part of any commit group |
| stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); |
| ++i; |
| } else { |
| Array<Stmt> group_bodies; |
| auto stage_id = commit_group_indices[i]; |
| auto predicate = new_blocks[i].predicate; |
| for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { |
| ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) |
| << "Predicates in the same stage are expected to be identical"; |
| group_bodies.push_back(new_blocks[i].block->body); |
| } |
| auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; |
| auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), |
| tir::attr::async_commit_queue_scope, stage_id, body); |
| auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); |
| stmts.push_back(BlockRealize({}, predicate, new_block)); |
| } |
| } |
| |
| return stmts; |
| } |
| |
| /*! |
| * \brief Emit the pipeline loop in the given range. |
| * \param start The start of the range |
| * \param end The end of the range |
| * \param unroll_loop Whether the loop should be unrolled. |
| * \return The result loop. |
| */ |
| Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { |
| PrimExpr new_loop_var; |
| PrimExpr extent = end - start; |
| |
| auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; |
| |
| if (!analyzer_.CanProve(extent > 0)) { |
| return make_nop(); |
| } |
| bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); |
| if (is_unit_loop) { |
| new_loop_var = start; // use constants as the loop var for unit loops |
| } else { |
| new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); |
| analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); |
| } |
| |
| // In contrast to analyzer_ which is bound to [start, end), this one is bound to |
| // the "normalized" range, [pipeline_loop_->min, extent). |
| arith::Analyzer ana_normalized; |
| if (!is_unit_loop) { |
| ana_normalized.Bind(Downcast<Var>(new_loop_var), Range(pipeline_loop_->min, extent)); |
| } |
| |
| std::vector<RewrittenBlockInfo> new_blocks; |
| |
| // Async related |
| std::map<int, AsyncStateLocal> async_states_local; |
| std::unordered_map<const BufferNode*, int> buffer_to_commit_group; |
| |
| for (const Block& block : ordered_stmts_) { |
| int stage = pipeline_info_.at(block).stage; |
| PrimExpr skewed_loop_var = new_loop_var - stage; |
| PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && |
| (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); |
| if (analyzer_.CanProve(!inbound)) { |
| continue; |
| } |
| Block new_block = Downcast<Block>(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, |
| pipeline_loop_, max_stage_ != 1, |
| fragment_info_)(block)); |
| |
| PrimExpr delta = start - pipeline_loop_->min; |
| // This variable corresponds to |
| // - "producer_head" if this stage is an async producer |
| // - "consumer_head" if this stage reads from asynchronously written buffers. |
| PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; |
| |
| // Adjust the block predicate and the body according to the final loop bound |
| // [pipeline_loop_->min, extent). |
| if (!is_unit_loop) { |
| Var loop_iter = Downcast<Var>(new_loop_var); |
| inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); |
| } |
| |
| new_block = Downcast<Block>( |
| Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); |
| |
| if (pipeline_info_[block].async) { |
| auto& local_state = async_states_local[stage]; |
| |
| int commit_group_id = -1; |
| if (local_state.commit_groups.empty() || local_state.consumed) { |
| // consumed == true means there is already a consumer stage waiting for an |
| // eariler async operation of this stage. In such cases, we make multiple commit_queue |
| // for this stage. |
| commit_group_id = local_state.commit_groups.size(); |
| local_state.commit_groups.push_back({new_blocks.size()}); |
| } else { |
| // This is the case when one commit_queue groups multiple async blocks. |
| // with commit_queue(stage): |
| // async_scope: |
| // A_shared[...] = ... |
| // async_scope: |
| // B_shared[...] = ... |
| |
| commit_group_id = local_state.commit_groups.size() - 1; |
| local_state.commit_groups.back().push_back(new_blocks.size()); |
| } |
| |
| for (auto write_region : new_block->writes) { |
| async_states[stage].dst_buffers.insert(write_region->buffer.get()); |
| buffer_to_commit_group[write_region->buffer.get()] = commit_group_id; |
| } |
| |
| local_state.producer_head = normalized_access_index; |
| |
| if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) { |
| local_state.predicate = inbound; |
| } else if (local_state.predicate) { |
| local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound); |
| } |
| |
| BlockNode* n = new_block.CopyOnWrite(); |
| n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); |
| } |
| |
| new_blocks.push_back( |
| {stage, inbound, new_block, normalized_access_index, pipeline_info_[block].async}); |
| |
| for (auto read_region : new_block->reads) { |
| for (auto kv : async_states) { |
| int producer_stage_id = kv.first; |
| if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) { |
| async_states_local[producer_stage_id].consumed = true; |
| } |
| } |
| } |
| } |
| |
| PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local); |
| auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized); |
| |
| Stmt new_loop{nullptr}; |
| |
| if (stmts.empty()) { |
| return make_nop(); |
| } |
| if (stmts.size() == 1) { |
| new_loop = stmts[0]; |
| } else { |
| new_loop = SeqStmt(stmts); |
| } |
| |
| if (!is_unit_loop) { |
| new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, |
| unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); |
| } |
| |
| // Update producer heads in the global async states. |
| for (const auto& kv : async_states_local) { |
| const int stage_id = kv.first; |
| const AsyncStateLocal& state = kv.second; |
| |
| if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && |
| async_states[stage_id].producer_head) { |
| // Advance the "global" producer head if it is still valid and we know exactly how much we |
| // can increment |
| async_states[stage_id].producer_head = |
| async_states[stage_id].producer_head.value() + extent; |
| } else { |
| // Otherwise, invalidate the global producer head |
| async_states[stage_id].producer_head = NullOpt; |
| } |
| } |
| |
| return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); |
| } |
| |
| arith::Analyzer analyzer_; |
| Map<Var, Buffer> buffer_data_to_buffer_; |
| const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers_; |
| Array<Buffer> pipeline_allocs_; |
| For pipeline_loop_; |
| PipelineInfo pipeline_info_; |
| const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_; |
| int max_stage_ = -1; |
| Map<Buffer, Buffer> buffer_remap_; |
| Array<Block> ordered_stmts_; |
| std::map<int, AsyncStateGlobal> async_states; |
| }; |
| |
| /*! |
| * \brief Build the dependency graph among a array of blocks. |
| * \param[in] blocks The array of blocks. |
| * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the |
| * destination. |
| * \param[out] dep_dst2src Optional, a map to store dependency edges from the |
| * destination to the source. |
| */ |
| void BuildDependencyGraph( |
| const Array<Block>& blocks, |
| std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, |
| std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { |
| std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers; |
| |
| for (const Block& block : blocks) { |
| for (const BufferRegion& read : block->reads) { |
| auto it = buffer_writers.find(read->buffer->data); |
| if (it != buffer_writers.end()) { |
| for (const Block& writer : it->second) { |
| if (dep_src2dst != nullptr) { |
| (*dep_src2dst)[writer].push_back(block); |
| } |
| if (dep_dst2src != nullptr) { |
| (*dep_dst2src)[block].push_back(writer); |
| } |
| } |
| } |
| } |
| for (const BufferRegion& write : block->writes) { |
| buffer_writers[write->buffer->data].push_back(block); |
| } |
| } |
| } |
| |
| class PipelineInjector : private StmtExprMutator { |
| public: |
| static Stmt Inject(const PrimFunc& func) { |
| PipelineInjector injector; |
| for (const auto& kv : func->buffer_map) { |
| const Buffer& buffer = kv.second; |
| injector.buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body); |
| return injector(func->body); |
| } |
| |
| private: |
| PipelineInjector() = default; |
| |
| /*! |
| * \brief Check the pipeline satisfies the following conditions: |
| * 1. No conflicting order: The order of each statement should be unique. |
| * 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for |
| * dependency (e.g. read-after-write) from statement A to statement B, it requires: |
| * case 1: stage(A) < stage(B) |
| * case 2: stage(A) == stage(B) and order(A) < order(B) |
| */ |
| void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) { |
| std::unordered_set<int> used_orders; |
| std::unordered_map<int, int> stage_max_order; |
| std::unordered_map<int, const Block*> order_to_block; |
| std::unordered_map<const Block*, int> block_to_stage; |
| for (const Block& block : original_order) { |
| const auto& stmt_info = pipeline_info.at(block); |
| int order = stmt_info.order; |
| CHECK(!used_orders.count(order)) |
| << "ValueError: Two statements in the software pipeline cannot have the same order"; |
| used_orders.insert(order); |
| } |
| |
| std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; |
| BuildDependencyGraph(original_order, &dep_src2dst, nullptr); |
| |
| for (const auto& pair : dep_src2dst) { |
| const Block& src = pair.first; |
| const auto& src_info = pipeline_info.at(src); |
| const Array<Block>& dsts = pair.second; |
| for (const Block& dst : dsts) { |
| const auto& dst_info = pipeline_info.at(dst); |
| CHECK_LE(src_info.stage, dst_info.stage) |
| << "ValueError: statement " << dst << " in stage " << dst_info.stage |
| << " cannot depends on statement " << src << " in a later stage " << src_info.stage; |
| if (src_info.stage == dst_info.stage) { |
| CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " |
| "access dependency in the same stage of the " |
| "software pipeline cannot be reordered"; |
| } |
| } |
| } |
| } |
| |
| Stmt VisitStmt_(const ForNode* op) final { |
| // Step 1: Recursively rewrite the children first. |
| For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op)); |
| if (!HasPipelineAnnotation(op)) { |
| return std::move(for_node); |
| } |
| // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of |
| // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the |
| // child of the block. |
| Stmt pipeline_body{nullptr}; |
| Array<Buffer> pipeline_allocs; |
| if (const auto* realize = for_node->body.as<BlockRealizeNode>()) { |
| const auto& block = realize->block; |
| for (const auto& buffer : block->alloc_buffers) { |
| ICHECK(buffer->IsInstance<BufferNode>()); |
| buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| pipeline_body = block->body; |
| pipeline_allocs = block->alloc_buffers; |
| } else { |
| pipeline_body = for_node->body; |
| } |
| |
| const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>(); |
| CHECK(pipeline_body_seq) |
| << "ValueError: The body of the software pipeline should be SeqStmt, got " |
| << pipeline_body->GetTypeKey(); |
| |
| // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be |
| // converted into a block. |
| PipelineInfo pipeline_info; |
| Array<Block> original_order; // pipeline body blocks in the original order |
| |
| auto f_add_child = [&](const Stmt& child) { |
| original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); |
| }; |
| for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { |
| const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>(); |
| if (nested_block_realize && is_one(nested_block_realize->predicate) && |
| nested_block_realize->block->body->IsInstance<SeqStmtNode>()) { |
| const Block& nested_pipeline_block = nested_block_realize->block; |
| ICHECK( |
| nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered |
| for (const auto& buffer : nested_pipeline_block->alloc_buffers) { |
| pipeline_allocs.push_back(buffer); |
| buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>(); |
| for (size_t j = 0; j < nested_seq->seq.size(); j++) { |
| f_add_child(nested_seq->seq[j]); |
| } |
| } else { |
| f_add_child(pipeline_body_seq->seq[i]); |
| } |
| } |
| |
| auto pipeline_stages = |
| Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_stage)); |
| auto pipeline_orders = |
| Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_order)); |
| CHECK_EQ(pipeline_stages.size(), original_order.size()); |
| CHECK_EQ(pipeline_orders.size(), original_order.size()); |
| |
| std::unordered_set<int> pipeline_async_stages; |
| if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { |
| for (auto s : Downcast<Array<Integer>>(annot)) { |
| pipeline_async_stages.insert(s->value); |
| } |
| } |
| |
| for (size_t i = 0; i < pipeline_stages.size(); i++) { |
| int stage = static_cast<int>(pipeline_stages[i]->value); |
| bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end(); |
| PipelineAnnotation stage_order{stage, |
| /*order=*/static_cast<int>(pipeline_orders[i]->value), |
| is_async}; |
| pipeline_info.emplace(original_order[i], stage_order); |
| } |
| |
| ValidatePipelineBody(pipeline_info, original_order); |
| |
| // Step 4: Rewrite the pipeline body. |
| Stmt pipeline = |
| PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, pipeline_allocs, |
| GetRef<For>(op), pipeline_info, fragment_info_); |
| |
| if (const auto* realize = op->body.as<BlockRealizeNode>()) { |
| const auto& block = realize->block; |
| for (const auto& buffer : block->alloc_buffers) { |
| buffer_data_to_buffer_.erase(buffer->data); |
| } |
| } |
| return pipeline; |
| } |
| |
| /*! |
| * \brief Add buffer allocations to a block and update the write region of the block. |
| * \param n The block pointer to which the buffer allocations are added. |
| * \param alloc_buffers The buffer allocations to be added. |
| */ |
| void AddAllocBuffers(BlockNode* n, const Array<Buffer> alloc_buffers) { |
| for (const Buffer& alloc_buffer : alloc_buffers) { |
| n->alloc_buffers.push_back(alloc_buffer); |
| Region region; |
| region.reserve(alloc_buffer->shape.size()); |
| for (const PrimExpr& dim : alloc_buffer->shape) { |
| region.push_back(Range::FromMinExtent(0, dim)); |
| } |
| n->writes.push_back(BufferRegion(alloc_buffer, region)); |
| } |
| } |
| |
| Stmt VisitStmt_(const BlockNode* op) final { |
| for (const auto& buffer : op->alloc_buffers) { |
| buffer_data_to_buffer_.Set(buffer->data, buffer); |
| } |
| |
| auto it = op->annotations.find(attr::double_buffer_scope); |
| if (it != op->annotations.end()) { |
| int buffer_index = Downcast<Integer>((*it).second).IntValue(); |
| CHECK(buffer_index >= 0 && static_cast<size_t>(buffer_index) < op->writes.size()) |
| << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" |
| << buffer_index << " vs. " << op->writes.size() << ")"; |
| double_buffers.insert(op->writes[buffer_index]->buffer); |
| } |
| Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
| |
| for (const auto& buffer : op->alloc_buffers) { |
| buffer_data_to_buffer_.erase(buffer->data); |
| } |
| return std::move(block); |
| } |
| |
| bool HasPipelineAnnotation(const ForNode* op) const { |
| auto it1 = op->annotations.find(attr::software_pipeline_stage); |
| auto it2 = op->annotations.find(attr::software_pipeline_order); |
| bool has_stage = it1 != op->annotations.end(); |
| bool has_order = it2 != op->annotations.end(); |
| if (has_stage && has_order) { |
| return true; |
| } |
| if (has_stage) { |
| LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; |
| } |
| if (has_order) { |
| LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; |
| } |
| return false; |
| } |
| |
| Map<Var, Buffer> buffer_data_to_buffer_; |
| std::unordered_map<const VarNode*, FragmentInfo> fragment_info_; |
| std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers; |
| }; |
| |
| } // namespace software_pipeline |
| |
| namespace transform { |
| |
| /*! |
| * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. |
| * \return The IR transform pass. |
| */ |
| Pass InjectSoftwarePipeline() { |
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
| auto* fptr = f.CopyOnWrite(); |
| fptr->body = software_pipeline::PipelineInjector::Inject(f); |
| fptr->body = ConvertSSA(std::move(fptr->body)); |
| return f; |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); |
| |
| } // namespace transform |
| |
| } // namespace tir |
| } // namespace tvm |