blob: 906d79b176ce3b6b32f80608a76fc1a8b9ae64c8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_cpu.cc
*/
#ifdef TVM_LLVM_VERSION
#include "codegen_cpu.h"
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/Comdat.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DIBuilder.h>
#include <llvm/IR/DebugInfoMetadata.h>
#include <llvm/IR/DebugLoc.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalVariable.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#include <tvm/ffi/reflection/registry.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/module.h>
#include <tvm/tir/analysis.h>
#include <algorithm>
#include <memory>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include "llvm_instance.h"
namespace tvm {
namespace codegen {
// Make these non-inline because of std::unique_ptr. See comment in
// codegen_llvm.cc for more information.
CodeGenCPU::CodeGenCPU() = default;
CodeGenCPU::~CodeGenCPU() = default;
void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target,
ffi::Optional<ffi::String> system_lib_prefix, bool dynamic_lookup,
bool target_c_runtime) {
CodeGenLLVM::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime);
system_lib_prefix_ = system_lib_prefix;
dbg_info_ = CreateDebugInfo(module_.get());
func_handle_map_.clear();
export_system_symbols_.clear();
// Runtime types.
t_tvm_shape_index_ =
llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits());
// Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
// typedef struct { DLDeviceType device_type; int device_id; } DLDevice;
t_tvm_device_ = llvm::StructType::create({t_int_, t_int_});
// Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
// typedef struct { uint8_t code; uint8_t bits; uint16_t lanes; } DLDataType;
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
// Defined in include/tvm/runtime/base.h:
// typedef void* TVMFunctionHandle;
t_tvm_func_handle_ = t_void_p_;
// Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
// typedef struct { ... } DLTensor;
t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_device_, t_int_, t_tvm_type_,
llvmGetPointerTo(t_tvm_shape_index_, 0),
llvmGetPointerTo(t_tvm_shape_index_, 0), t_int64_});
// Defined in include/tvm/ffi/c_api.h:
t_tvm_ffi_any_ = llvm::StructType::create({t_int32_, t_int32_, t_float64_});
// Defined in include/tvm/runtime/c_backend_api.h:
// typedef struct { void* sync_handle; int32_t num_task; } TVMParallelGroupEnv;
t_tvm_parallel_group_env_ = llvm::StructType::create({llvmGetPointerTo(t_int32_, 0), t_int32_});
// Defined in include/tvm/ffi/c_api.h:
// typedef int (*)(void* self, const TVMFFIAny* args, int32_t num_args,
// TVMFFIAny* result);
ftype_tvm_ffi_c_func_ = llvm::FunctionType::get(
t_int_,
{t_void_p_, llvmGetPointerTo(t_tvm_ffi_any_, 0), t_int_, llvmGetPointerTo(t_tvm_ffi_any_, 0)},
false);
// Defined in include/tvm/runtime/c_backend_api.h:
// typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
t_int_, {t_int_, llvmGetPointerTo(t_tvm_parallel_group_env_, 0), t_void_p_}, false);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_);
// Runtime functions.
// Defined in include/tvm/ffi/c_api.h:
// int TVMFFIFunctionCall(TVMFunctionHandle func, TVMFFIAny* args, int32_t num_args,
// TVMFFIAny* result);
ftype_tvm_ffi_func_call_ = ftype_tvm_ffi_c_func_;
// Defined in include/tvm/ffi/c_api.h:
// void TVMFFIErrorSetRaisedFromCStr(const char *kind, const char* msg);
ftype_tvm_ffi_error_set_raised_by_c_str_ = llvm::FunctionType::get(
t_void_, {llvmGetPointerTo(t_char_, 0), llvmGetPointerTo(t_char_, 0)}, false);
// Defined in include/tvm/runtime/c_backend_api.h:
// int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out);
ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(
t_int_, {t_void_p_, llvmGetPointerTo(t_char_, 0), llvmGetPointerTo(t_tvm_func_handle_, 0)},
false);
// Defined in include/tvm/runtime/c_backend_api.h:
// int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);
ftype_tvm_parallel_launch_ = llvm::FunctionType::get(
t_int_, {llvmGetPointerTo(ftype_tvm_parallel_lambda_, 0), t_void_p_, t_int_}, false);
// Defined in include/tvm/runtime/c_backend_api.h:
// int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
ftype_tvm_parallel_barrier_ = llvm::FunctionType::get(
t_int_, {t_int_, llvmGetPointerTo(t_tvm_parallel_group_env_, 0)}, false);
ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false);
ftype_tvm_static_init_ = llvm::FunctionType::get(
t_int_,
{llvmGetPointerTo(t_void_p_, 0), llvmGetPointerTo(ftype_tvm_static_init_callback_, 0),
t_void_p_, t_int_},
false);
// initialize TVM runtime API
if (system_lib_prefix_.has_value() && !target_c_runtime) {
// We will need this in environment for backward registration.
// Defined in include/tvm/runtime/c_backend_api.h:
// int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr);
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {llvmGetPointerTo(t_char_, 0), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMFFIEnvModRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib_prefix_.has_value()) {
f_tvm_ffi_func_call_ =
llvm::Function::Create(ftype_tvm_ffi_func_call_, llvm::Function::ExternalLinkage,
"TVMFFIFunctionCall", module_.get());
f_tvm_ffi_set_raised_by_c_str_ = llvm::Function::Create(
ftype_tvm_ffi_error_set_raised_by_c_str_, llvm::Function::ExternalLinkage,
"TVMFFIErrorSetRaisedFromCStr", module_.get());
f_tvm_get_func_from_env_ =
llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage,
"TVMBackendGetFuncFromEnv", module_.get());
f_tvm_parallel_launch_ =
llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage,
"TVMBackendParallelLaunch", module_.get());
f_tvm_parallel_barrier_ =
llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage,
"TVMBackendParallelBarrier", module_.get());
}
target_c_runtime_ = target_c_runtime;
InitGlobalContext(dynamic_lookup);
}
llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(llvm::StringRef name,
const ffi::Array<Type>& param_types,
const Type& return_type) {
#if TVM_LLVM_VERSION < 50
return nullptr;
#else
llvm::SmallVector<llvm::Metadata*, 4> paramTys;
paramTys.push_back(GetDebugType(return_type));
for (const auto& param_type : param_types) {
paramTys.push_back(GetDebugType(param_type));
}
auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType(
dbg_info_->di_builder_->getOrCreateTypeArray(paramTys));
bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage);
#if TVM_LLVM_VERSION >= 80
auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true,
/*IsOptimized=*/true);
#else
bool SPFlags = /*IsOptimized=*/true;
#endif
auto* DIFunction = dbg_info_->di_builder_->createFunction(
/*Scope=*/dbg_info_->file_, /*Name=*/name, /*LinkageName=*/"",
/*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy,
/*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagPrototyped, /*SPFlags=*/SPFlags);
return DIFunction;
#endif
}
llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const GlobalVar& gvar, const PrimFunc& func) {
std::string name = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint);
return CreateDebugFunction(name, func->params.Map(GetType), func->ret_type);
}
void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
di_subprogram_ = CreateDebugFunction(gvar, func);
EmitDebugLocation(func->span);
CodeGenLLVM::AddFunction(gvar, func);
if (f_tvm_register_system_symbol_ != nullptr) {
if (auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
}
AddDebugInformation(function_, func->params.Map(GetType));
}
void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) {
// main already exists, no need to create a wrapper function
// main takes precedence over other entry functions
return;
}
// create a wrapper function with tvm_ffi_main name and redirects to the entry function
llvm::Function* target_func = module_->getFunction(entry_func_name);
TVM_FFI_ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module";
// Create wrapper function
llvm::Function* wrapper_func =
llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage,
ffi::symbol::tvm_ffi_main, module_.get());
// Set attributes (Windows comdat, DLL export, etc.)
if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
llvm::Comdat* comdat = module_->getOrInsertComdat(ffi::symbol::tvm_ffi_main);
comdat->setSelectionKind(llvm::Comdat::Any);
wrapper_func->setComdat(comdat);
}
wrapper_func->setCallingConv(llvm::CallingConv::C);
wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// Create simple tail call
llvm::BasicBlock* entry =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", wrapper_func);
builder_->SetInsertPoint(entry);
// Forward all arguments to target function
std::vector<llvm::Value*> call_args;
for (llvm::Value& arg : wrapper_func->args()) {
call_args.push_back(&arg);
}
llvm::Value* result = builder_->CreateCall(target_func, call_args);
if (target_func->getReturnType()->isVoidTy()) {
builder_->CreateRetVoid();
} else {
builder_->CreateRet(result);
}
}
std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
// link modules
if (dbg_info_ != nullptr) {
dbg_info_->di_builder_->finalize();
}
return CodeGenLLVM::Finish();
}
CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf,
llvm::Value* index, int kind) {
if (kind < builtin::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_array_, 0));
} else {
TVM_FFI_ICHECK_EQ(buf->getType(), llvmGetPointerTo(t_tvm_array_, 0));
}
}
switch (kind) {
case builtin::kArrAddr: {
return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index));
}
case builtin::kArrData: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrShape: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(4);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrStrides: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(5);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrNDim: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeCode: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeBits: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeLanes: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrByteOffset: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(6);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceId: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceType: {
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kTVMFFIAnyTypeIndex: {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0));
buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(0)});
return TypedPointer(t_int32_, buf);
}
case builtin::kTVMFFIAnyZeroPadding: {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0));
buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(1)});
return TypedPointer(t_int32_, buf);
}
case builtin::kTVMFFIAnyUnionValue: {
TVM_FFI_ICHECK_EQ(t.lanes(), 1);
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0));
// field 2 is the union value
buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(2)});
if (t.is_bool()) {
// it should be safe to set the pointer to the first byte of the union value
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(DTypeToLLVMType(t), 0));
return TypedPointer(t_int8_, buf);
} else if (t.is_int() && t.bits() == 64) {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_int64_, 0));
return TypedPointer(t_int64_, buf);
} else if (t.is_float() && t.bits() == 64) {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_float64_, 0));
return TypedPointer(t_float64_, buf);
} else if (t.is_handle()) {
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_void_p_, 0));
return TypedPointer(t_void_p_, buf);
} else {
LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMFFIAny's value field";
}
}
default:
TVM_FFI_THROW(InternalError) << "unknown field code";
}
}
llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr>& args, bool skip_first_arg) {
std::vector<llvm::Value*> arg_values;
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
arg_values.push_back(MakeValue(args[i]));
}
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false);
// Check if it is available in global function table as injected function.
auto callee = [&]() -> llvm::Value* {
if (auto it = gv_func_map_.find(global_symbol); it != gv_func_map_.end()) {
if (it->second == nullptr) {
it->second = InitContextPtr(llvmGetPointerTo(ftype, 0), "__" + global_symbol);
}
return GetContextPtr(it->second);
} else if (llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol))) {
return f;
} else {
return llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
MakeStringRef(global_symbol), module_.get());
}
}();
if (callee->getType() != llvmGetPointerTo(ftype, 0)) {
callee = builder_->CreatePointerCast(callee, llvmGetPointerTo(ftype, 0));
}
return builder_->CreateCall(ftype, callee, arg_values);
}
llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name);
#if TVM_LLVM_VERSION >= 100
gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type)));
#else
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
#endif
gv->setInitializer(llvm::Constant::getNullValue(p_type));
gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// comdat is needed for windows select any linking to work
// set comdat to Any(weak linking)
if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
llvm::Comdat* comdat = module_->getOrInsertComdat(name);
comdat->setSelectionKind(llvm::Comdat::Any);
gv->setComdat(comdat);
}
return gv;
}
llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
TVM_FFI_ICHECK(gv != nullptr);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* faddr =
builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment());
#else
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
#endif
faddr->setMetadata("tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) {
std::string ctx_symbol = system_lib_prefix_.value_or("") + ffi::symbol::tvm_ffi_library_ctx;
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, ctx_symbol);
// Register back the locations.
if (f_tvm_register_system_symbol_ != nullptr && !target_c_runtime_) {
export_system_symbols_.emplace_back(std::make_pair(ctx_symbol, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
gv_tvm_ffi_func_call_ =
InitContextPtr(llvmGetPointerTo(ftype_tvm_ffi_func_call_, 0), "__TVMFFIFunctionCall");
gv_tvm_get_func_from_env_ = InitContextPtr(llvmGetPointerTo(ftype_tvm_get_func_from_env_, 0),
"__TVMBackendGetFuncFromEnv");
gv_tvm_ffi_set_last_error_c_str_ =
InitContextPtr(llvmGetPointerTo(ftype_tvm_ffi_error_set_raised_by_c_str_, 0),
"__TVMFFIErrorSetRaisedFromCStr");
gv_tvm_parallel_launch_ = InitContextPtr(llvmGetPointerTo(ftype_tvm_parallel_launch_, 0),
"__TVMBackendParallelLaunch");
gv_tvm_parallel_barrier_ = InitContextPtr(llvmGetPointerTo(ftype_tvm_parallel_barrier_, 0),
"__TVMBackendParallelBarrier");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
}
}
}
llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end", function_);
auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
}
void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
EmitDebugLocation(op);
/*! \brief maintain states that should be guarded when step into compute scope */
struct ComputeScopeStates {
explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {}
void EnterWithScope() {
std::swap(function_, parent_->function_);
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
std::swap(di_subprogram_, parent_->di_subprogram_);
std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_);
}
void ExitWithScope() {
std::swap(function_, parent_->function_);
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
std::swap(di_subprogram_, parent_->di_subprogram_);
std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_);
}
llvm::Function* function_{nullptr};
llvm::DISubprogram* di_subprogram_{nullptr};
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
std::vector<std::pair<llvm::BasicBlock*, llvm::BasicBlock*>> loop_frame_jump_tgts_;
std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
CodeGenCPU* parent_;
};
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from ffi::PackedArgs.
// This is easier than set the alias scope manually.
ffi::Array<Var> vargs = tir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (Var v : vargs) {
llvm::Value* value = MakeValue(v);
value->setName(v->name_hint.c_str());
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false);
// $xxx_compute_ functions are not global. They should be marked as static (via InternalLinkage)
// to call them correctly on MIPS platform (CALL16 issue)
// Linkage ld Error: CALL16 reloc at 0x290 not against global symbol
const StringImmNode* value = op->value.as<StringImmNode>();
TVM_FFI_ICHECK(value != nullptr);
llvm::Function* fcompute = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
MakeStringRef(value->value), module_.get());
SetTargetAttributes(fcompute);
for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); it++) {
const Var& var = vargs[std::distance(fcompute->arg_begin(), it)];
it->setName(std::string(var->name_hint));
}
llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values));
llvm::LLVMContext* ctx = llvm_target_->GetContext();
// enter compute scope and setup compute function.
With<ComputeScopeStates> scope_states_guard(this);
size_t idx = 0;
for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
var_map_[var.get()] = v;
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
fcompute->addParamAttr(idx, llvm::Attribute::NoAlias);
// always not inline compute function to make the code structure clean
#else
fcompute->setDoesNotAlias(idx + 1);
#endif
fcompute->addFnAttr(llvm::Attribute::NoInline);
}
// Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
auto f = alloc_storage_info_.find(var.get());
if (f != alloc_storage_info_.end()) {
unsigned align = f->second.alignment;
if (align > 1) {
auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align);
fcompute->addParamAttr(idx, attr);
}
}
#endif
}
function_ = fcompute;
di_subprogram_ = CreateDebugFunction(MakeStringRef(value->value), vargs.Map(GetType),
PrimType(DataType::Int(32)));
auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
builder_->SetInsertPoint(compute_call_end);
AddDebugInformation(fcompute, vargs.Map(GetType));
}
CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const ffi::Array<Var>& vfields,
uint64_t* num_bytes,
std::string struct_name) {
if (vfields.size() == 0) {
*num_bytes = 0U;
return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_));
}
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
TVM_FFI_ICHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
llvm::StructType* ctype = struct_name.size() ? llvm::StructType::create(fields, struct_name)
: llvm::StructType::create(fields);
llvm::AllocaInst* cvalue =
WithFunctionEntry([&]() { return builder_->CreateAlloca(ctype, ConstInt32(1)); });
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)}));
}
*num_bytes = data_layout_->getTypeAllocSize(ctype);
return TypedPointer(ctype, cvalue);
}
void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const ffi::Array<Var>& vfields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
llvm::Type* field_type = cdata.type->getStructElementType(i);
llvm::Value* field_addr =
builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)});
llvm::Value* load =
builder_->CreateLoad(field_type, field_addr, std::string(vfields[i]->name_hint));
(*vmap)[vfields[i].get()] = load;
}
}
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) {
// closure data
llvm::Function* f =
llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get());
SetTargetAttributes(f);
// allocate and setup the closure, call the closure.
ffi::Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes;
TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name);
#if TVM_LLVM_VERSION >= 90
auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
auto launch_callee = RuntimeTVMParallelLaunch();
#endif
llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
launch_callee,
{f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
auto* lambda_entry =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
task_id->setName("task_id");
llvm::Value* penv = &(*it++);
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", DataType::Int(32));
par_env.num_task = Var("num_task", DataType::Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
t_int32_,
builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}),
"num_task");
par_env.penv = penv;
auto new_analyzer = std::make_unique<arith::Analyzer>();
std::swap(function_, f);
std::swap(parallel_env_, par_env);
std::swap(analyzer_, new_analyzer);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(analyzer_, new_analyzer);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
TVM_FFI_ICHECK_NE(par_env.parallel_loop_count, 0)
<< "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
}
llvm::Value* CodeGenCPU::CreateStaticHandle() {
llvm::GlobalVariable* gv =
new llvm::GlobalVariable(*module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage,
nullptr, "__tvm_static_handle");
#if TVM_LLVM_VERSION >= 100
gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_)));
#else
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
#endif
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
return gv;
}
void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
// closure data
llvm::Function* f =
llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get());
SetTargetAttributes(f);
llvm::Value* gv = CreateStaticHandle();
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage,
init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
uint64_t nbytes;
ffi::Array<Var> vfields = tir::UndefinedVars(body, {});
TypedPointer cdata = PackClosureData(vfields, &nbytes);
llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall(
finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function.
auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
TVM_FFI_ICHECK(parallel_env_.penv == nullptr);
auto new_analyzer = std::make_unique<arith::Analyzer>();
std::swap(function_, f);
std::swap(analyzer_, new_analyzer);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(analyzer_, new_analyzer);
std::swap(function_, f);
builder_->SetInsertPoint(init_end);
}
llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
// We will store the packed function handle in global space.
// Initialize it during the first call.
#if TVM_LLVM_VERSION >= 200
llvm::DataLayout layout(module_.get()->getDataLayout());
#else
llvm::DataLayout layout(module_.get());
#endif
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr =
new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname);
#if TVM_LLVM_VERSION >= 100
hptr->setAlignment(llvm::Align(align));
#else
hptr->setAlignment(align);
#endif
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
llvm::LLVMContext* ctx = llvm_target_->GetContext();
llvm::BasicBlock* pre_block = builder_->GetInsertBlock();
auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align);
#else
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
#endif
llvm::Value* handle_not_null =
builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out =
WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); });
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
llvm::Align(gv_mod_ctx_->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
gv_mod_ctx_->getAlignment());
#else
llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment());
#endif
ctx_load->setMetadata(
"tbaa", md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
#if TVM_LLVM_VERSION >= 90
auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv());
#else
auto env_callee = RuntimeTVMGetFuncFromEnv();
#endif
llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle =
builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align);
#else
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
#endif
// Store the handle
builder_->CreateStore(loaded_handle, hptr);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Array<PrimExpr>& args,
const DataType& r_type,
const int64_t begin, const int64_t end,
bool use_env_lookup) {
std::string func_name = [&]() {
auto ptr = args[0].as<StringImmNode>();
TVM_FFI_ICHECK(ptr) << "Expected first argument of tir::Call to be "
<< "a string containing the callee's name, "
<< "but instead contained " << args[0];
return ptr->value;
}();
// call the function
int64_t nargs = end - begin;
TVM_FFI_ICHECK_GE(nargs, 0);
llvm::Value* stack_args = MakeValue(args[1]);
llvm::Value* packed_args = builder_->CreateInBoundsGEP(
t_tvm_ffi_any_, builder_->CreatePointerCast(stack_args, llvmGetPointerTo(t_tvm_ffi_any_, 0)),
ConstInt32(begin));
llvm::Value* result = builder_->CreateInBoundsGEP(
t_tvm_ffi_any_, builder_->CreatePointerCast(stack_args, llvmGetPointerTo(t_tvm_ffi_any_, 0)),
ConstInt32(end));
llvm::FunctionType* callee_ftype = nullptr;
llvm::Value* callee_value = nullptr;
std::vector<llvm::Value*> call_args;
if (use_env_lookup) {
callee_ftype = ftype_tvm_ffi_func_call_;
callee_value = RuntimeTVMFFIFunctionCall();
call_args.push_back(GetPackedFuncHandle(func_name));
call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result});
} else {
// directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix
callee_ftype = ftype_tvm_ffi_c_func_;
callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name);
if (callee_value == nullptr) {
callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage,
func_name, module_.get());
}
call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_));
call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result});
}
#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value);
#else
(void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM.
auto call_callee = callee_value;
#endif
llvm::Value* call = builder_->CreateCall(call_callee, call_args);
llvm::BasicBlock* end_block = CheckCallSuccess(call);
PackedCall pc = {nullptr};
if (!r_type.is_void()) {
// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* result_value =
builder_->CreateInBoundsGEP(t_tvm_ffi_any_, result, {ConstInt32(0), ConstInt32(2)});
llvm::Value* load_ptr =
builder_->CreatePointerCast(result_value, llvmGetPointerTo(llvm_r_api_type, 0));
#if TVM_LLVM_VERSION >= 110
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
#else
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);
llvm::Value* result_type_index =
builder_->CreateInBoundsGEP(t_tvm_ffi_any_, result, {ConstInt32(0), ConstInt32(0)});
// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_type_index = builder_->CreateAlignedLoad(t_int32_, result_type_index, llvm::Align(4));
#elif TVM_LLVM_VERSION >= 80
pc.ret_type_index = builder_->CreateAlignedLoad(t_int32_, result_type_index, 8);
#else
pc.ret_type_index = builder_->CreateAlignedLoad(result_type_index, 8);
#endif
}
pc.end_block = end_block;
return pc;
}
llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) {
TVM_FFI_ICHECK_EQ(op->args.size(), 4U);
bool use_string_lookup = op->op.same_as(builtin::tvm_call_packed_lowered());
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as<IntImmNode>()->value,
op->args[3].as<IntImmNode>()->value, use_string_lookup);
return pc.ret_value;
}
llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as<IntImmNode>()->value,
op->args[3].as<IntImmNode>()->value, true);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
// Get traced value.
llvm::Value* traced_value = MakeValue(op->args[4]);
// The update_block handles case when we need to update the return value.
llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_);
// The continue_block handles case when we need to return original
// traced value.
llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_);
// Check the ret_type_code and create cmp instruction.
llvm::Value* cmp = builder_->CreateICmpNE(
pc.ret_type_index, llvm::ConstantInt::get(t_int_, ffi::TypeIndex::kTVMFFINone));
builder_->CreateCondBr(cmp, update_block, continue_block);
builder_->SetInsertPoint(update_block);
builder_->CreateBr(continue_block);
builder_->SetInsertPoint(continue_block);
// The return value depends on from what bb we come from.
llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
phi_rvalue->addIncoming(pc.ret_value, update_block);
phi_rvalue->addIncoming(traced_value, pc.end_block);
return phi_rvalue;
}
llvm::Value* CodeGenCPU::RuntimeTVMFFIFunctionCall() {
if (f_tvm_ffi_func_call_ != nullptr) return f_tvm_ffi_func_call_;
return GetContextPtr(gv_tvm_ffi_func_call_);
}
llvm::Value* CodeGenCPU::RuntimeTVMGetFuncFromEnv() {
if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
return GetContextPtr(gv_tvm_get_func_from_env_);
}
llvm::Value* CodeGenCPU::RuntimeTVMFFIErrorSetRaisedFromCStr() {
if (f_tvm_ffi_set_raised_by_c_str_ != nullptr) return f_tvm_ffi_set_raised_by_c_str_;
return GetContextPtr(gv_tvm_ffi_set_last_error_c_str_);
}
llvm::Value* CodeGenCPU::RuntimeTVMParallelLaunch() {
if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
return GetContextPtr(gv_tvm_parallel_launch_);
}
llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
return GetContextPtr(gv_tvm_parallel_barrier_);
}
void CodeGenCPU::AddStartupFunction() {
if (!target_c_runtime_) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
SetTargetAttributes(function_);
llvm::BasicBlock* startup_entry =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
builder_->CreateCall(f_tvm_register_system_symbol_,
{name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
}
}
llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
return CreateCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
return CreateCallTracePacked(op);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
return CreateCallPacked(op);
} else if (op->op.same_as(builtin::tvm_static_handle())) {
return CreateStaticHandle();
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
builder_->CreateRet(ConstInt32(-1));
auto next_block = std::next(builder_->GetInsertBlock()->getIterator());
llvm::BasicBlock* new_bb =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont", function_, &*next_block);
builder_->SetInsertPoint(new_bb);
return ConstInt32(-1);
} else if (op->op.same_as(builtin::tvm_struct_get())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>().value()->value;
TypedPointer ref =
CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
if (kind == builtin::kArrAddr) {
return builder_->CreatePointerCast(ref.addr, t_void_p_);
}
llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr);
if (op->dtype == DataType::Bool()) {
struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value);
}
return struct_value;
} else if (op->op.same_as(builtin::tvm_struct_set())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>().value()->value;
llvm::Value* value = MakeValue(op->args[3]);
TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
TVM_FFI_ICHECK(kind != builtin::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(value, ref.type);
}
if (kind == builtin::kTVMFFIAnyUnionValue) {
// when we set any union value, we need to be careful to
// clear off the union value to zero if the set size is less than 64 bits
if (data_layout_->getTypeAllocSize(ref.type) != 8) {
llvm::Value* i64_addr =
builder_->CreatePointerCast(ref.addr, llvmGetPointerTo(t_int64_, 0));
builder_->CreateStore(ConstInt64(0), i64_addr);
}
}
builder_->CreateStore(value, ref.addr);
return ConstInt32(0);
} else if (op->op.same_as(builtin::tvm_stack_alloca())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
std::string type = op->args[0].as<StringImm>().value()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* {
const int64_t* pval = as_const_int(op->args[1]);
TVM_FFI_ICHECK(pval) << "require stack alloca to contain constant value";
llvm::Value* num = ConstInt32(pval[0]);
if (type == "shape") {
return builder_->CreateAlloca(t_tvm_shape_index_, num);
} else if (type == "tvm_ffi_any") {
return builder_->CreateAlloca(t_tvm_ffi_any_, num);
} else if (type == "array") {
return builder_->CreateAlloca(t_tvm_array_, num);
} else if (type == "tensormap") {
auto* alloca = builder_->CreateAlloca(t_tvm_tensormap_, num);
alloca->setAlignment(llvm::Align(64));
return alloca;
} else {
TVM_FFI_THROW(InternalError) << "Unknown stack alloca type " << type;
}
});
} else {
return CodeGenLLVM::CreateIntrinsic(op);
}
}
void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
EmitDebugLocation(op);
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
if (op->message.as<StringImmNode>()) {
os << ", " << op->message.as<StringImmNode>()->value;
}
llvm::Value* msg = GetConstString(os.str());
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end", function_);
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
#if TVM_LLVM_VERSION >= 90
auto err_callee = llvm::FunctionCallee(ftype_tvm_ffi_error_set_raised_by_c_str_,
RuntimeTVMFFIErrorSetRaisedFromCStr());
#else
auto err_callee = RuntimeTVMFFIErrorSetRaisedFromCStr();
#endif
builder_->CreateCall(err_callee, {GetConstString("RuntimeError"), msg});
builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
CodeGenLLVM::VisitStmt_(op);
}
void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
EmitDebugLocation(op);
if (op->attr_key == tir::attr::coproc_uop_scope) {
const StringImmNode* value = op->value.as<StringImmNode>();
TVM_FFI_ICHECK(value != nullptr);
this->CreateStaticInit(value->value, op->body);
} else if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (tir::attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") {
TVM_FFI_ICHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (op->attr_key == "pragma_parallel_launch_point") {
CreateParallelLaunch(op->body, 0, "pragma_parallel");
} else if (op->attr_key == "pragma_parallel_barrier_when_finish") {
TVM_FFI_ICHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
TVM_FFI_ICHECK(!parallel_env_.in_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
#if TVM_LLVM_VERSION >= 90
auto bar_callee =
llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier());
#else
auto bar_callee = RuntimeTVMParallelBarrier();
#endif
builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == tir::attr::pragma_import_llvm) {
const StringImmNode* value = op->value.as<StringImmNode>();
TVM_FFI_ICHECK(value != nullptr);
this->HandleImport(value->value);
this->VisitStmt(op->body);
} else {
LOG(WARNING) << "Unknown pragma " << op->attr_key;
this->VisitStmt(op->body);
}
} else {
CodeGenLLVM::VisitStmt_(op);
}
}
void CodeGenCPU::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->kind == ForKind::kParallel) {
TVM_FFI_ICHECK(is_zero(op->min))
<< "Parallel launch require canonical loop with zero start index";
TVM_FFI_ICHECK(op->HasTrivialStep())
<< "Parallel launch require canonical loop with trivial loop step";
if (parallel_env_.penv == nullptr) {
auto copy_node = For(ffi::make_object<ForNode>(*op));
CreateParallelLaunch(copy_node, 0,
std::string("loop_parallel_") + op->loop_var->name_hint.c_str());
} else {
// already in parallel env.
TVM_FFI_ICHECK(parallel_env_.task_id.defined());
TVM_FFI_ICHECK(parallel_env_.num_task.defined());
TVM_FFI_ICHECK(parallel_env_.penv != nullptr);
DataType t = op->extent.dtype();
PrimExpr num_task = cast(t, parallel_env_.num_task);
PrimExpr task_id = cast(t, parallel_env_.task_id);
TVM_FFI_ICHECK(!parallel_env_.in_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.in_parallel_loop = true;
PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent);
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id), MakeValue(end), MakeValue(num_task), op->loop_var,
op->body);
} else {
PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
PrimExpr begin = min(task_id * step, op->extent);
end = min((task_id + make_const(t, 1)) * step, end);
CreateSerialFor(MakeValue(begin), MakeValue(end),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
}
parallel_env_.in_parallel_loop = false;
++parallel_env_.parallel_loop_count;
}
} else {
TVM_FFI_THROW(InternalError) << "cannot handle for type " << op->kind;
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm.codegen.llvm.target_cpu",
[](const ffi::PackedArgs& targs, ffi::Any* rv) {
*rv = static_cast<void*>(new CodeGenCPU());
});
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION