| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include <tvm/ffi/reflection/registry.h> |
| |
| #include "../utils.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| /*! |
| * \brief Check whether the loop has any annotation |
| * \param sref The sref of loop |
| * \return Whether the loop has any annotation |
| */ |
| inline bool HasAnnOrBinding(const ForNode* loop) { |
| return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); |
| } |
| |
| /*! \brief The visitor for extracting the stride of a var in a PrimExpr. */ |
| class StrideExtractor : public ExprVisitor { |
| public: |
| /*! |
| * \brief Extracting the stride of a var in a PrimExpr. |
| * e.g the stride of `x` in `(x * 2 + 1) * 3 + 1` is 6 |
| * \param expr The given PrimExpr. |
| * \param var The target var. |
| * \return The stride of the var. |
| */ |
| static int64_t Extract(const PrimExpr& expr, const Var& var) { |
| StrideExtractor extractor(var); |
| extractor.VisitExpr(expr); |
| return extractor.strides_[expr.get()]; |
| } |
| |
| private: |
| explicit StrideExtractor(const Var& var) : var_(var) {} |
| |
| void VisitExpr_(const MulNode* node) final { |
| ExprVisitor::VisitExpr_(node); |
| |
| if (const auto* a = node->a.as<IntImmNode>()) { |
| if (strides_.count(node->b.get())) { |
| strides_[node] = strides_[node->b.get()] * a->value; |
| } |
| } else if (const auto* b = node->b.as<IntImmNode>()) { |
| if (strides_.count(node->a.get())) { |
| strides_[node] = strides_[node->a.get()] * b->value; |
| } |
| } |
| } |
| |
| void VisitExpr_(const AddNode* node) final { |
| ExprVisitor::VisitExpr_(node); |
| int64_t stride_a, stride_b; |
| if (strides_.count(node->a.get())) { |
| stride_a = strides_[node->a.get()]; |
| } else { |
| stride_a = INT64_MAX; |
| } |
| if (strides_.count(node->b.get())) { |
| stride_b = strides_[node->b.get()]; |
| } else { |
| stride_b = INT64_MAX; |
| } |
| if (stride_a != INT64_MAX || stride_b != INT64_MAX) { |
| strides_[node] = std::min(stride_a, stride_b); |
| } |
| } |
| |
| void VisitExpr_(const VarNode* node) final { |
| if (node == var_.get()) { |
| strides_[node] = 1; |
| } |
| } |
| |
| const Var& var_; |
| std::unordered_map<const PrimExprNode*, int64_t> strides_; |
| }; |
| |
| struct ParsedAnnotation { |
| int max_parallel_extent; |
| int max_vectorize_extent; |
| int unroll_explicit; |
| int unroll_implicit; |
| int num_parallel_loops; |
| int num_vectorize_loops; |
| }; |
| |
| bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { |
| bool found = false; |
| *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; |
| for (const auto& ann : block->annotations) { |
| if (ann.first == attr::meta_schedule_parallel) { |
| found = true; |
| if (auto opt_int_imm = ann.second.try_cast<IntImm>()) { |
| parsed->max_parallel_extent = (*opt_int_imm)->value; |
| } |
| } else if (ann.first == attr::meta_schedule_vectorize) { |
| found = true; |
| if (auto opt_int_imm = ann.second.try_cast<IntImm>()) { |
| parsed->max_vectorize_extent = (*opt_int_imm)->value; |
| } |
| } else if (ann.first == attr::meta_schedule_unroll_explicit) { |
| found = true; |
| if (auto opt_int_imm = ann.second.try_cast<IntImm>()) { |
| parsed->unroll_explicit = (*opt_int_imm)->value; |
| } |
| } else if (ann.first == attr::meta_schedule_unroll_implicit) { |
| found = true; |
| if (auto opt_int_imm = ann.second.try_cast<IntImm>()) { |
| parsed->unroll_implicit = (*opt_int_imm)->value; |
| } |
| } |
| } |
| return found; |
| } |
| |
| void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { |
| if (parsed.max_parallel_extent != -1) { |
| sch->Unannotate(block_rv, attr::meta_schedule_parallel); |
| } |
| if (parsed.max_vectorize_extent != -1) { |
| sch->Unannotate(block_rv, attr::meta_schedule_vectorize); |
| } |
| if (parsed.unroll_explicit != -1) { |
| sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); |
| } |
| if (parsed.unroll_implicit != -1) { |
| sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); |
| } |
| } |
| |
| int CalculateNumRewritableLoops(const ffi::Array<StmtSRef>& loop_srefs, |
| const std::vector<int>& loop_types) { |
| int rw_loops_num = 0; |
| ICHECK_EQ(loop_srefs.size(), loop_types.size()); |
| for (size_t i = 0; i < loop_srefs.size(); ++i) { |
| const StmtSRef& loop_sref = loop_srefs[i]; |
| const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
| if (HasAnnOrBinding(loop)) { |
| continue; |
| } |
| // Cannot vectorize reduce axis |
| if (loop_types[i] != IterVarType::kDataPar) { |
| continue; |
| } |
| // Cannot fuse with a loop with multiple children |
| if (!IsSingleStmt(loop->body)) { |
| continue; |
| } |
| // Check if the loop extent is valid |
| if (GetLoopIntExtent(loop_sref) == nullptr) { |
| continue; |
| } |
| ++rw_loops_num; |
| } |
| return rw_loops_num; |
| } |
| |
| void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, |
| const ffi::Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) { |
| StmtSRef block_sref = sch->GetSRef(block_rv); |
| if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { |
| return; |
| } |
| const int n_loops = loop_rvs.size(); |
| if (n_loops == 0) { |
| parsed->max_parallel_extent = -1; |
| parsed->max_vectorize_extent = -1; |
| return; |
| } |
| // Extract loop_srefs, and calculate the iterator types |
| ffi::Array<StmtSRef> loop_srefs; |
| std::vector<int> loop_types; |
| { |
| loop_srefs.reserve(n_loops); |
| loop_types.reserve(n_loops); |
| for (const LoopRV& loop_rv : loop_rvs) { |
| loop_srefs.push_back(sch->GetSRef(loop_rv)); |
| loop_types.push_back(GetLoopIterType(loop_srefs.back())); |
| } |
| } |
| // check the maximal number of axes that are vectorizable (contiguous memory access) |
| BlockRealize realize = GetBlockRealize(sch->state(), block_sref); |
| ffi::Array<BufferRegion> buffer_access(realize->block->reads); |
| buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), |
| realize->block->writes.end()); |
| std::unordered_map<const VarNode*, PrimExpr> binding_map; |
| for (size_t i = 0; i < realize->iter_values.size(); i++) { |
| binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; |
| } |
| int max_fusible = INT32_MAX; |
| // for each block read/write, get the strides of the loop vars and find the fusible |
| // (vectorizable) axes |
| for (const BufferRegion& access : buffer_access) { |
| int fusible = 0; |
| std::vector<int64_t> strides; |
| // get strides for each loop var |
| for (const StmtSRef& loop_sref : loop_srefs) { |
| int64_t stride = 0, buffer_stride = 1; |
| const auto* var = loop_sref->StmtAs<ForNode>(); |
| arith::Analyzer analyzer; |
| for (int i = access->region.size() - 1; i >= 0; i--) { |
| PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); |
| int64_t coef = StrideExtractor::Extract(idx, var->loop_var); |
| if (coef != 0) { |
| stride = coef * buffer_stride; |
| break; |
| } |
| buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value; |
| } |
| strides.push_back(stride); |
| } |
| int prev_used_iter = -1; |
| // check the number of fusible loops |
| for (int i = strides.size() - 1; i >= 0; i--) { |
| if (strides[i] == 0) { |
| // not used in the buffer access, safe to fuse |
| fusible++; |
| continue; |
| } else if (prev_used_iter == -1) { |
| // the stride of last axis is not 1 means the memory access is not contiguous |
| if (strides[i] != 1 && fusible != 0) { |
| break; |
| } |
| fusible++; |
| prev_used_iter = i; |
| } else { |
| // contiguous memory access |
| const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>(); |
| int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value; |
| if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { |
| fusible++; |
| prev_used_iter = i; |
| } else { |
| break; |
| } |
| } |
| } |
| max_fusible = std::min(max_fusible, fusible); |
| } |
| |
| // Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization. |
| int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types); |
| |
| // Calculate the parallelize extent |
| if (parsed->max_parallel_extent != -1) { |
| int max_extent = parsed->max_parallel_extent; |
| int& num_fusible = parsed->num_parallel_loops = 0; |
| int64_t prod_extent = 1; |
| for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { |
| const StmtSRef& loop_sref = loop_srefs[i]; |
| const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
| if (HasAnnOrBinding(loop)) { |
| break; |
| } |
| // Check if the loop extent is valid |
| const int64_t* extent = GetLoopIntExtent(loop_sref); |
| if (extent == nullptr) { |
| break; |
| } |
| // Then we can fuse it in |
| ++num_fusible; |
| // Check if we need to break |
| prod_extent *= *extent; |
| if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { |
| break; |
| } |
| } |
| if (prod_extent == 1) { |
| num_fusible = -1; |
| } |
| } |
| // Calculate the vectorize extent |
| if (parsed->max_vectorize_extent != -1) { |
| int max_extent = parsed->max_vectorize_extent; |
| int& num_fusible = parsed->num_vectorize_loops = 0; |
| int64_t prod_extent = 1; |
| for (int i = n_loops - 1; |
| i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { |
| const StmtSRef& loop_sref = loop_srefs[i]; |
| const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
| if (HasAnnOrBinding(loop)) { |
| break; |
| } |
| // Cannot vectorize reduce axis |
| if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { |
| break; |
| } |
| // Cannot fuse with a loop with multiple children |
| if (!IsSingleStmt(loop->body)) { |
| break; |
| } |
| // Check if the loop extent is valid |
| const int64_t* extent = GetLoopIntExtent(loop_sref); |
| if (extent == nullptr) { |
| break; |
| } |
| // Check if the extent is still in a good range |
| prod_extent *= *extent; |
| if (prod_extent > max_extent) { |
| break; |
| } |
| ++num_fusible; |
| } |
| if (prod_extent == 1) { |
| num_fusible = -1; |
| } |
| } |
| |
| if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { |
| if (max_rw_loops == n_loops && max_fusible == n_loops) { |
| // All loops can be fused, parallelized and vectorized |
| parsed->num_parallel_loops = n_loops; |
| parsed->num_vectorize_loops = n_loops; |
| } else { |
| // Prefer num_vectorize to num_parallel |
| parsed->num_parallel_loops = |
| std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops); |
| } |
| } |
| } |
| |
| bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { |
| IRModule mod = sch->mod(); |
| for (const auto& kv : mod->functions) { |
| const GlobalVar& g_var = kv.first; |
| const BaseFunc& base_func = kv.second; |
| if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
| const BlockRealizeNode* block_realize = prim_func->body.as<BlockRealizeNode>(); |
| if (block_realize != nullptr) { |
| Block block = block_realize->block; |
| if (ParseAnnotation(block, parsed)) { |
| *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); |
| RemoveParsedAnn(sch, *root_rv, *parsed); |
| return true; |
| } |
| } |
| } |
| } |
| return false; |
| } |
| |
| void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array<LoopRV>* loop_rvs, |
| int vec_len) { |
| size_t n_loops = loop_rvs->size(); |
| LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); |
| ffi::Array<LoopRV> split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); |
| ICHECK_EQ(split.size(), 2); |
| const LoopRV& outer = split[0]; |
| const LoopRV& inner = split[1]; |
| sch->Parallel(outer); |
| sch->Vectorize(inner); |
| for (size_t i = 0; i < n_loops - 1; ++i) { |
| loop_rvs->Set(i, outer); |
| } |
| loop_rvs->Set(n_loops - 1, inner); |
| } |
| |
| void RewriteParallel(const Schedule& sch, size_t n, ffi::Array<LoopRV>* loop_rvs) { |
| ICHECK_LE(n, loop_rvs->size()); |
| LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); |
| sch->Parallel(fused); |
| for (size_t i = 0; i < n; ++i) { |
| loop_rvs->Set(i, fused); |
| } |
| } |
| |
| void RewriteVectorize(const Schedule& sch, size_t n, ffi::Array<LoopRV>* loop_rvs) { |
| size_t n_loops = loop_rvs->size(); |
| ICHECK_LE(n, n_loops); |
| LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); |
| sch->Vectorize(fused); |
| for (size_t i = n_loops - n; i < n_loops; ++i) { |
| loop_rvs->Set(i, fused); |
| } |
| } |
| |
| void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const BlockRV& block, |
| const LoopRV& loop) { |
| // Do not unroll for pure spatial block. |
| if (max_step <= 0 || IsSpatial(sch->GetSRef(block))) { |
| return; |
| } |
| |
| sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); |
| sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); |
| } |
| |
| } // namespace tir |
| |
| namespace meta_schedule { |
| |
| using tir::Schedule; |
| |
| class RewriteParallelVectorizeUnrollNode : public PostprocNode { |
| public: |
| void InitializeWithTuneContext(const TuneContext& context) final {} |
| |
| bool Apply(const Schedule& sch) final { |
| tir::ParsedAnnotation parsed_root; |
| tir::BlockRV root_rv{ffi::UnsafeInit()}; |
| while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { |
| for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { |
| ffi::Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv); |
| if (loop_rvs.empty()) { |
| continue; |
| } |
| tir::ParsedAnnotation parsed = parsed_root; |
| tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); |
| const int loops_num = loop_rvs.size(); |
| if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { |
| // Fuse, split, vectorize and parallelize |
| tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); |
| } else { |
| // Parallel |
| if (parsed.num_parallel_loops > 0) { |
| tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); |
| } |
| // Vectorize |
| if (parsed.num_vectorize_loops > 0) { |
| tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); |
| } |
| } |
| // AutoUnroll |
| if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { |
| ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); |
| int unroll_explicit = parsed.unroll_explicit != -1; |
| int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; |
| tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); |
| } |
| } |
| } |
| return true; |
| } |
| |
| Postproc Clone() const { |
| ObjectPtr<RewriteParallelVectorizeUnrollNode> n = |
| ffi::make_object<RewriteParallelVectorizeUnrollNode>(*this); |
| return Postproc(n); |
| } |
| |
| static void RegisterReflection() { |
| namespace refl = tvm::ffi::reflection; |
| refl::ObjectDef<RewriteParallelVectorizeUnrollNode>(); |
| } |
| |
| TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteParallelVectorizeUnroll", |
| RewriteParallelVectorizeUnrollNode, PostprocNode); |
| }; |
| |
| Postproc Postproc::RewriteParallelVectorizeUnroll() { |
| ObjectPtr<RewriteParallelVectorizeUnrollNode> n = |
| ffi::make_object<RewriteParallelVectorizeUnrollNode>(); |
| return Postproc(n); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| RewriteParallelVectorizeUnrollNode::RegisterReflection(); |
| refl::GlobalDef().def("meta_schedule.PostprocRewriteParallelVectorizeUnroll", |
| Postproc::RewriteParallelVectorizeUnroll); |
| } |
| |
| } // namespace meta_schedule |
| } // namespace tvm |