| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/relax/backend/vm/codegen_tir.cc |
| * \brief A codegen to generate VMTIR function(that can be compiled) from executable. |
| */ |
| #include <tvm/ffi/cast.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/module.h> |
| #include <tvm/relax/exec_builder.h> |
| #include <tvm/relax/expr_functor.h> |
| #include <tvm/relax/op_attr_types.h> |
| #include <tvm/runtime/logging.h> |
| #include <tvm/runtime/vm/executable.h> |
| #include <tvm/target/target.h> |
| #include <tvm/tirx/builtin.h> |
| #include <tvm/tirx/expr.h> |
| #include <tvm/tirx/function.h> |
| #include <tvm/tirx/stmt.h> |
| |
| #include <cctype> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "../../transform/utils.h" |
| |
| namespace tvm { |
| namespace relax { |
| namespace codegen_vm { |
| |
| using vm::VMFuncInfo; |
| |
| /*! |
| * \brief A class to generate VMTIR for Relax functions. |
| * |
| * \note Skip CallPacked with special attrs for now, as they can be |
| * further simplified with PrimExpr. |
| */ |
| class CodeGenVMTIR : public ExprFunctor<ffi::Optional<PrimExpr>(const Expr&)> { |
| public: |
| explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) |
| : builder_(builder), ctx_mod_(ctx_mod) { |
| system_lib_prefix_ = ctx_mod_->GetAttr<ffi::String>(tvm::attr::kSystemLibPrefix); |
| } |
| |
| static IRModule Run(relax::ExecBuilder builder, IRModule mod) { |
| // create a new copy |
| IRModule res_mod = mod; |
| res_mod.CopyOnWrite(); |
| |
| CodeGenVMTIR codegen(builder, mod); |
| // Remove relax function and turn into TIR func. |
| for (auto& p : mod->functions) { |
| if (auto* func = p.second.as<FunctionNode>()) { |
| auto tir_func = codegen.Codegen(ffi::GetRef<Function>(func)); |
| auto gsymbol = tir_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); |
| res_mod->Add(GlobalVar(gsymbol.value()), tir_func); |
| res_mod->Remove(p.first); |
| } |
| } |
| return res_mod; |
| } |
| |
| private: |
| int64_t NewRegister() { return registers_num_++; } |
| |
| static IntImm ConstInt64(int64_t value) { return IntImm::Int64(value); } |
| |
| static IntImm ConstInt32(int64_t value) { return IntImm::Int32(value); } |
| |
| PrimExpr RegListGet(int64_t slot) const { |
| // use 128 bits to represent any |
| return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), |
| {reg_anylist_handle_, ConstInt32(slot)}) |
| .as_or_throw<PrimExpr>(); |
| } |
| |
| PrimExpr ConstListGet(int64_t slot) const { |
| // use 128 bits to represent any |
| return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), |
| {const_anylist_handle_, ConstInt32(slot)}) |
| .as_or_throw<PrimExpr>(); |
| } |
| |
| PrimExpr FuncListGet(int64_t slot) const { |
| // use 128 bits to represent any |
| return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), |
| {func_anylist_handle_, ConstInt32(slot)}) |
| .as_or_throw<PrimExpr>(); |
| } |
| |
| void EmitStmt(tirx::Stmt stmt) { |
| TVM_FFI_ICHECK(!stmt_stack_.empty()); |
| stmt_stack_.back().emplace_back(stmt); |
| } |
| |
| void EmitCallPacked(ffi::String name, const ffi::Array<PrimExpr>& args, |
| int64_t dst_anylist_slot = -1) { |
| ffi::Array<PrimExpr> all_args; |
| // negative index indicate return value can be discarded, emit call_packed |
| if (dst_anylist_slot >= 0) { |
| all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; |
| } |
| all_args.push_back(tirx::StringImm(name)); |
| for (PrimExpr arg : args) { |
| all_args.push_back(arg); |
| } |
| if (dst_anylist_slot >= 0) { |
| this->EmitStmt(tirx::Evaluate( |
| tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args) |
| .as_or_throw<PrimExpr>())); |
| } else { |
| this->EmitStmt(tirx::Evaluate( |
| tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args) |
| .as_or_throw<PrimExpr>())); |
| } |
| } |
| |
| void EmitCallCPacked(const tirx::PrimFunc& prim_func, const ffi::Array<PrimExpr>& args, |
| int64_t dst_anylist_slot = -1) { |
| ffi::Optional<ffi::String> gsymbol = prim_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); |
| TVM_FFI_ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; |
| ffi::Array<PrimExpr> all_args; |
| // negative index indicate return value can be discarded, emit call_packed |
| if (dst_anylist_slot >= 0) { |
| all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; |
| } |
| all_args.push_back(tirx::StringImm(gsymbol.value())); |
| for (PrimExpr arg : args) { |
| all_args.push_back(arg); |
| } |
| if (dst_anylist_slot >= 0) { |
| this->EmitStmt(tirx::Evaluate( |
| tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args) |
| .as_or_throw<PrimExpr>())); |
| } else { |
| this->EmitStmt(tirx::Evaluate( |
| tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args) |
| .as_or_throw<PrimExpr>())); |
| } |
| } |
| |
| tirx::PrimFunc Codegen(const Function& func) { |
| ffi::Optional<ffi::String> gsymbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); |
| TVM_FFI_ICHECK(gsymbol.has_value()) |
| << "there should be no local functions in Relax VM codegen phase. " |
| "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; |
| // initialize the state |
| stmt_stack_ = {}; |
| registers_num_ = 0; |
| var_map_.clear(); |
| ctx_ptr_ = tirx::Var("ctx_ptr", PrimType::Handle()); |
| reg_anylist_handle_ = tirx::Var("r", PrimType::Handle()); |
| func_anylist_handle_ = tirx::Var("f", PrimType::Handle()); |
| const_anylist_handle_ = tirx::Var("c", PrimType::Handle()); |
| |
| ffi::Array<ffi::String> param_names; |
| for (Var param : func->params) { |
| param_names.push_back(param->name_hint()); |
| } |
| // declare this function. |
| builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc); |
| |
| for (size_t i = 0; i < func->params.size(); ++i) { |
| int64_t r = NewRegister(); |
| TVM_FFI_ICHECK_EQ(static_cast<size_t>(r), i); |
| this->var_map_.insert({func->params[i], RegListGet(r)}); |
| } |
| size_t ret_reg = NewRegister(); |
| |
| tirx::Stmt body = WithNewScope([&]() { |
| ffi::Optional<PrimExpr> ret = ExprFunctor::VisitExpr(func->body); |
| if (ret.defined()) { |
| this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); |
| } |
| }); |
| |
| // Mark the function entry internally. |
| builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names, |
| VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_); |
| builder_->EndFunction(gsymbol.value()); |
| |
| Type ret_type = VoidType(); |
| ffi::Array<tirx::Var> tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, |
| func_anylist_handle_}; |
| ffi::String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); |
| tirx::PrimFunc tir_func(tir_params, body, ret_type, {}); |
| tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); |
| tir_func = WithAttr(tir_func, tvm::attr::kSTir, true); |
| registers_num_ = 0; |
| var_map_.clear(); |
| stmt_stack_.clear(); |
| return tir_func; |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final { |
| for (auto block : op->blocks) { |
| for (Binding binding : block->bindings) { |
| Expr expr = GetBoundValue(binding); |
| ffi::Optional<PrimExpr> value = VisitExpr(expr); |
| |
| if (expr.as<Var>() && value.defined()) { |
| // For a normalized relax module, there should be one |
| // register for each relax::Binding. This makes the Relax |
| // semantics of R.vm.kill_* operate the same as the Python |
| // "del" operator. These bindings may be removable by using |
| // relax.transform.CanonicalizeBindings earlier in lowering. |
| auto new_reg = NewRegister(); |
| EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg); |
| value = RegListGet(new_reg); |
| } |
| |
| this->var_map_.insert({binding->var, value}); |
| } |
| } |
| return this->VisitExpr(op->body); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const CallNode* call_node) final { |
| Call call = ffi::GetRef<Call>(call_node); |
| |
| if (call_node->op == null_value_op_) { |
| return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}) |
| .as_or_throw<PrimExpr>(); |
| } |
| int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister(); |
| if (call->op.as<OpNode>()) { |
| if (call_node->op == call_builtin_with_ctx_op_) { |
| EmitCallBuiltinWithCtx(call, dst_reg); |
| } else if (call_node->op == alloc_storage_op_) { |
| EmitAllocStorage(call, dst_reg); |
| } else if (call_node->op == alloc_tensor_op_) { |
| EmitAllocTensor(call, dst_reg); |
| } else if (call_node->op == kill_object_op_) { |
| dst_reg = EmitKillObject(call); |
| } else { |
| // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those |
| // ops are handled in a pass when lowering them to TIR. |
| TVM_FFI_THROW(InternalError) << "CodeGenVMTIR cannot handle this intrinsic now:\n" |
| << call_node->op; |
| } |
| } else { |
| EmitNormalCall(call, dst_reg); |
| } |
| if (dst_reg >= 0) { |
| return RegListGet(dst_reg); |
| } else { |
| return std::nullopt; |
| } |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const IfNode* op) final { |
| // Reserve a register for return |
| size_t merge_register = NewRegister(); |
| PrimExpr cond_value = this->VisitExpr(op->cond).value(); |
| |
| cond_value = tvm::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(), |
| {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}) |
| .as_or_throw<PrimExpr>(); |
| |
| tirx::Stmt true_branch = WithNewScope([&]() { |
| PrimExpr true_value = this->VisitExpr(op->true_branch).value(); |
| this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register); |
| }); |
| tirx::Stmt false_branch = WithNewScope([&]() { |
| PrimExpr false_value = this->VisitExpr(op->false_branch).value(); |
| this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register); |
| }); |
| this->EmitStmt(tirx::IfThenElse(cond_value, true_branch, false_branch)); |
| return RegListGet(merge_register); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const VarNode* op) final { |
| Var var = ffi::GetRef<Var>(op); |
| auto it = this->var_map_.find(var); |
| TVM_FFI_ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; |
| return it->second; |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const ConstantNode* op) final { |
| return ConstListGet(builder_->ConvertConstant(op->data).value()); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const ShapeExprNode* op) final { |
| std::vector<int64_t> shape; |
| for (PrimExpr e : op->values) { |
| if (auto* int_value = e.as<IntImmNode>()) { |
| shape.push_back(int_value->value); |
| } else { |
| TVM_FFI_THROW(InternalError) |
| << "Should only use constant shape after shape lowering: " << op->values; |
| } |
| } |
| return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExprFallback_(const ExprNode* op) final { |
| return ffi::GetRef<Expr>(op).as_or_throw<PrimExpr>(); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const StringImmNode* op) final { |
| return ConstListGet(builder_->ConvertConstant(op->value).value()); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const DataTypeImmNode* op) final { |
| return ConstListGet(builder_->ConvertConstant(op->value).value()); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const TupleNode* op) final { |
| Tuple tuple = ffi::GetRef<Tuple>(op); |
| ffi::Array<PrimExpr> args; |
| for (auto arg : tuple->fields) { |
| args.push_back(this->VisitExpr(arg).value()); |
| } |
| int32_t dst_register = NewRegister(); |
| this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register); |
| return RegListGet(dst_register); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const TupleGetItemNode* op) final { |
| TupleGetItem expr = ffi::GetRef<TupleGetItem>(op); |
| ffi::Array<PrimExpr> args = {this->VisitExpr(expr->tuple).value()}; |
| |
| args.push_back(ConstInt64(expr->index)); |
| |
| int64_t dst_register = NewRegister(); |
| this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); |
| return RegListGet(dst_register); |
| } |
| |
| // Lookup the function and see if it matches |
| ffi::Optional<ffi::String> LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { |
| if (auto* ext_func = expr.as<ExternFuncNode>()) { |
| *kind = VMFuncInfo::FuncKind::kPackedFunc; |
| return ext_func->global_symbol; |
| } else if (auto* gvar_ptr = expr.as<GlobalVarNode>()) { |
| GlobalVar gvar = ffi::GetRef<GlobalVar>(gvar_ptr); |
| // Run a look up in the env to see if it maps to an extern func. |
| auto it = ctx_mod_->functions.find(gvar); |
| if (it != ctx_mod_->functions.end()) { |
| BaseFunc func = (*it).second; |
| if (auto* efunc = func.as<ExternFuncNode>()) { |
| *kind = VMFuncInfo::FuncKind::kPackedFunc; |
| return efunc->global_symbol; |
| } else if (func.as<FunctionNode>()) { |
| *kind = VMFuncInfo::FuncKind::kVMTIRFunc; |
| return gvar->name_hint; |
| } else if (func.as<tirx::PrimFuncNode>()) { |
| *kind = VMFuncInfo::FuncKind::kPackedFunc; |
| return gvar->name_hint; |
| } else { |
| *kind = VMFuncInfo::FuncKind::kPackedFunc; |
| return gvar->name_hint; |
| } |
| } |
| LOG(WARNING) << "Undefined global var " << gvar->name_hint; |
| // undefined global var, consider eliminate later. |
| *kind = VMFuncInfo::FuncKind::kPackedFunc; |
| return gvar->name_hint; |
| } else { |
| return std::nullopt; |
| } |
| } |
| // Lookup PrimFunc in the same module |
| // We can do direct PrimFunc call in such cases |
| ffi::Optional<tirx::PrimFunc> LookupPrimFunc(const ffi::String& name) { |
| if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt; |
| |
| GlobalVar gvar = ctx_mod_->GetGlobalVar(name); |
| auto it = ctx_mod_->functions.find(gvar); |
| if (it != ctx_mod_->functions.end()) { |
| BaseFunc func = (*it).second; |
| if (auto* prim_func = func.as<tirx::PrimFuncNode>()) { |
| return ffi::GetRef<tirx::PrimFunc>(prim_func); |
| } |
| } |
| return std::nullopt; |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const GlobalVarNode* op) final { |
| VMFuncInfo::FuncKind kind; |
| auto symbol = LookupFunction(ffi::GetRef<Expr>(op), &kind); |
| TVM_FFI_ICHECK(symbol.has_value()); |
| builder_->DeclareFunction(symbol.value(), kind); |
| return FuncListGet(builder_->GetFunction(symbol.value()).value()); |
| } |
| |
| ffi::Optional<PrimExpr> VisitExpr_(const ExternFuncNode* op) final { |
| builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); |
| return FuncListGet(builder_->GetFunction(op->global_symbol).value()); |
| } |
| |
| void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { |
| // Handle args of the call |
| ffi::Array<PrimExpr> args; |
| args.push_back(ctx_ptr_); |
| for (Expr arg : call_node->args) { |
| args.push_back(this->VisitExpr(arg).value()); |
| } |
| this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg); |
| } |
| |
| void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { |
| TVM_FFI_ICHECK_EQ(call_node->args.size(), 5); |
| ffi::Array<PrimExpr> args; |
| for (int i = 0; i < 4; ++i) { |
| args.push_back(this->VisitExpr(call_node->args[i]).value()); |
| } |
| int64_t vdevice_index = -1; |
| if (const auto* int_imm = call_node->args[4].as<IntImmNode>()) { |
| vdevice_index = int_imm->value; |
| } |
| auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); |
| |
| if (vdevice.defined()) { |
| args.push_back(tirx::StringImm(vdevice.value()->memory_scope)); |
| } |
| |
| this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg); |
| } |
| |
| int64_t EmitKillObject(const Call& call_node) { |
| TVM_FFI_ICHECK_EQ(call_node->args.size(), 1); |
| PrimExpr arg = this->VisitExpr(call_node->args[0]).value(); |
| |
| // Check the arg is a register. |
| const auto* tir_call = arg.as<CallNode>(); |
| TVM_FFI_ICHECK(tir_call != nullptr); |
| TVM_FFI_ICHECK(tir_call->op == tirx::builtin::anylist_getitem()); |
| TVM_FFI_ICHECK(tir_call->args.size() == 2); |
| TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); |
| const auto* p_dst_reg = tir_call->args[1].as<tirx::IntImmNode>(); |
| TVM_FFI_ICHECK(p_dst_reg != nullptr); |
| TVM_FFI_ICHECK( |
| p_dst_reg->ty.as_or_throw<PrimType>().MatchesElementType(DLDataTypeCode::kDLInt, 32)); |
| |
| int64_t dst_reg = p_dst_reg->value; |
| this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg); |
| return dst_reg; |
| } |
| |
| void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { |
| ffi::Array<PrimExpr> args; |
| // if context is required, pass as first argument. |
| args.push_back(ctx_ptr_); |
| auto* func = call_node->args[0].as<ExternFuncNode>(); |
| TVM_FFI_ICHECK(func) << "CallBuiltin comes with extern func"; |
| |
| auto tuple_arg = call_node->args[1].as_or_throw<Tuple>(); |
| |
| // Handle args of the call |
| for (Expr arg : tuple_arg->fields) { |
| args.push_back(this->VisitExpr(arg).value()); |
| } |
| |
| this->EmitCallPacked(func->global_symbol, args, dst_reg); |
| } |
| |
| void EmitNormalCall(const Call& call_node, int64_t dst_reg) { |
| ffi::Array<PrimExpr> args = VisitArray(call_node->args); |
| // A function can be a closure that comes from parent |
| // Do call closure to be safe. |
| VMFuncInfo::FuncKind kind; |
| auto symbol = LookupFunction(call_node->op, &kind); |
| |
| if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) { |
| // primfunc in the same module. |
| // use cpacked to directly invoke without named based lookup |
| if (ffi::Optional<tirx::PrimFunc> prim_func = LookupPrimFunc(symbol.value())) { |
| this->EmitCallCPacked(prim_func.value(), args, dst_reg); |
| } else { |
| this->EmitCallPacked(symbol.value(), args, dst_reg); |
| } |
| } else { |
| // Default path, leverage function table and invoke as closure |
| ffi::Array<PrimExpr> all_args; |
| all_args.push_back(ctx_ptr_); |
| all_args.push_back(this->VisitExpr(call_node->op).value()); |
| for (auto arg : args) { |
| all_args.push_back(arg); |
| } |
| this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg); |
| } |
| } |
| |
| template <typename FLambda> |
| tirx::Stmt WithNewScope(const FLambda& callback) { |
| stmt_stack_.push_back({}); |
| callback(); |
| tirx::Stmt stmt = tirx::SeqStmt::Flatten(stmt_stack_.back()); |
| stmt_stack_.pop_back(); |
| return stmt; |
| } |
| |
| ffi::Array<PrimExpr> VisitArray(const ffi::Array<Expr>& arr) { |
| ffi::Array<PrimExpr> ret; |
| for (size_t i = 0; i < arr.size(); ++i) { |
| ret.push_back(this->VisitExpr(arr[i]).value()); |
| } |
| return ret; |
| } |
| /*! \brief Internal ExecBuilder. */ |
| relax::ExecBuilder builder_; |
| /*! \brief List to ctx_ptr */ |
| tirx::Var ctx_ptr_; |
| /*! \brief List to store temp object registers */ |
| tirx::Var reg_anylist_handle_; |
| /*! \brief List to store closures */ |
| tirx::Var func_anylist_handle_; |
| /*! \brief List to store constants */ |
| tirx::Var const_anylist_handle_; |
| /*! |
| * \brief Total number of virtual registers allocated. |
| * \note The first two registers are reserved for special registers. |
| */ |
| int64_t registers_num_ = 0; |
| /*! \brief Stack to build up statements */ |
| std::vector<std::vector<tirx::Stmt>> stmt_stack_; |
| /*! \brief Map from var to Expr. */ |
| std::unordered_map<Var, ffi::Optional<PrimExpr>> var_map_; |
| /*! \brief the context module. */ |
| IRModule ctx_mod_; |
| /*! \brief system lib prefix */ |
| ffi::Optional<ffi::String> system_lib_prefix_; |
| /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ |
| const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); |
| const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); |
| const Op& kill_object_op_ = Op::Get("relax.vm.kill_object"); |
| const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); |
| const Op& null_value_op_ = Op::Get("relax.null_value"); |
| }; |
| |
| /*! |
| * \brief Create the Relax VM executable from all relax.Function in mod. |
| * and add them to exec_builder. Create extra TIR functions. |
| * |
| * \param exec_builder Builder to collect executables. |
| * \param mod Input module. |
| * \return Extra TIR module created. |
| */ |
| IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { |
| return CodeGenVMTIR::Run(exec_builder, mod); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("relax.VMTIRCodeGen", VMTIRCodeGen); |
| } |
| |
| } // namespace codegen_vm |
| } // namespace relax |
| } // namespace tvm |