blob: bf29556768c55ce645baf38c4ff8ed5f057008c7 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/backend/vm/codegen_vm.cc
* \brief A codegen to generate VM executable from a Relax IRModule.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/exec_builder.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/runtime/vm/bytecode.h>
#include <tvm/target/target.h>
#include <tvm/tirx/function.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "../../../runtime/const_loader_module.h"
#include "../../../target/source/codegen_source_base.h"
#include "../../transform/utils.h"
namespace tvm {
namespace relax {
namespace codegen_vm {
using tvm::Target;
using namespace relax;
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief A class to generate VM executable for Relax functions.
*/
class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
public:
explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod)
: builder_(builder), ctx_mod_(ctx_mod) {}
static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
IRModule res_mod = mod;
res_mod.CopyOnWrite();
CodeGenVM codegen(builder, mod);
// Remove relax function and turn into TIR func.
for (const auto& [gvar, f] : mod->functions) {
if (auto* func = f.as<FunctionNode>()) {
codegen.Codegen(ffi::GetRef<Function>(func));
res_mod->Remove(gvar);
}
}
return res_mod;
}
protected:
size_t NewRegister() { return registers_num_++; }
// Convert Arg value to a register, trigger copy if needed
Instruction::Arg EnsureReg(Instruction::Arg arg) {
if (arg.kind() == Instruction::ArgKind::kRegister) {
return arg;
} else {
RegName dst_reg = NewRegister();
builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg);
return Instruction::Arg::Register(dst_reg);
}
}
void Codegen(const Function& func) {
ffi::Optional<ffi::String> gsymbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(gsymbol.has_value())
<< "there should be no local functions in Relax VM codegen phase. "
"Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?";
ffi::Array<ffi::String> param_names;
for (Var param : func->params) {
param_names.push_back(param->name_hint());
}
builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names);
for (size_t i = 0; i < func->params.size(); ++i) {
RegName r = NewRegister();
TVM_FFI_ICHECK_EQ(r, static_cast<RegName>(i));
this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)});
}
Instruction::Arg ret = ExprFunctor::VisitExpr(func->body);
builder_->EmitRet(EnsureReg(ret));
builder_->EndFunction(gsymbol.value());
// reset register number to be 0;
registers_num_ = 0;
var_arg_map_.clear();
}
Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Expr expr = GetBoundValue(binding);
Instruction::Arg value = VisitExpr(expr);
if (expr.as<VarNode>()) {
// For a normalized relax module, there should be one
// register for each relax::Binding. This makes the Relax
// semantics of R.vm.kill_* operate the same as the Python
// "del" operator. These bindings may be removable by using
// relax.transform.CanonicalizeBindings earlier in lowering.
RegName new_reg = NewRegister();
builder_->EmitCall("vm.builtin.copy", {value}, new_reg);
value = Instruction::Arg::Register(new_reg);
}
this->var_arg_map_.insert({binding->var, value});
}
}
Instruction::Arg ret_reg = this->VisitExpr(op->body);
return ret_reg;
}
Instruction::Arg VisitExpr_(const CallNode* call_node) final {
Call call = ffi::GetRef<Call>(call_node);
if (call_node->op == null_value_op_) {
return Instruction::Arg::Register(Instruction::kVoidRegister);
}
// allocate dst register.
RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister();
if (call->op.as<OpNode>()) {
if (call_node->op == call_builtin_with_ctx_op_) {
// TODO(relax-team) migrate most handling of op to
// directly map to call_builtin_with_ctx before codegen and simplify vm codegen.
EmitCallBuiltinWithCtx(call, dst_reg);
} else if (call_node->op == alloc_storage_op_) {
EmitAllocStorage(call, dst_reg);
} else if (call_node->op == alloc_tensor_op_) {
EmitAllocTensor(call, dst_reg);
} else if (call_node->op == kill_object_op_) {
dst_reg = EmitKillObject(call);
} else {
// every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
// ops are handled in a pass when lowering them to TIR.
TVM_FFI_THROW(InternalError) << "CodeGenVM cannot handle this intrinsic now:\n"
<< call_node->op;
}
} else {
EmitNormalCall(call, dst_reg);
}
return Instruction::Arg::Register(dst_reg);
}
Instruction::Arg VisitExpr_(const IfNode* op) final {
const If& ife = ffi::GetRef<If>(op);
Instruction::Arg cond_value = this->VisitExpr(ife->cond);
// Reserve a register for cond
RegName cond_reg = NewRegister();
builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg);
// obtain the temp exec in progress.
vm::VMExecutable* exec = builder_->exec();
// Record the offset of If instruction
size_t if_offset = exec->instr_offset.size();
builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3);
size_t num_instr = exec->instr_offset.size();
Instruction::Arg true_value = this->VisitExpr(ife->true_branch);
// Reserve a register for return
size_t merge_register = NewRegister();
// Copy the output from true branch to merge register
builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register);
// Record the offset of Goto instruction
size_t goto_offset = exec->instr_offset.size();
builder_->EmitGoto(1);
// Calculate the false offset of If
size_t false_offset = exec->instr_offset.size() - num_instr + 1;
Instruction::Arg false_value = this->VisitExpr(ife->false_branch);
// Copy the output data of false branch to merge register
builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register);
// Update the offsets of the If instruction emitted above
// Jump to the behind of the next goto instruction
exec->SetInstructionData(if_offset, 2, static_cast<ExecWord>(false_offset));
// Update the pc_offset of Goto instruction
// Jump over the false branch
size_t pc_offset = exec->instr_offset.size() - goto_offset;
exec->SetInstructionData(goto_offset, 1, static_cast<ExecWord>(pc_offset));
return Instruction::Arg::Register(merge_register);
}
Instruction::Arg VisitExpr_(const VarNode* op) final {
Var var = ffi::GetRef<Var>(op);
auto it = this->var_arg_map_.find(var);
TVM_FFI_ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined";
return it->second;
}
Instruction::Arg VisitExpr_(const ConstantNode* op) final {
auto arg = builder_->ConvertConstant(op->data);
if (auto tsinfo = op->struct_info_.as<TensorStructInfoNode>()) {
if (tsinfo->vdevice.defined()) {
VDevice vdev = tsinfo->vdevice.value();
builder_->SaveMemoryScope(arg, vdev->memory_scope);
}
}
return arg;
}
Instruction::Arg VisitExpr_(const ShapeExprNode* op) final {
std::vector<int64_t> shape;
for (PrimExpr e : op->values) {
if (auto* int_value = e.as<IntImmNode>()) {
shape.push_back(int_value->value);
} else {
TVM_FFI_THROW(InternalError)
<< "Should only use constant shape after shape lowering: " << op->values;
}
}
return builder_->ConvertConstant(ffi::Shape(shape));
}
Instruction::Arg VisitExpr_(const PrimValueNode* op) final {
if (auto* int_imm = op->value.as<IntImmNode>()) {
return builder_->ConvertConstant(int_imm->value);
} else if (auto* float_imm = op->value.as<FloatImmNode>()) {
return builder_->ConvertConstant(float_imm->value);
} else {
TVM_FFI_THROW(InternalError)
<< "PrimValue should only contain constant after VMShapeLower, "
<< "but received " << ffi::GetRef<Expr>(op) << " with type " << op->value->GetTypeKey();
TVM_FFI_UNREACHABLE();
}
}
Instruction::Arg VisitExpr_(const StringImmNode* op) final {
return builder_->ConvertConstant(op->value);
}
Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final {
return builder_->ConvertConstant(op->value);
}
Instruction::Arg VisitExpr_(const TupleNode* op) final {
Tuple tuple = ffi::GetRef<Tuple>(op);
std::vector<Instruction::Arg> args;
for (Expr arg : tuple->fields) {
args.push_back(this->VisitExpr(arg));
}
size_t dst_register = NewRegister();
builder_->EmitCall("vm.builtin.make_tuple", args, dst_register);
return Instruction::Arg::Register(dst_register);
}
Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final {
TupleGetItem expr = ffi::GetRef<TupleGetItem>(op);
std::vector<Instruction::Arg> args = {this->VisitExpr(expr->tuple)};
args.push_back(builder_->ConvertConstant(expr->index));
size_t dst_register = NewRegister();
builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register);
return Instruction::Arg::Register(dst_register);
}
Instruction::Arg VisitExpr_(const GlobalVarNode* op) final {
GlobalVar gvar = ffi::GetRef<GlobalVar>(op);
ffi::Optional<ffi::String> symbol;
VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc;
// Run a look up in the env to see if it maps to an extern func.
auto it = ctx_mod_->functions.find(gvar);
if (it != ctx_mod_->functions.end()) {
BaseFunc func = (*it).second;
if (auto* efunc = func.as<ExternFuncNode>()) {
symbol = efunc->global_symbol;
kind = VMFuncInfo::FuncKind::kPackedFunc;
} else if (func.as<FunctionNode>()) {
symbol = gvar->name_hint;
kind = VMFuncInfo::FuncKind::kVMFunc;
}
}
// GlobalVar can be reference to a Relax function or a TIR primfunc
// At this point: all global var must corresponds to the right symbol.
// TODO(relax-team): switch everything to extern before splitting TIR/relax
// so we do not have idle global var here.
if (!symbol.has_value()) {
symbol = gvar->name_hint;
kind = VMFuncInfo::FuncKind::kPackedFunc;
}
// declare the function to be safe.
TVM_FFI_ICHECK(symbol.has_value());
builder_->DeclareFunction(symbol.value(), kind);
return builder_->GetFunction(symbol.value());
}
Instruction::Arg VisitExpr_(const ExternFuncNode* op) final {
static const constexpr char* kCSource = "c_source";
static const constexpr char* kCSourceFmt = "c_source_fmt";
if (ffi::Optional<ffi::String> opt_code = op->attrs.GetAttr<ffi::String>(kCSource)) {
ffi::String sym = op->global_symbol;
ffi::String fmt = op->attrs.GetAttr<ffi::String>(kCSourceFmt).value_or("c");
ffi::String code = opt_code.value();
ffi::Module c_source_module =
codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym},
/*const_vars=*/{});
builder_->exec()->ImportModule(c_source_module);
}
builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc);
return builder_->GetFunction(op->global_symbol);
}
void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
TVM_FFI_ICHECK_EQ(call_node->args.size(), 4);
// Handle args of the call
std::vector<Instruction::Arg> args;
args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
// buffer size, dtype, device index
for (auto arg : call_node->args) {
args.push_back(this->VisitExpr(arg));
}
builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg);
}
void EmitAllocTensor(const Call& call_node, RegName dst_reg) {
TVM_FFI_ICHECK_EQ(call_node->args.size(), 5);
std::vector<Instruction::Arg> args;
for (int i = 0; i < 4; ++i) {
args.push_back(this->VisitExpr(call_node->args[i]));
}
int64_t vdevice_index = -1;
if (auto* prim_value_node = call_node->args[4].as<PrimValueNode>()) {
vdevice_index = prim_value_node->value.as<IntImmNode>()->value;
}
auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index);
if (vdevice.defined()) {
args.push_back(this->VisitExpr(StringImm(vdevice.value()->memory_scope)));
}
builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
}
RegName EmitKillObject(const Call& call_node) {
TVM_FFI_ICHECK_EQ(call_node->args.size(), 1);
Instruction::Arg arg = this->VisitExpr(call_node->args[0]);
TVM_FFI_ICHECK(arg.kind() == Instruction::ArgKind::kRegister)
<< "Expected the object to be killed to be stored in a register, "
<< "but argument " << call_node->args[0] << " produced VM instruction of type "
<< arg.kind();
RegName dst_reg = arg.value();
builder_->EmitCall("vm.builtin.null_value", {}, dst_reg);
return dst_reg;
}
void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) {
std::vector<Instruction::Arg> args;
args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
auto func = this->VisitExpr(call_node->args[0]);
auto tuple_arg = Downcast<Tuple>(call_node->args[1]);
// Handle args of the call
for (Expr arg : tuple_arg->fields) {
args.push_back(this->VisitExpr(arg));
}
builder_->EmitCall(func, args, dst_reg);
}
void EmitNormalCall(const Call& call_node, RegName dst_reg) {
Instruction::Arg func = VisitExpr(call_node->op);
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
builder_->EmitCall(func, args, dst_reg);
} else {
std::vector<Instruction::Arg> closure_args = {
Instruction::Arg::Register(Instruction::kVMRegister), func};
std::copy(args.begin(), args.end(), std::back_inserter(closure_args));
builder_->EmitCall("vm.builtin.invoke_closure", closure_args, dst_reg);
}
}
// Emits call to packed function `name` with arguments copied over from `call_node` args
void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) {
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
builder_->EmitCall(name, args, dst_reg);
}
std::vector<Instruction::Arg> VisitArray(const ffi::Array<Expr>& arr) {
std::vector<Instruction::Arg> ret;
for (size_t i = 0; i < arr.size(); ++i) {
ret.push_back(this->VisitExpr(arr[i]));
}
return ret;
}
/*! \brief Internal ExecBuilder. */
relax::ExecBuilder builder_;
/*!
* \brief Total number of virtual registers allocated.
* \note The first two registers are reserved for special registers.
*/
size_t registers_num_ = 0;
/*! \brief Map from var to register number. */
std::unordered_map<Var, Instruction::Arg> var_arg_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
};
/*!
* \brief Create the Relax VM executable from all relax.Function in mod.
* and add them to exec_builder.
* \param exec_builder Builder to collect executables.
* \param mod Input module.
* \return Left over IRModule that may contain other functions.
*/
IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
return CodeGenVM::Run(exec_builder, mod);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.VMCodeGen", VMCodeGen);
}
/*!
* \brief Link the modules together, possibly create a constant module.
*
* \param params The metadata for initialization of all modules.
* \param lib the internal module that is compiled by tvm.
* \param ext_libs The external modules that needs to be imported inside the metadata
* module(s).
* \return The created module.
*/
void LinkModules(ObjectPtr<VMExecutable> exec, const ffi::Map<ffi::String, runtime::Tensor>& params,
const tvm::ffi::Module& lib, const ffi::Array<ffi::Module>& ext_libs) {
// query if we need const loader for ext_modules
// Wrap all submodules in the initialization wrapper.
ffi::Map<ffi::String, ffi::Array<ffi::String>> const_vars_by_symbol;
for (tvm::ffi::Module mod : ext_libs) {
auto pf_sym = mod->GetFunction("get_symbol");
auto pf_var = mod->GetFunction("get_const_vars");
if (pf_sym.has_value() && pf_var.has_value()) {
ffi::String symbol = (*pf_sym)().cast<ffi::String>();
ffi::Array<ffi::String> variables = (*pf_var)().cast<ffi::Array<ffi::String>>();
TVM_FFI_ICHECK(!const_vars_by_symbol.count(symbol)) << "Found duplicated symbol: " << symbol;
const_vars_by_symbol.Set(symbol, variables);
}
}
if (!const_vars_by_symbol.empty() || !params.empty()) {
// need runtime const information, run link const loader
ffi::Module const_loader_mod = runtime::ConstLoaderModuleCreate(params, const_vars_by_symbol);
const_loader_mod->ImportModule(lib);
for (const auto& it : ext_libs) {
const_loader_mod->ImportModule(it);
}
exec->ImportModule(const_loader_mod);
} else {
// directly import the ext_modules as we don't need const loader
exec->ImportModule(lib);
for (const auto& it : ext_libs) {
exec->ImportModule(it);
}
}
}
/*!
* \brief Link the libraries together.
*/
ffi::Module VMLink(ExecBuilder builder, Target target, ffi::Optional<ffi::Module> lib,
ffi::Array<ffi::Module> ext_libs,
ffi::Map<ffi::String, runtime::Tensor> params) {
ObjectPtr<VMExecutable> executable = builder->Get();
if (!lib.defined()) {
lib = codegen::CSourceModuleCreate(";", "c", ffi::Array<ffi::String>{});
}
LinkModules(executable, params, lib.value(), ext_libs);
return ffi::Module(executable);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.VMLink", VMLink);
}
} // namespace codegen_vm
} // namespace relax
} // namespace tvm