blob: 4c128280a71433c8bb52917576a120f0f0d481ff [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/backend/vm/exec_builder.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/exec_builder.h>
#include <sstream>
#include <unordered_set>
namespace tvm {
namespace relax {
using namespace vm;
TVM_FFI_STATIC_INIT_BLOCK() { ExecBuilderNode::RegisterReflection(); }
ExecBuilder ExecBuilderNode::Create() {
ExecBuilder ret(ffi::make_object<ExecBuilderNode>());
ret->exec_ = ffi::make_object<VMExecutable>();
return ret;
}
VMExecutable* ExecBuilderNode::exec() const { return exec_.get(); }
ffi::ObjectPtr<VMExecutable> ExecBuilderNode::Get() {
this->Formalize();
this->CheckExecutable();
return exec_;
}
void ExecBuilderNode::SaveMemoryScope(vm::Instruction::Arg idx, ffi::String scope) {
exec_->memory_scopes[idx.value()] = scope;
}
vm::Instruction::Arg ExecBuilderNode::ConvertConstant_(Any cvalue) {
// emit constant immediate as immediate.
if (auto opt_int = cvalue.as<int64_t>()) {
int64_t val = opt_int.value();
if (val <= vm::Instruction::kValueMaxLimit && val >= vm::Instruction::kValueMinLimit) {
return vm::Instruction::Arg::Immediate(val);
}
}
// run dedup for object with structural equality
auto it = const_dedup_map_.find(cvalue);
if (it != const_dedup_map_.end()) {
return vm::Instruction::Arg::ConstIdx(it->second);
}
vm::Index idx = exec_->constants.size();
exec_->constants.push_back(cvalue);
const_dedup_map_[cvalue] = idx;
return vm::Instruction::Arg::ConstIdx(idx);
}
void ExecBuilderNode::DeclareFunction(const std::string& func_name, VMFuncInfo::FuncKind kind) {
auto it = exec_->func_map.find(func_name);
if (it != exec_->func_map.end()) {
TVM_FFI_ICHECK(kind == exec_->func_table[it->second].kind)
<< "Function " << func_name << "already declared in a different kind";
return;
}
VMFuncInfo vmfunc;
vmfunc.kind = kind;
vmfunc.name = func_name;
// use num args to mark undefined.
vmfunc.start_instr = 0;
vmfunc.num_args = -2;
vmfunc.register_file_size = 0;
exec_->func_map[func_name] = exec_->func_table.size();
exec_->func_table.push_back(vmfunc);
}
vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) {
auto it = exec_->func_map.find(func_name);
TVM_FFI_ICHECK(it != exec_->func_map.end()) << "Cannot find function " << func_name;
return vm::Instruction::Arg::FuncIdx(it->second);
}
void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs,
ffi::Optional<ffi::Array<ffi::String>> param_names,
vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) {
auto it = exec_->func_map.find(func_name);
if (it == exec_->func_map.end()) {
this->DeclareFunction(func_name, kind);
}
auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name));
TVM_FFI_ICHECK_EQ(vmfunc.name, func_name);
TVM_FFI_ICHECK_EQ(vmfunc.num_args, -2) << "Function " << func_name << " already defined";
vmfunc.num_args = num_inputs;
if (param_names.defined()) {
TVM_FFI_ICHECK_EQ(num_inputs, param_names.value().size())
<< "Function " << func_name << " defined with " << num_inputs << " arguments, "
<< "but the list of parameter names has " << param_names.value().size() << " names ("
<< param_names << ")";
std::vector<std::string> names;
for (auto name : param_names.value()) {
names.push_back(name);
}
vmfunc.param_names = names;
}
vmfunc.register_file_size = init_register_size;
if (kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
vmfunc.start_instr = exec_->instr_offset.size();
}
}
void ExecBuilderNode::EndFunction(const std::string& func_name) {
auto it = exec_->func_map.find(func_name);
TVM_FFI_ICHECK(it != exec_->func_map.end());
VMFuncInfo& vmfunc = exec_->func_table.at(it->second);
TVM_FFI_ICHECK_EQ(vmfunc.end_instr, 0) << "EndFuncton can only be called once";
if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
vmfunc.end_instr = exec_->instr_offset.size();
}
}
void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args,
vm::RegName dst) {
TVM_FFI_ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx);
// store instruction
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
exec_->instr_data.push_back(dst);
exec_->instr_data.push_back(func.value());
exec_->instr_data.push_back(args.size());
for (Instruction::Arg arg : args) {
exec_->instr_data.push_back(arg.data());
}
}
void ExecBuilderNode::EmitCall(const std::string& func, std::vector<Instruction::Arg> args,
RegName dst) {
auto it = exec_->func_map.find(func);
if (it == exec_->func_map.end()) {
this->DeclareFunction(func, VMFuncInfo::FuncKind::kPackedFunc);
}
Index func_idx = exec_->func_map.at(func);
EmitCall(vm::Instruction::Arg::FuncIdx(func_idx), args, dst);
}
void ExecBuilderNode::EmitRet(vm::Instruction::Arg result) {
TVM_FFI_ICHECK(result.kind() == vm::Instruction::ArgKind::kRegister);
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Ret));
exec_->instr_data.push_back(result.value());
}
void ExecBuilderNode::EmitGoto(Index pc_offset) {
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Goto));
exec_->instr_data.push_back(pc_offset);
}
void ExecBuilderNode::EmitIf(vm::Instruction::Arg cond, vm::Index false_offset) {
TVM_FFI_ICHECK(cond.kind() == vm::Instruction::ArgKind::kRegister);
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::If));
exec_->instr_data.push_back(cond.value());
exec_->instr_data.push_back(false_offset);
}
void ExecBuilderNode::CheckExecutable() {
for (auto it = exec_->func_table.cbegin(); it != exec_->func_table.cend(); ++it) {
if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue;
if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) {
TVM_FFI_ICHECK_GE(it->register_file_size, it->num_args + 1)
<< "Function " << it->name << " do not meet register file constraint.";
continue;
}
Index num_inputs = it->num_args;
std::unordered_set<RegName> dst_registers;
std::unordered_set<RegName> arg_registers;
size_t start_instr = it->start_instr;
size_t end_instr = it->end_instr;
TVM_FFI_ICHECK_LT(start_instr, end_instr)
<< "Function " << it->name << " EndFunction has not be been called";
auto check_reg_defined = [&](Instruction::Arg arg) {
if (arg.kind() != Instruction::ArgKind::kRegister) return;
if (arg.value() >= Instruction::kBeginSpecialReg) return;
if (arg.value() < num_inputs) return;
if (dst_registers.find(arg.value()) == dst_registers.end()) {
TVM_FFI_THROW(InternalError) << "register r(" << arg.value() << ") in VM function \""
<< it->name << "\" is used as input while it is never defined"
<< " as a destination. Dump:\n"
<< exec_->AsText();
}
};
auto check_const_defined = [&](Instruction::Arg arg) {
if (arg.kind() != Instruction::ArgKind::kConstIdx) return;
TVM_FFI_ICHECK_LT(arg.value(), exec_->constants.size())
<< "Constant index " << arg.value() << " exceed size of constant pool. Dump:\n"
<< exec_->AsText();
};
auto check_func_defined = [&](Instruction::Arg arg) {
if (arg.kind() != Instruction::ArgKind::kFuncIdx) return;
TVM_FFI_ICHECK_LT(arg.value(), exec_->func_table.size())
<< "Func index " << arg.value() << " exceed size of fun_table. Dump:\n"
<< exec_->AsText();
};
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call: {
check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx));
for (int i = 0; i < instr.num_args; ++i) {
check_reg_defined(instr.args[i]);
check_const_defined(instr.args[i]);
check_func_defined(instr.args[i]);
arg_registers.emplace(instr.args[i].value());
}
if (instr.dst != Instruction::kVoidRegister) {
dst_registers.emplace(instr.dst);
}
break;
}
case Opcode::Ret: {
arg_registers.emplace(instr.result);
check_reg_defined(Instruction::Arg::Register(instr.result));
break;
}
case Opcode::Goto: {
TVM_FFI_ICHECK_NE(instr.pc_offset, 0);
break;
}
case Opcode::If: {
TVM_FFI_ICHECK_GT(instr.false_offset, 1);
check_reg_defined(Instruction::Arg::Register(instr.cond));
arg_registers.emplace(instr.cond);
break;
}
default:
TVM_FFI_THROW(InternalError)
<< "should never hit this case: " << static_cast<int>(instr.op);
break;
}
}
}
}
void ExecBuilderNode::Formalize() {
// a pass to formalize user-specified register indexes in the order of use
// and decide the number of registers to allocate for each VMFunction in the VMExecutable
for (auto it = this->exec_->func_table.begin(); it != this->exec_->func_table.end(); ++it) {
if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue;
if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) continue;
Index num_inputs = it->num_args;
RegName register_idx = num_inputs;
std::unordered_map<RegName, RegName> register_map;
size_t start_instr = it->start_instr;
size_t end_instr = it->end_instr;
for (size_t idx = start_instr; idx < end_instr; ++idx) {
Instruction instr = this->exec_->GetInstruction(idx);
switch (instr.op) {
case Opcode::Call: {
// rewrite args
for (int i = 0; i < instr.num_args; ++i) {
if (instr.args[i].kind() == Instruction::ArgKind::kRegister &&
instr.args[i].value() >= num_inputs &&
instr.args[i].value() < Instruction::kBeginSpecialReg &&
register_map.find(instr.args[i].value()) != register_map.end()) {
this->exec_->instr_data[this->exec_->instr_offset[idx] + 4 + i] =
register_map[instr.args[i].value()];
}
}
if (instr.dst >= num_inputs && instr.dst < Instruction::kBeginSpecialReg) {
auto it = register_map.find(instr.dst);
if (it != register_map.end()) {
this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = it->second;
} else {
this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_idx;
register_map[instr.dst] = register_idx++;
}
}
break;
}
case Opcode::Ret: {
if (register_map.find(instr.result) != register_map.end()) {
this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] =
register_map[instr.result];
}
break;
}
case Opcode::Goto: {
break;
}
case Opcode::If: {
if (register_map.find(instr.cond) != register_map.end()) {
this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_map[instr.cond];
}
break;
}
default:
TVM_FFI_THROW(InternalError)
<< "should never hit this case: " << static_cast<int>(instr.op);
break;
}
}
it->register_file_size = register_idx;
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("relax.ExecBuilderCreate", ExecBuilderNode::Create)
.def_packed("relax.ExecBuilderConvertConstant",
[](ffi::PackedArgs args, ffi::Any* ret) {
ExecBuilder builder = args[0].cast<ExecBuilder>();
ffi::Any rt;
rt = args[1];
*ret = builder->ConvertConstant(rt).data();
})
.def("relax.ExecBuilderEmitFunction",
[](ExecBuilder builder, ffi::String func, int64_t num_inputs,
ffi::Optional<ffi::Array<ffi::String>> param_names) {
builder->EmitFunction(func, num_inputs, param_names);
})
.def_method("relax.ExecBuilderEndFunction", &ExecBuilderNode::EndFunction)
.def("relax.ExecBuilderDeclareFunction",
[](ExecBuilder builder, ffi::String name, int32_t kind) {
builder->DeclareFunction(name, static_cast<VMFuncInfo::FuncKind>(kind));
})
.def("relax.ExecBuilderEmitCall",
[](ExecBuilder builder, ffi::String name, ffi::Array<IntImm> args, int64_t dst) {
std::vector<Instruction::Arg> args_;
for (size_t i = 0; i < args.size(); ++i) {
args_.push_back(Instruction::Arg::FromData(args[i]->value));
}
auto dst_ = Instruction::Arg::Register(dst);
builder->EmitCall(name, args_, dst_.value());
})
.def("relax.ExecBuilderEmitRet",
[](ExecBuilder builder, int64_t data) {
builder->EmitRet(Instruction::Arg::FromData(data));
})
.def_method("relax.ExecBuilderEmitGoto", &ExecBuilderNode::EmitGoto)
.def("relax.ExecBuilderEmitIf",
[](ExecBuilder builder, int64_t data, vm::Index false_offset) {
builder->EmitIf(Instruction::Arg::FromData(data), false_offset);
})
.def("relax.ExecBuilderR",
[](ExecBuilder builder, int64_t value) {
return Instruction::Arg::Register(value).data();
})
.def("relax.ExecBuilderImm",
[](ExecBuilder builder, int64_t value) {
return Instruction::Arg::Immediate(value).data();
})
.def("relax.ExecBuilderC",
[](ExecBuilder builder, int64_t value) {
return Instruction::Arg::ConstIdx(value).data();
})
.def(
"relax.ExecBuilderF",
[](ExecBuilder builder, ffi::String value) { return builder->GetFunction(value).data(); })
.def("relax.ExecBuilderGet", [](ExecBuilder builder) {
ffi::ObjectPtr<VMExecutable> p_exec = builder->Get();
return ffi::Module(p_exec);
});
}
} // namespace relax
} // namespace tvm