blob: bf04ce88319c1cf07901eb6fef2e8ce0c6e3cb9e [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/backend/vm/codegen_tir.cc
* \brief A codegen to generate VMTIR function(that can be compiled) from executable.
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/relax/exec_builder.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/vm/executable.h>
#include <tvm/target/target.h>
#include <tvm/tirx/builtin.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/function.h>
#include <tvm/tirx/stmt.h>
#include <cctype>
#include <string>
#include <unordered_map>
#include <vector>
#include "../../transform/utils.h"
namespace tvm {
namespace relax {
namespace codegen_vm {
using vm::VMFuncInfo;
/*!
* \brief A class to generate VMTIR for Relax functions.
*
* \note Skip CallPacked with special attrs for now, as they can be
* further simplified with PrimExpr.
*/
class CodeGenVMTIR : public ExprFunctor<ffi::Optional<PrimExpr>(const Expr&)> {
public:
explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod)
: builder_(builder), ctx_mod_(ctx_mod) {
system_lib_prefix_ = ctx_mod_->GetAttr<ffi::String>(tvm::attr::kSystemLibPrefix);
}
static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
// create a new copy
IRModule res_mod = mod;
res_mod.CopyOnWrite();
CodeGenVMTIR codegen(builder, mod);
// Remove relax function and turn into TIR func.
for (auto& p : mod->functions) {
if (auto* func = p.second.as<FunctionNode>()) {
auto tir_func = codegen.Codegen(ffi::GetRef<Function>(func));
auto gsymbol = tir_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
res_mod->Add(GlobalVar(gsymbol.value()), tir_func);
res_mod->Remove(p.first);
}
}
return res_mod;
}
private:
int64_t NewRegister() { return registers_num_++; }
static IntImm ConstInt64(int64_t value) { return IntImm::Int64(value); }
static IntImm ConstInt32(int64_t value) { return IntImm::Int32(value); }
PrimExpr RegListGet(int64_t slot) const {
// use 128 bits to represent any
return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(),
{reg_anylist_handle_, ConstInt32(slot)})
.as_or_throw<PrimExpr>();
}
PrimExpr ConstListGet(int64_t slot) const {
// use 128 bits to represent any
return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(),
{const_anylist_handle_, ConstInt32(slot)})
.as_or_throw<PrimExpr>();
}
PrimExpr FuncListGet(int64_t slot) const {
// use 128 bits to represent any
return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(),
{func_anylist_handle_, ConstInt32(slot)})
.as_or_throw<PrimExpr>();
}
void EmitStmt(tirx::Stmt stmt) {
TVM_FFI_ICHECK(!stmt_stack_.empty());
stmt_stack_.back().emplace_back(stmt);
}
void EmitCallPacked(ffi::String name, const ffi::Array<PrimExpr>& args,
int64_t dst_anylist_slot = -1) {
ffi::Array<PrimExpr> all_args;
// negative index indicate return value can be discarded, emit call_packed
if (dst_anylist_slot >= 0) {
all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
}
all_args.push_back(tirx::StringImm(name));
for (PrimExpr arg : args) {
all_args.push_back(arg);
}
if (dst_anylist_slot >= 0) {
this->EmitStmt(tirx::Evaluate(
tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args)
.as_or_throw<PrimExpr>()));
} else {
this->EmitStmt(tirx::Evaluate(
tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args)
.as_or_throw<PrimExpr>()));
}
}
void EmitCallCPacked(const tirx::PrimFunc& prim_func, const ffi::Array<PrimExpr>& args,
int64_t dst_anylist_slot = -1) {
ffi::Optional<ffi::String> gsymbol = prim_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase";
ffi::Array<PrimExpr> all_args;
// negative index indicate return value can be discarded, emit call_packed
if (dst_anylist_slot >= 0) {
all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
}
all_args.push_back(tirx::StringImm(gsymbol.value()));
for (PrimExpr arg : args) {
all_args.push_back(arg);
}
if (dst_anylist_slot >= 0) {
this->EmitStmt(tirx::Evaluate(
tvm::Call(tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args)
.as_or_throw<PrimExpr>()));
} else {
this->EmitStmt(tirx::Evaluate(
tvm::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args)
.as_or_throw<PrimExpr>()));
}
}
tirx::PrimFunc Codegen(const Function& func) {
ffi::Optional<ffi::String> gsymbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(gsymbol.has_value())
<< "there should be no local functions in Relax VM codegen phase. "
"Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?";
// initialize the state
stmt_stack_ = {};
registers_num_ = 0;
var_map_.clear();
ctx_ptr_ = tirx::Var("ctx_ptr", PrimType::Handle());
reg_anylist_handle_ = tirx::Var("r", PrimType::Handle());
func_anylist_handle_ = tirx::Var("f", PrimType::Handle());
const_anylist_handle_ = tirx::Var("c", PrimType::Handle());
ffi::Array<ffi::String> param_names;
for (Var param : func->params) {
param_names.push_back(param->name_hint());
}
// declare this function.
builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc);
for (size_t i = 0; i < func->params.size(); ++i) {
int64_t r = NewRegister();
TVM_FFI_ICHECK_EQ(static_cast<size_t>(r), i);
this->var_map_.insert({func->params[i], RegListGet(r)});
}
size_t ret_reg = NewRegister();
tirx::Stmt body = WithNewScope([&]() {
ffi::Optional<PrimExpr> ret = ExprFunctor::VisitExpr(func->body);
if (ret.defined()) {
this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg);
}
});
// Mark the function entry internally.
builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names,
VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_);
builder_->EndFunction(gsymbol.value());
Type ret_type = VoidType();
ffi::Array<tirx::Var> tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_,
func_anylist_handle_};
ffi::String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value();
tirx::PrimFunc tir_func(tir_params, body, ret_type, {});
tir_func = WithAttr(tir_func, "global_symbol", tir_func_name);
tir_func = WithAttr(tir_func, tvm::attr::kSTir, true);
registers_num_ = 0;
var_map_.clear();
stmt_stack_.clear();
return tir_func;
}
ffi::Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Expr expr = GetBoundValue(binding);
ffi::Optional<PrimExpr> value = VisitExpr(expr);
if (expr.as<Var>() && value.defined()) {
// For a normalized relax module, there should be one
// register for each relax::Binding. This makes the Relax
// semantics of R.vm.kill_* operate the same as the Python
// "del" operator. These bindings may be removable by using
// relax.transform.CanonicalizeBindings earlier in lowering.
auto new_reg = NewRegister();
EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg);
value = RegListGet(new_reg);
}
this->var_map_.insert({binding->var, value});
}
}
return this->VisitExpr(op->body);
}
ffi::Optional<PrimExpr> VisitExpr_(const CallNode* call_node) final {
Call call = ffi::GetRef<Call>(call_node);
if (call_node->op == null_value_op_) {
return tvm::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)})
.as_or_throw<PrimExpr>();
}
int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister();
if (call->op.as<OpNode>()) {
if (call_node->op == call_builtin_with_ctx_op_) {
EmitCallBuiltinWithCtx(call, dst_reg);
} else if (call_node->op == alloc_storage_op_) {
EmitAllocStorage(call, dst_reg);
} else if (call_node->op == alloc_tensor_op_) {
EmitAllocTensor(call, dst_reg);
} else if (call_node->op == kill_object_op_) {
dst_reg = EmitKillObject(call);
} else {
// every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
// ops are handled in a pass when lowering them to TIR.
TVM_FFI_THROW(InternalError) << "CodeGenVMTIR cannot handle this intrinsic now:\n"
<< call_node->op;
}
} else {
EmitNormalCall(call, dst_reg);
}
if (dst_reg >= 0) {
return RegListGet(dst_reg);
} else {
return std::nullopt;
}
}
ffi::Optional<PrimExpr> VisitExpr_(const IfNode* op) final {
// Reserve a register for return
size_t merge_register = NewRegister();
PrimExpr cond_value = this->VisitExpr(op->cond).value();
cond_value = tvm::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(),
{tirx::StringImm("vm.builtin.read_if_cond"), cond_value})
.as_or_throw<PrimExpr>();
tirx::Stmt true_branch = WithNewScope([&]() {
PrimExpr true_value = this->VisitExpr(op->true_branch).value();
this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register);
});
tirx::Stmt false_branch = WithNewScope([&]() {
PrimExpr false_value = this->VisitExpr(op->false_branch).value();
this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register);
});
this->EmitStmt(tirx::IfThenElse(cond_value, true_branch, false_branch));
return RegListGet(merge_register);
}
ffi::Optional<PrimExpr> VisitExpr_(const VarNode* op) final {
Var var = ffi::GetRef<Var>(op);
auto it = this->var_map_.find(var);
TVM_FFI_ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined";
return it->second;
}
ffi::Optional<PrimExpr> VisitExpr_(const ConstantNode* op) final {
return ConstListGet(builder_->ConvertConstant(op->data).value());
}
ffi::Optional<PrimExpr> VisitExpr_(const ShapeExprNode* op) final {
std::vector<int64_t> shape;
for (PrimExpr e : op->values) {
if (auto* int_value = e.as<IntImmNode>()) {
shape.push_back(int_value->value);
} else {
TVM_FFI_THROW(InternalError)
<< "Should only use constant shape after shape lowering: " << op->values;
}
}
return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value());
}
ffi::Optional<PrimExpr> VisitExprFallback_(const ExprNode* op) final {
return ffi::GetRef<Expr>(op).as_or_throw<PrimExpr>();
}
ffi::Optional<PrimExpr> VisitExpr_(const StringImmNode* op) final {
return ConstListGet(builder_->ConvertConstant(op->value).value());
}
ffi::Optional<PrimExpr> VisitExpr_(const DataTypeImmNode* op) final {
return ConstListGet(builder_->ConvertConstant(op->value).value());
}
ffi::Optional<PrimExpr> VisitExpr_(const TupleNode* op) final {
Tuple tuple = ffi::GetRef<Tuple>(op);
ffi::Array<PrimExpr> args;
for (auto arg : tuple->fields) {
args.push_back(this->VisitExpr(arg).value());
}
int32_t dst_register = NewRegister();
this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register);
return RegListGet(dst_register);
}
ffi::Optional<PrimExpr> VisitExpr_(const TupleGetItemNode* op) final {
TupleGetItem expr = ffi::GetRef<TupleGetItem>(op);
ffi::Array<PrimExpr> args = {this->VisitExpr(expr->tuple).value()};
args.push_back(ConstInt64(expr->index));
int64_t dst_register = NewRegister();
this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register);
return RegListGet(dst_register);
}
// Lookup the function and see if it matches
ffi::Optional<ffi::String> LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) {
if (auto* ext_func = expr.as<ExternFuncNode>()) {
*kind = VMFuncInfo::FuncKind::kPackedFunc;
return ext_func->global_symbol;
} else if (auto* gvar_ptr = expr.as<GlobalVarNode>()) {
GlobalVar gvar = ffi::GetRef<GlobalVar>(gvar_ptr);
// Run a look up in the env to see if it maps to an extern func.
auto it = ctx_mod_->functions.find(gvar);
if (it != ctx_mod_->functions.end()) {
BaseFunc func = (*it).second;
if (auto* efunc = func.as<ExternFuncNode>()) {
*kind = VMFuncInfo::FuncKind::kPackedFunc;
return efunc->global_symbol;
} else if (func.as<FunctionNode>()) {
*kind = VMFuncInfo::FuncKind::kVMTIRFunc;
return gvar->name_hint;
} else if (func.as<tirx::PrimFuncNode>()) {
*kind = VMFuncInfo::FuncKind::kPackedFunc;
return gvar->name_hint;
} else {
*kind = VMFuncInfo::FuncKind::kPackedFunc;
return gvar->name_hint;
}
}
LOG(WARNING) << "Undefined global var " << gvar->name_hint;
// undefined global var, consider eliminate later.
*kind = VMFuncInfo::FuncKind::kPackedFunc;
return gvar->name_hint;
} else {
return std::nullopt;
}
}
// Lookup PrimFunc in the same module
// We can do direct PrimFunc call in such cases
ffi::Optional<tirx::PrimFunc> LookupPrimFunc(const ffi::String& name) {
if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt;
GlobalVar gvar = ctx_mod_->GetGlobalVar(name);
auto it = ctx_mod_->functions.find(gvar);
if (it != ctx_mod_->functions.end()) {
BaseFunc func = (*it).second;
if (auto* prim_func = func.as<tirx::PrimFuncNode>()) {
return ffi::GetRef<tirx::PrimFunc>(prim_func);
}
}
return std::nullopt;
}
ffi::Optional<PrimExpr> VisitExpr_(const GlobalVarNode* op) final {
VMFuncInfo::FuncKind kind;
auto symbol = LookupFunction(ffi::GetRef<Expr>(op), &kind);
TVM_FFI_ICHECK(symbol.has_value());
builder_->DeclareFunction(symbol.value(), kind);
return FuncListGet(builder_->GetFunction(symbol.value()).value());
}
ffi::Optional<PrimExpr> VisitExpr_(const ExternFuncNode* op) final {
builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc);
return FuncListGet(builder_->GetFunction(op->global_symbol).value());
}
void EmitAllocStorage(const Call& call_node, int64_t dst_reg) {
// Handle args of the call
ffi::Array<PrimExpr> args;
args.push_back(ctx_ptr_);
for (Expr arg : call_node->args) {
args.push_back(this->VisitExpr(arg).value());
}
this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg);
}
void EmitAllocTensor(const Call& call_node, int64_t dst_reg) {
TVM_FFI_ICHECK_EQ(call_node->args.size(), 5);
ffi::Array<PrimExpr> args;
for (int i = 0; i < 4; ++i) {
args.push_back(this->VisitExpr(call_node->args[i]).value());
}
int64_t vdevice_index = -1;
if (const auto* int_imm = call_node->args[4].as<IntImmNode>()) {
vdevice_index = int_imm->value;
}
auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index);
if (vdevice.defined()) {
args.push_back(tirx::StringImm(vdevice.value()->memory_scope));
}
this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg);
}
int64_t EmitKillObject(const Call& call_node) {
TVM_FFI_ICHECK_EQ(call_node->args.size(), 1);
PrimExpr arg = this->VisitExpr(call_node->args[0]).value();
// Check the arg is a register.
const auto* tir_call = arg.as<CallNode>();
TVM_FFI_ICHECK(tir_call != nullptr);
TVM_FFI_ICHECK(tir_call->op == tirx::builtin::anylist_getitem());
TVM_FFI_ICHECK(tir_call->args.size() == 2);
TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_));
const auto* p_dst_reg = tir_call->args[1].as<tirx::IntImmNode>();
TVM_FFI_ICHECK(p_dst_reg != nullptr);
TVM_FFI_ICHECK(
p_dst_reg->ty.as_or_throw<PrimType>().MatchesElementType(DLDataTypeCode::kDLInt, 32));
int64_t dst_reg = p_dst_reg->value;
this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg);
return dst_reg;
}
void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) {
ffi::Array<PrimExpr> args;
// if context is required, pass as first argument.
args.push_back(ctx_ptr_);
auto* func = call_node->args[0].as<ExternFuncNode>();
TVM_FFI_ICHECK(func) << "CallBuiltin comes with extern func";
auto tuple_arg = call_node->args[1].as_or_throw<Tuple>();
// Handle args of the call
for (Expr arg : tuple_arg->fields) {
args.push_back(this->VisitExpr(arg).value());
}
this->EmitCallPacked(func->global_symbol, args, dst_reg);
}
void EmitNormalCall(const Call& call_node, int64_t dst_reg) {
ffi::Array<PrimExpr> args = VisitArray(call_node->args);
// A function can be a closure that comes from parent
// Do call closure to be safe.
VMFuncInfo::FuncKind kind;
auto symbol = LookupFunction(call_node->op, &kind);
if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) {
// primfunc in the same module.
// use cpacked to directly invoke without named based lookup
if (ffi::Optional<tirx::PrimFunc> prim_func = LookupPrimFunc(symbol.value())) {
this->EmitCallCPacked(prim_func.value(), args, dst_reg);
} else {
this->EmitCallPacked(symbol.value(), args, dst_reg);
}
} else {
// Default path, leverage function table and invoke as closure
ffi::Array<PrimExpr> all_args;
all_args.push_back(ctx_ptr_);
all_args.push_back(this->VisitExpr(call_node->op).value());
for (auto arg : args) {
all_args.push_back(arg);
}
this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg);
}
}
template <typename FLambda>
tirx::Stmt WithNewScope(const FLambda& callback) {
stmt_stack_.push_back({});
callback();
tirx::Stmt stmt = tirx::SeqStmt::Flatten(stmt_stack_.back());
stmt_stack_.pop_back();
return stmt;
}
ffi::Array<PrimExpr> VisitArray(const ffi::Array<Expr>& arr) {
ffi::Array<PrimExpr> ret;
for (size_t i = 0; i < arr.size(); ++i) {
ret.push_back(this->VisitExpr(arr[i]).value());
}
return ret;
}
/*! \brief Internal ExecBuilder. */
relax::ExecBuilder builder_;
/*! \brief List to ctx_ptr */
tirx::Var ctx_ptr_;
/*! \brief List to store temp object registers */
tirx::Var reg_anylist_handle_;
/*! \brief List to store closures */
tirx::Var func_anylist_handle_;
/*! \brief List to store constants */
tirx::Var const_anylist_handle_;
/*!
* \brief Total number of virtual registers allocated.
* \note The first two registers are reserved for special registers.
*/
int64_t registers_num_ = 0;
/*! \brief Stack to build up statements */
std::vector<std::vector<tirx::Stmt>> stmt_stack_;
/*! \brief Map from var to Expr. */
std::unordered_map<Var, ffi::Optional<PrimExpr>> var_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
/*! \brief system lib prefix */
ffi::Optional<ffi::String> system_lib_prefix_;
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
};
/*!
* \brief Create the Relax VM executable from all relax.Function in mod.
* and add them to exec_builder. Create extra TIR functions.
*
* \param exec_builder Builder to collect executables.
* \param mod Input module.
* \return Extra TIR module created.
*/
IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) {
return CodeGenVMTIR::Run(exec_builder, mod);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.VMTIRCodeGen", VMTIRCodeGen);
}
} // namespace codegen_vm
} // namespace relax
} // namespace tvm