| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/relax/transform/split_tir_layout_rewrite.cc |
| * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post process. |
| */ |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/transform.h> |
| #include <tvm/relax/expr_functor.h> |
| #include <tvm/relax/transform.h> |
| #include <tvm/s_tir/stmt.h> |
| #include <tvm/s_tir/transform.h> |
| #include <tvm/tirx/stmt_functor.h> |
| |
| #include <algorithm> |
| #include <cstddef> |
| |
| namespace tvm { |
| namespace tirx { |
| class SplitPrimFuncLayoutRewrite : public StmtMutator { |
| public: |
| explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} |
| std::tuple<ffi::Optional<PrimFunc>, PrimFunc> Transform(const PrimFunc& func) { |
| TVM_FFI_ICHECK(func->body.as<SBlockRealizeNode>()) |
| << "The body of the primfunc should be a root block."; |
| const auto& block = func->body.as<SBlockRealizeNode>()->block; |
| visit_root_block(block.get()); |
| if (layout_rewrite_preproc_stmts_.size() > 0) { |
| return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); |
| } else { |
| return std::make_tuple(std::nullopt, func); |
| } |
| } |
| |
| private: |
| void sort_rewrite_infos() { |
| std::sort( |
| rewrite_infos_.begin(), rewrite_infos_.end(), |
| [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index < b.buffer_index; }); |
| } |
| |
| PrimFunc create_layout_rewrite_preproc_func() const { |
| // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers |
| TVM_FFI_ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; |
| |
| // Step 2: Create the params for the new PrimFunc |
| ffi::Array<Var> params; |
| ffi::Map<Var, Buffer> buffer_map; |
| |
| for (const auto& info : rewrite_infos_) { |
| params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); |
| buffer_map.Set(params.back(), info.pre_rewrite_buffer); |
| } |
| for (const auto& info : rewrite_infos_) { |
| params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); |
| buffer_map.Set(params.back(), info.post_rewrite_buffer); |
| } |
| |
| // Step 3: Create the body for the new PrimFunc |
| TVM_FFI_ICHECK(layout_rewrite_preproc_stmts_.size() > 0) |
| << "There should be at least one layout rewrite preproc stmt."; |
| Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] |
| : SeqStmt(layout_rewrite_preproc_stmts_); |
| body = SBlockRealize( |
| /*iter_values=*/ffi::Array<PrimExpr>(), |
| /*predicate=*/const_true(), |
| /*block=*/ |
| SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, |
| /*name_hint=*/"root", body)); |
| |
| ffi::Map<ffi::String, ffi::Any> dict; |
| for (const auto& [key, original_value] : original_func_->attrs->dict) { |
| if (key == "global_symbol") { |
| dict.Set(key, Downcast<ffi::String>(original_value) + "_weight_prepack"); |
| } else if (key != "layout_free_buffers") { |
| dict.Set(key, original_value); |
| } |
| } |
| DictAttrs attrs(dict); |
| PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map, attrs); |
| |
| return s_tir::RenewDefs(func); |
| } |
| |
| PrimFunc create_compute_func() const { |
| // Step 1: Create the params for the new PrimFunc |
| ffi::Array<Var> params = original_func_->params; |
| ffi::Map<Var, Buffer> buffer_map = original_func_->buffer_map; |
| for (const auto& info : rewrite_infos_) { |
| const Var& param = params[info.buffer_index]; |
| TVM_FFI_ICHECK(buffer_map[param] == info.pre_rewrite_buffer); |
| buffer_map.Set(param, info.post_rewrite_buffer); |
| } |
| |
| // Step 2: Create the body for the new PrimFunc |
| Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); |
| SBlock original_block = original_func_->body.as<SBlockRealizeNode>()->block; |
| ffi::Array<Buffer> alloc_buffers; |
| for (const auto& buffer : original_block->alloc_buffers) { |
| auto it = |
| std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), |
| [&](const RewriteInfo& info) { return info.post_rewrite_buffer == buffer; }); |
| if (it == rewrite_infos_.end()) { |
| alloc_buffers.push_back(buffer); |
| } |
| } |
| |
| body = SBlockRealize( |
| /*iter_values=*/ffi::Array<PrimExpr>(), |
| /*predicate=*/const_true(), |
| /*block=*/ |
| SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, |
| /*name_hint=*/"root", body, |
| /*init=*/std::nullopt, |
| /*alloc_buffers=*/alloc_buffers)); |
| |
| ffi::Map<ffi::String, ffi::Any> dict; |
| for (const auto& [key, original_value] : original_func_->attrs->dict) { |
| if (key == "global_symbol") { |
| dict.Set(key, Downcast<ffi::String>(original_value) + "_prepacked"); |
| } else if (key != "layout_free_buffers") { |
| dict.Set(key, original_value); |
| } |
| } |
| DictAttrs attrs(dict); |
| PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map, attrs); |
| |
| return s_tir::RenewDefs(func); |
| } |
| |
| void visit_root_block(const SBlockNode* op) { |
| Stmt body = op->body; |
| if (const auto* seq_stmt = body.as<SeqStmtNode>()) { |
| for (const auto& stmt : seq_stmt->seq) { |
| current_subtree_ = 0; |
| Stmt new_stmt = this->VisitStmt(stmt); |
| TVM_FFI_ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; |
| if (current_subtree_ == 1) { |
| layout_rewrite_preproc_stmts_.push_back(new_stmt); |
| } else { |
| compute_stmts_.push_back(new_stmt); |
| } |
| } |
| } else { |
| current_subtree_ = 0; |
| this->VisitStmt(body); |
| TVM_FFI_ICHECK(current_subtree_ == -1) |
| << "There should be a compute block if there is only one subtree under the root."; |
| } |
| } |
| Stmt VisitStmt_(const SBlockNode* op) final { |
| SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(op)); |
| auto it = op->annotations.find(s_tir::attr::meta_schedule_layout_rewrite_preproc); |
| bool is_layout_rewrite_preproc = |
| it != op->annotations.end() && is_one(Downcast<PrimExpr>((*it).second)); |
| |
| if (current_subtree_ == 0) { |
| current_subtree_ = is_layout_rewrite_preproc ? 1 : -1; |
| } else if (current_subtree_ == 1) { |
| TVM_FFI_ICHECK(is_layout_rewrite_preproc) |
| << "There is a layout rewrite block in the subtree, but meet a non-layout rewrite block."; |
| } else { |
| TVM_FFI_ICHECK(!is_layout_rewrite_preproc) |
| << "There is a non-layout rewrite block in the subtree, but meet a layout rewrite block."; |
| } |
| |
| if (is_layout_rewrite_preproc) { |
| TVM_FFI_ICHECK(op->reads.size() == 1) |
| << "There should be only one read buffer in the layout rewrite"; |
| TVM_FFI_ICHECK(op->writes.size() == 1) |
| << "There should be only one write buffer in the layout rewrite"; |
| TVM_FFI_ICHECK(op->alloc_buffers.empty()) |
| << "There should be no alloc buffer in the layout rewrite"; |
| TVM_FFI_ICHECK(op->match_buffers.empty()) |
| << "There should be no match buffer in the layout rewrite"; |
| const Buffer& preproc_buffer = op->reads[0]->buffer; |
| int buffer_index = -1; |
| for (size_t i = 0; i < original_func_->params.size(); ++i) { |
| const Buffer& buffer = original_func_->buffer_map[original_func_->params[i]]; |
| if (buffer == preproc_buffer) { |
| buffer_index = i; |
| break; |
| } |
| } |
| TVM_FFI_ICHECK(buffer_index != -1) |
| << "The preproc buffer is not found in the original primfunc."; |
| rewrite_infos_.push_back( |
| RewriteInfo{buffer_index, op->reads[0]->buffer, op->writes[0]->buffer}); |
| |
| auto new_annotations = op->annotations; |
| new_annotations.erase(s_tir::attr::meta_schedule_layout_rewrite_preproc); |
| auto n = ffi::make_object<SBlockNode>(*block.get()); |
| n->annotations = new_annotations; |
| return SBlock(n); |
| } |
| return block; |
| } |
| |
| public: |
| struct RewriteInfo { |
| int buffer_index; |
| Buffer pre_rewrite_buffer; |
| Buffer post_rewrite_buffer; |
| }; |
| std::vector<RewriteInfo> rewrite_infos_; |
| |
| private: |
| /*! \brief The stmts that are used for layout rewrite preproc*/ |
| ffi::Array<Stmt> layout_rewrite_preproc_stmts_; |
| /*! \brief The stmts that are other than layout rewrite preproc*/ |
| ffi::Array<Stmt> compute_stmts_; |
| /*! |
| \brief Whether the current subtree is a layout rewrite preproc subtree. |
| -1: visited a non-layout rewrite preproc block |
| 0: unsure, not visited any block |
| 1: visited a layout rewrite preproc block |
| */ |
| int current_subtree_; |
| /*! \brief The original primfunc*/ |
| PrimFunc original_func_; |
| }; |
| } // namespace tirx |
| |
| namespace relax { |
| class SplitLayoutRewritePreproc : public ExprMutator { |
| public: |
| static IRModule Transform(const IRModule& mod) { |
| SplitLayoutRewritePreproc mutator(mod); |
| |
| // Step 1: Split the primfunc into preproc and compute |
| for (auto [gv, func] : mod->functions) { |
| if (func->IsInstance<tirx::PrimFuncNode>()) { |
| tirx::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast<tirx::PrimFunc>(func)); |
| auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast<tirx::PrimFunc>(func)); |
| if (preproc_func.defined()) { |
| mutator.split_funcs_.emplace(gv.get(), |
| std::make_tuple(preproc_func.value(), compute_func)); |
| mutator.rewrite_infos_.emplace(gv.get(), tir_rewriter.rewrite_infos_); |
| } |
| } |
| } |
| |
| for (auto [gv, func] : mod->functions) { |
| if (func->IsInstance<relax::FunctionNode>()) { |
| auto relax_func = Downcast<relax::Function>(func); |
| mutator.builder_->UpdateFunction(gv, Downcast<relax::Function>(mutator(relax_func))); |
| } |
| } |
| return mutator.builder_->GetContextIRModule(); |
| } |
| |
| private: |
| explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {} |
| using ExprMutator::VisitExpr_; |
| |
| Expr VisitExpr_(const CallNode* op) final { |
| static const Op& call_tir_op = Op::Get("relax.call_tir"); |
| Call call = Downcast<Call>(ExprMutator::VisitExpr_(op)); |
| |
| // Step 1: Skip call to other than `tirx.call_tir` |
| if (!call->op.same_as(call_tir_op)) { |
| return call; |
| } |
| |
| // Step 2: Skip if there is no preproc stage |
| const GlobalVar gv = Downcast<GlobalVar>(call->args[0]); |
| auto it = split_funcs_.find(gv.get()); |
| if (it == split_funcs_.end()) { |
| return call; |
| } |
| |
| // Step 3: Get the preproc and compute functions and update the module |
| const auto& [preproc_func, compute_func] = it->second; |
| GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint + "_weight_prepack"); |
| GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + "_prepacked"); |
| // Step 4. Get rewrite infos |
| auto rewrite_infos_it = rewrite_infos_.find(gv.get()); |
| TVM_FFI_ICHECK(rewrite_infos_it != rewrite_infos_.end()) |
| << "Rewrite infos are not found for " << gv->name_hint; |
| const auto& rewrite_infos = rewrite_infos_it->second; |
| |
| // Step 5: Emit the preproc call |
| ffi::Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields; |
| ffi::Array<Expr> preproc_args; |
| ffi::Array<StructInfo> preproc_sinfo_list; |
| for (const auto& info : rewrite_infos) { |
| preproc_args.push_back(call_tir_args[info.buffer_index]); |
| tirx::Buffer rewritten_buffer = info.post_rewrite_buffer; |
| for (const auto& shape_expr : rewritten_buffer->shape) { |
| TVM_FFI_ICHECK(shape_expr.as<tirx::IntImmNode>()) |
| << "Currently does not support rewrite buffer with " |
| "dynamic shape."; |
| } |
| preproc_sinfo_list.push_back( |
| TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); |
| } |
| StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 // |
| ? TupleStructInfo(preproc_sinfo_list) // |
| : preproc_sinfo_list[0]; |
| |
| // Step 6: Call the preproc function |
| Expr preproc_call = |
| builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_sinfo})); |
| if (rewrite_infos.size() == 1) { |
| call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); |
| } else { |
| for (size_t i = 0; i < rewrite_infos.size(); ++i) { |
| call_tir_args.Set(rewrite_infos[i].buffer_index, TupleGetItem(preproc_call, i)); |
| } |
| } |
| Expr main_call = |
| builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->sinfo_args)); |
| |
| return main_call; |
| } |
| |
| private: |
| std::unordered_map<const GlobalVarNode*, std::tuple<tirx::PrimFunc, tirx::PrimFunc>> split_funcs_; |
| std::unordered_map<const GlobalVarNode*, |
| std::vector<tirx::SplitPrimFuncLayoutRewrite::RewriteInfo>> |
| rewrite_infos_; |
| }; |
| |
| } // namespace relax |
| |
| namespace transform { |
| Pass SplitLayoutRewritePreproc() { |
| auto pass_func = [](IRModule mod, PassContext pc) { |
| return relax::SplitLayoutRewritePreproc::Transform(mod); |
| }; |
| auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {}); |
| return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, |
| "SplitLayoutRewritePreproc"); |
| } |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("relax.transform.SplitLayoutRewritePreproc", SplitLayoutRewritePreproc); |
| } |
| } // namespace transform |
| } // namespace tvm |