blob: 4dad2fc4b3ec81488a8c628a8ea8e69da0233c1b [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_llvm.cc
*/
#ifdef TVM_LLVM_VERSION
// Part of the code are adapted from Halide's CodeGen_LLVM
#include "codegen_llvm.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#if LLVM_VERSION_MAJOR >= 17
#include <llvm/TargetParser/Triple.h>
#else
#include <llvm/ADT/Triple.h>
#endif
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/BinaryFormat/Dwarf.h>
#include <llvm/CodeGen/TargetSubtargetInfo.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DIBuilder.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/DebugInfoMetadata.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/FMF.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalVariable.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>
#include <llvm/Pass.h>
#if TVM_LLVM_VERSION >= 160
#include <llvm/IR/Verifier.h> // For VerifierPass
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/StandardInstrumentations.h>
#include <llvm/TargetParser/Host.h>
#else
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Support/Host.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#endif
#include <llvm/Support/Alignment.h>
#include <llvm/Support/CodeGen.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TypeSize.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/tirx/op.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "../../arith/pattern_match.h"
#include "../build_common.h"
#include "codegen_params.h"
#include "llvm_instance.h"
namespace tvm {
namespace codegen {
// CodeGenLLVM has members of type std::unique_ptr<T>. These members will be
// instantiated in the constructor, which will requre that the type T is
// complete at that point. Put the constructor (and destructor) here, since
// all types should be complete here.
CodeGenLLVM::CodeGenLLVM() = default;
CodeGenLLVM::~CodeGenLLVM() = default;
CodeGenLLVM::DebugInfo::~DebugInfo() = default;
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(LLVMTarget* llvm_target) {
std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName();
std::string factory_template = "tvm.codegen.llvm.target_";
void* handle = nullptr;
if (auto f = tvm::ffi::Function::GetGlobal(factory_template + target)) {
handle = (*f)().cast<void*>();
} else if (auto f = tvm::ffi::Function::GetGlobal(factory_template + "cpu")) {
handle = (*f)().cast<void*>();
} else {
TVM_FFI_THROW(InternalError) << "no factory function for codegen for target " << target;
}
if (handle) {
return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
} else {
TVM_FFI_THROW(InternalError) << "unable to create codegen for target " << target;
}
}
void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target,
ffi::Optional<ffi::String> system_lib_prefix, bool dynamic_lookup,
bool target_c_runtime) {
llvm_target_ = llvm_target;
llvm::LLVMContext* ctx = llvm_target_->GetContext();
builder_.reset(new IRBuilder(*ctx));
module_.reset(new llvm::Module(module_name, *ctx));
md_builder_.reset(new llvm::MDBuilder(*ctx));
// types
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace());
t_int1_ = llvm::Type::getInt1Ty(*ctx);
t_int_ = llvm::Type::getInt32Ty(*ctx);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
// CUTensorMap is a 128 byte struct, so we use a 128 byte array to represent it.
t_tvm_tensormap_ = llvm::ArrayType::get(t_char_, 128);
// meta data
md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
InitTarget();
}
void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); }
void CodeGenLLVM::InitTarget() {
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
#if TVM_LLVM_VERSION >= 210
module_->setTargetTriple(tm->getTargetTriple());
#else
module_->setTargetTriple(tm->getTargetTriple().str());
#endif
module_->setDataLayout(tm->createDataLayout());
#if TVM_LLVM_VERSION >= 200
data_layout_.reset(new llvm::DataLayout(module_.get()->getDataLayout()));
#else
data_layout_.reset(new llvm::DataLayout(module_.get()));
#endif
if (native_vector_bits_ == 0) {
native_vector_bits_ = llvm_target_->GetVectorWidth();
}
bool use_float16_abi = false;
// For conversions between _Float16 and float, LLVM uses runtime functions
// __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used
// "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision
// values can be passed in XMM registers (i.e. as floating-point). This happens
// when the compilation target has SSE2 enabled (either directly, or by enabling
// a feature that implies SSE2).
// Because the names of the conversion functions remain unchanged, it is impossible
// for TVM to provide them in the runtime, and have them work in both cases.
// To alleviate this issue, emit these functions directly into the target module
// after detecting whether or not to use floating-point ABI. To allow the linker
// to remove potential duplicates (or if they are unused), they are weak and
// reside in a separate section (ELF).
llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch();
if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) {
// Detect if SSE2 is enabled. This determines whether float16 ABI is used.
std::stringstream os;
const char fname[] = "test_sse2";
os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n"
<< "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\""
<< llvm_target_->GetCPU() << "\" ";
if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) {
os << "\"target-features\"=\"" << fs << "\" ";
}
os << "}\n";
auto mod = llvm_target_->GetInstance().ParseIR(os.str());
auto* test_sse2 = mod->getFunction(fname);
TVM_FFI_ICHECK(test_sse2 != nullptr) << "Module creation error";
use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2");
}
// Call this function only with LLVM >= 6.0. The code it emits uses "dso_local"
// which was introduced in LLVM 6.
EmitFloat16ConversionBuiltins(use_float16_abi);
}
llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
return this->DeclareFunctionInternal(gvar, f);
}
void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
this->AddFunctionInternal(gvar, f);
}
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
alias_var_set_.clear();
alloc_storage_info_.clear();
volatile_buf_.clear();
analyzer_.reset(new arith::Analyzer());
}
std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
const GlobalVar& gvar, const PrimFunc& func) {
if (auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
return {global_symbol.value(), llvm::Function::ExternalLinkage};
}
std::string symbol_name = [&]() {
std::stringstream ss;
ss << "_internal_";
ss << gvar->name_hint;
return ss.str();
}();
return {symbol_name, llvm::Function::PrivateLinkage};
}
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
return it->second;
}
TVM_FFI_ICHECK_EQ(func->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";
std::vector<llvm::Type*> param_types;
is_restricted_ = func->HasNonzeroAttr(tirx::attr::kNoAlias);
for (Var param : func->params) {
param_types.push_back(GetLLVMType(param));
if (!is_restricted_ && param.dtype().is_handle()) {
alias_var_set_.insert(param.get());
}
}
llvm::FunctionType* ftype =
llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);
auto [symbol_name, linkage_type] = GetLinkage(gvar, func);
auto function = module_->getFunction(MakeStringRef(symbol_name));
if (function == nullptr) {
function =
llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get());
}
function->setCallingConv(llvm::CallingConv::C);
function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
SetTargetAttributes(function);
functions_[gvar.get()] = function;
return function;
}
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
this->InitFuncState();
function_ = DeclareFunctionInternal(gvar, f);
// set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
const Var& var = f->params[i];
var_map_[var.get()] = v;
v->setName(std::string(var->name_hint));
if (is_restricted_) {
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
function_->addParamAttr(i, llvm::Attribute::NoAlias);
}
}
}
llvm::LLVMContext* ctx = llvm_target_->GetContext();
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);
// Add alignment attribute if needed.
for (size_t i = 0; i < f->params.size(); ++i) {
const Var& var = f->params[i];
auto f = alloc_storage_info_.find(var.get());
if (f != alloc_storage_info_.end()) {
unsigned align = f->second.alignment;
if (align > 1) {
auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align);
function_->addParamAttr(i, attr);
}
}
}
EmitDebugLocation(f->span);
if (IsVoidType(f->ret_type)) {
// All other return types are handled when encountering
// builtin::ret().
builder_->CreateRetVoid();
} else {
builder_->CreateRet(ConstInt32(0));
}
}
void CodeGenLLVM::Verify() const {
std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
if (llvm::verifyModule(*module_, &verify_errors)) {
TVM_FFI_THROW(InternalError) << "LLVM module verification failed with the following errors: \n"
<< verify_errors.str();
}
}
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
for (size_t i = 0; i < link_modules_.size(); ++i) {
TVM_FFI_ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
<< "Failed to link modules";
}
link_modules_.clear();
this->Verify();
this->Optimize();
this->Verify();
return std::move(module_);
}
void CodeGenLLVM::HandleImport(const std::string& code) {
llvm::StringRef code_str(code);
std::unique_ptr<llvm::Module> mlib;
#if TVM_LLVM_VERSION >= 180
if (code_str.ends_with(".ll") || code_str.ends_with(".bc")) {
#else
if (code_str.endswith(".ll") || code_str.endswith(".bc")) {
#endif
mlib = llvm_target_->GetInstance().LoadIR(code);
} else {
mlib = llvm_target_->GetInstance().ParseIR(code);
}
#if TVM_LLVM_VERSION >= 210
mlib->setTargetTriple(llvm::Triple(llvm_target_->GetTargetTriple()));
#else
mlib->setTargetTriple(llvm_target_->GetTargetTriple());
#endif
mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout());
// mark all the functions as force inline
for (llvm::Function& f : mlib->functions()) {
f.removeFnAttr(llvm::Attribute::OptimizeNone);
f.removeFnAttr(llvm::Attribute::NoInline);
f.addFnAttr(llvm::Attribute::AlwaysInline);
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
}
// add to linker libraries.
this->AddLinkModule(std::move(mlib));
}
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
link_modules_.emplace_back(std::move(mod));
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
TVM_FFI_THROW(InternalError) << "not implemented";
}
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
TVM_FFI_THROW(InternalError) << "not implemented";
}
llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) {
TVM_FFI_THROW(InternalError) << "not implemented";
}
#if TVM_LLVM_VERSION >= 160
// Use new pass manager
void CodeGenLLVM::Optimize() {
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
bool debug_logging = false;
bool verify_each = false;
llvm::PipelineTuningOptions pto = llvm::PipelineTuningOptions();
llvm::PassInstrumentationCallbacks pic;
llvm::PassBuilder builder(tm, pto, std::nullopt, &pic);
llvm::LoopAnalysisManager lam;
llvm::FunctionAnalysisManager fam;
llvm::CGSCCAnalysisManager cgam;
llvm::ModuleAnalysisManager mam;
builder.registerLoopAnalyses(lam);
builder.registerFunctionAnalyses(fam);
builder.registerCGSCCAnalyses(cgam);
builder.registerModuleAnalyses(mam);
builder.crossRegisterProxies(lam, fam, cgam, mam);
// Construct the default pass pipeline depending on the opt level.
std::string pipeline;
#if TVM_LLVM_VERSION <= 170
switch (llvm_target_->GetOptLevel()) {
case llvm::CodeGenOpt::Level::None:
pipeline = "default<O0>";
break;
case llvm::CodeGenOpt::Level::Less:
pipeline = "default<O1>";
break;
case llvm::CodeGenOpt::Level::Default:
pipeline = "default<O2>";
break;
default:
// CodeGenOpt::Level::Aggressive
pipeline = "default<O3>";
break;
}
#else
switch (llvm_target_->GetOptLevel()) {
case llvm::CodeGenOptLevel::None:
pipeline = "default<O0>";
break;
case llvm::CodeGenOptLevel::Less:
pipeline = "default<O1>";
break;
case llvm::CodeGenOptLevel::Default:
pipeline = "default<O2>";
break;
default:
// CodeGenOptLevel::Aggressive
pipeline = "default<O3>";
break;
}
#endif
llvm::StandardInstrumentations si(*llvm_target_->GetContext(), debug_logging, verify_each);
#if LLVM_VERSION_MAJOR >= 17
si.registerCallbacks(pic, &mam);
#else
si.registerCallbacks(pic, &fam);
#endif
llvm::ModulePassManager mpass;
if (verify_each) {
mpass.addPass(llvm::VerifierPass());
}
if (auto err = builder.parsePassPipeline(mpass, pipeline)) {
TVM_FFI_THROW(InternalError) << "error parsing pass pipeline '" << pipeline
<< "':" << llvm::toString(std::move(err)) << '\n';
}
mpass.run(*module_, mam);
}
#else // TVM_LLVM_VERSION
class FPassManager : public llvm::legacy::FunctionPassManager {
public:
explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {}
// override add to allow messaging
void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); }
};
class MPassManager : public llvm::legacy::PassManager {
public:
// override add to allow messaging
void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); }
};
void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {}
void CodeGenLLVM::Optimize() {
// pass manager
FPassManager fpass(module_.get());
MPassManager mpass;
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis()));
fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis()));
// place optimization pass
llvm::PassManagerBuilder builder;
// Use the same opt-level as specified in TargetMachine for running passes
llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel();
switch (opt_level) {
case llvm::CodeGenOpt::Level::None:
builder.OptLevel = 0;
break;
case llvm::CodeGenOpt::Level::Less:
builder.OptLevel = 1;
break;
case llvm::CodeGenOpt::Level::Default:
builder.OptLevel = 2;
break;
default:
// CodeGenOpt::Level::Aggressive
builder.OptLevel = 3;
}
builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
builder.LoopVectorize = true;
builder.SLPVectorize = true;
this->InitPassManagerBuilder(&builder);
tm->adjustPassManager(builder);
builder.populateFunctionPassManager(fpass);
builder.populateModulePassManager(mpass);
fpass.doInitialization();
for (auto it = module_->begin(); it != module_->end(); ++it) {
fpass.run(*it);
}
fpass.doFinalization();
mpass.run(*module_);
}
#endif // TVM_LLVM_VERSION
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
return native_vector_bits_;
}
unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; }
llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
if (dtype.is_handle()) {
TVM_FFI_ICHECK_EQ(dtype.lanes(), 1);
return t_void_p_;
}
if (dtype.is_void()) {
return t_void_;
}
llvm::Type* etype = nullptr;
llvm::LLVMContext* ctx = llvm_target_->GetContext();
if (dtype.is_int() || dtype.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx, dtype.bits());
} else if (dtype.is_bool()) {
etype = t_int1_;
} else if (dtype.is_float()) {
switch (dtype.bits()) {
case 16:
etype = llvm::Type::getHalfTy(*ctx);
break;
case 32:
etype = llvm::Type::getFloatTy(*ctx);
break;
case 64:
etype = llvm::Type::getDoubleTy(*ctx);
break;
default:
TVM_FFI_THROW(InternalError) << "do not support " << dtype;
}
} else if (dtype.code() == DataType::kFloat8_e3m4 || dtype.code() == DataType::kFloat8_e4m3 ||
dtype.code() == DataType::kFloat8_e4m3b11fnuz ||
dtype.code() == DataType::kFloat8_e4m3fn ||
dtype.code() == DataType::kFloat8_e4m3fnuz || dtype.code() == DataType::kFloat8_e5m2 ||
dtype.code() == DataType::kFloat8_e5m2fnuz ||
dtype.code() == DataType::kFloat8_e8m0fnu) {
etype = llvm::Type::getInt8Ty(*ctx);
} else if (dtype.code() == DataType::kFloat6_e2m3fn || dtype.code() == DataType::kFloat6_e3m2fn) {
etype = llvm::Type::getIntNTy(*ctx, 6);
} else if (dtype.code() == DataType::kFloat4_e2m1fn) {
etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
if (dtype.is_scalable_vector()) {
return llvm::VectorType::get(etype, dtype.vscale_factor(), true);
} else {
return llvm::FixedVectorType::get(etype, dtype.lanes());
}
} else {
return etype;
}
} // namespace codegen
llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
if (auto* ptr = type.as<PrimTypeNode>()) {
return DTypeToLLVMType(ptr->dtype);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
// LLVM IR doesn't allow void*, nor do we require custom datatypes
// to have LLVM equivalents, so we need to recognize these
// patterns explicitly.
if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) {
return t_void_p_;
}
} else if (ptr->element_type->IsInstance<TensorMapTypeNode>()) {
return llvmGetPointerTo(t_tvm_tensormap_, 0);
}
// TODO(tvm-team) consider put storage scope into the pointer type.
return llvmGetPointerTo(GetLLVMType(ptr->element_type), GetGlobalAddressSpace());
} else if (IsVoidType(type)) {
return t_void_;
} else if (type->IsInstance<TensorMapTypeNode>()) {
return t_tvm_tensormap_;
} else {
TVM_FFI_THROW(InternalError) << "Type " << type << " does not have a corresponding LLVM Type";
}
}
llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
return GetLLVMType(GetType(expr));
}
// Add tbaa alias information for load
//
// use a binary tree typed system to declare information
// and allow alias to be distinguished across nodes.
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index,
DataType access_dtype) {
if (alias_var_set_.count(buffer_var) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
int64_t base = 0, width = 0;
arith::PVar<IntImm> pbase, pstride;
arith::PVar<IntImm> planes;
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
int64_t xwith = 0;
if (arith::ramp(pbase, pstride, planes).Match(index)) {
base = pbase.Eval()->value;
xwith = planes.Eval()->value * pstride.Eval()->value;
} else if (auto* ptr = index.as<tirx::IntImmNode>()) {
base = ptr->value;
xwith = 1;
}
// adjust address index unit to byte
const int64_t unit_bit_width = 8;
const int64_t access_elem_bits = access_dtype.bits() * access_dtype.lanes();
base = base * access_elem_bits / unit_bit_width;
xwith = (xwith * access_elem_bits + unit_bit_width - 1) / unit_bit_width;
if (xwith > 0) {
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer_var;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int64_t w = 1024; w >= width; w /= 2) {
int64_t b = (base / w) * w;
std::stringstream os;
os << buffer_var << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index,
int* p_alignment, int* p_native_bits) {
int max_align_bits = t.bits();
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(
runtime::StorageScope::Create(GetPtrStorageScope(ffi::GetRef<Var>(buf_var))));
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
}
arith::ModularSet me = analyzer_->modular_set(index);
int64_t base = me->base;
int64_t coeff = me->coeff;
int align_bits = t.bits();
while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) {
base = base / 2;
coeff = coeff / 2;
align_bits *= 2;
}
if (align_bits < 8) {
align_bits = 8;
}
*p_alignment = align_bits / 8;
}
llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size,
unsigned int shared_address_space,
int alignment,
llvm::GlobalValue::LinkageTypes linkage) {
llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
llvm::GlobalVariable* global =
new llvm::GlobalVariable(*module_, type, false, linkage, llvm::UndefValue::get(type), "shmem",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
global->setAlignment(llvm::MaybeAlign(alignment));
return global;
}
std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
// TODO(tulloch): pass this information through Span classes to the IRModule instance?
debug_info->file_ = debug_info->di_builder_->createFile("IRModule.CodeGenLLVM", ".");
const int runtime_version = 0;
const bool is_optimized = false;
const char* compiler_flags = "";
debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
/*Lang=*/llvm::dwarf::DW_LANG_C, /*File=*/debug_info->file_, /*Producer=*/"TVM", is_optimized,
compiler_flags, runtime_version);
return debug_info;
}
void CodeGenLLVM::PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt) {
loop_frame_jump_tgts_.emplace_back(backedge_tgt, exit_tgt);
}
void CodeGenLLVM::PopLoopFrame() { loop_frame_jump_tgts_.pop_back(); }
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
TVM_FFI_ICHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
indices.reserve(extent);
for (int i = 0; i < extent; ++i) {
if (begin + i >= 0 && begin + i < num_elems) {
indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
} else {
indices.push_back(llvm::UndefValue::get(t_int32_));
}
}
return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = GetVectorNumElements(vec);
std::vector<int> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(num_elems - i - 1);
}
return builder_->CreateShuffleVector(vec, vec, indices);
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = GetVectorNumElements(vec);
if (num_elems == target_lanes) return vec;
TVM_FFI_ICHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
}
return builder_->CreateShuffleVector(vec, vec, mask);
}
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
// To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane
// LLVM vector types.
for (size_t i = 0, e = vecs.size(); i != e; ++i) {
llvm::Value* v = vecs[i];
if (!v->getType()->isVectorTy()) {
llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1);
vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0));
}
}
// concat vector, tree shape reduction
int total_lanes = 0;
for (llvm::Value* v : vecs) {
total_lanes += GetVectorNumElements(v);
}
while (vecs.size() > 1) {
std::vector<llvm::Value*> new_vecs;
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
const size_t lhs_lanes = GetVectorNumElements(lhs);
const size_t rhs_lanes = GetVectorNumElements(rhs);
if (lhs_lanes < rhs_lanes) {
lhs = CreateVecPad(lhs, rhs_lanes);
} else if (rhs_lanes < lhs_lanes) {
rhs = CreateVecPad(rhs, lhs_lanes);
}
const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
std::vector<int> mask;
for (size_t i = 0; i < lhs_lanes; ++i) {
mask.push_back(i);
}
for (size_t i = 0; i < rhs_lanes; ++i) {
mask.push_back(shared_lanes + i);
}
new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
}
if (vecs.size() % 2 != 0) {
new_vecs.push_back(vecs.back());
}
vecs.swap(new_vecs);
}
return CreateVecSlice(vecs[0], 0, total_lanes);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body) {
llvm::BasicBlock* pre_block = builder_->GetInsertBlock();
std::string loop_var_name = loop_var->name_hint;
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_);
auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_);
auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_);
auto* for_next = llvm::BasicBlock::Create(*ctx, "for_next_" + loop_var_name, function_);
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_begin);
llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
AddDebugInformation(loop_value, loop_var);
loop_value->addIncoming(begin, pre_block);
TVM_FFI_ICHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
auto lt = CreateLT(loop_var.dtype(), loop_value, end);
builder_->CreateCondBr(lt, for_body, for_end, md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
EmitDebugLocation(body->span);
PushLoopFrame(for_next, for_end);
this->VisitStmt(body);
PopLoopFrame();
var_map_.erase(loop_var.get());
builder_->CreateBr(for_next);
builder_->SetInsertPoint(for_next);
llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_end);
}
// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
// TODO(tvm-team): consider add native support
TVM_FFI_ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
TVM_FFI_ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_bool()) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
return builder_->CreateFCmpUNE(value, zero);
} else {
llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
return builder_->CreateICmpNE(value, zero);
}
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
return builder_->CreateFPToUI(value, target);
}
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
} else if ((from.is_uint() || from.is_bool()) && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
TVM_FFI_ICHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
}
llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name,
llvm::GlobalValue::LinkageTypes linkage_type) {
llvm::Type* ty = const_data->getType();
llvm::GlobalVariable* global =
new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name);
global->setAlignment(llvm::Align(1));
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices);
return ptr;
}
llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
if (auto it = str_map_.find(str); it != str_map_.end()) {
return it->second;
}
auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str);
auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage);
str_map_[str] = ptr;
return ptr;
}
CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
TVM_FFI_ICHECK_EQ(indices.size(), 1)
<< "CodeGenLLVM requires all buffers to be flat 1-d buffers.";
llvm::Value* index = indices[0];
llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType());
TVM_FFI_ICHECK(buffer_ptr_type != nullptr);
auto address_space = buffer_ptr_type->getAddressSpace();
llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype);
llvm::PointerType* element_ptr_type =
llvmGetPointerTo(DTypeToLLVMType(buffer_element_dtype), address_space);
llvm::Type* value_type = DTypeToLLVMType(value_dtype);
llvm::PointerType* value_ptr_type = llvmGetPointerTo(value_type, address_space);
TVM_FFI_ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer";
if (buffer_ptr_type != element_ptr_type) {
buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type);
}
TVM_FFI_ICHECK(!HasAlignmentPadding(buffer_element_dtype))
<< "DType " << buffer_element_dtype
<< " has padding for alignment. TVM data arrays are expected to be densely packed, with no "
"padding for alignment.";
llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index);
if (element_ptr_type != value_ptr_type) {
value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type);
}
return TypedPointer(value_type, value_ptr);
}
llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
auto it = var_map_.find(v);
TVM_FFI_ICHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
return it->second;
}
void CodeGenLLVM::CreatePrintf(const std::string& format,
llvm::ArrayRef<llvm::Value*> format_args) {
EmitDebugLocation();
llvm::Function* func_printf = module_->getFunction("printf");
if (func_printf == nullptr) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, true);
func_printf =
llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "printf", module_.get());
}
llvm::Function* func_fflush = module_->getFunction("fflush");
if (!func_fflush) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, {t_void_p_}, false);
func_fflush =
llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "fflush", module_.get());
}
#if TVM_LLVM_VERSION >= 200
llvm::Value* str = builder_->CreateGlobalString(format);
#else
llvm::Value* str = builder_->CreateGlobalStringPtr(format);
#endif
str->setName("printf_format_str");
std::vector<llvm::Value*> printf_args = {str};
printf_args.insert(printf_args.end(), format_args.begin(), format_args.end());
builder_->CreateCall(func_printf, printf_args);
// Call fflush() immediately, as this utility is intended for debug
// purposes. A segfault occurring within the generated LLVM code
// would otherwise leave the stdout buffer unflushed.
llvm::Value* null_stream = llvm::ConstantPointerNull::get(t_void_p_);
null_stream->setName("null_stream");
builder_->CreateCall(func_fflush, {null_stream});
}
llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) {
EmitDebugLocation();
llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level);
#if TVM_LLVM_VERSION >= 200
llvm::Function* builtin = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), llvm::Intrinsic::returnaddress, {}));
#else
llvm::Function* builtin =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress);
#endif
llvm::Value* call = builder_->CreateCall(builtin, level_val);
call->setName("return_addr");
return call;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr>& args, bool skip_first_arg) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
arg_value.push_back(MakeValue(args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false);
llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol));
if (f == nullptr) {
f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, MakeStringRef(global_symbol),
module_.get());
}
llvm::CallInst* call = builder_->CreateCall(f, arg_value);
return call;
}
llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types) {
llvm::Module* module = module_.get();
if (!llvm::Intrinsic::isOverloaded(id)) {
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(module, id, {}));
#else
return llvm::Intrinsic::getDeclaration(module, id, {});
#endif
}
llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
llvm::SmallVector<llvm::Type*, 4> overload_types;
auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
overload_types.clear();
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
if (error) {
return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
}
}
return match;
};
// First, try matching the signature assuming non-vararg case.
auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
switch (try_match(fn_ty, false)) {
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
// The return type doesn't match, there is nothing else to do.
return nullptr;
case llvm::Intrinsic::MatchIntrinsicTypes_Match:
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module, id, overload_types));
#else
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
break;
}
// Keep adding one type at a time (starting from empty list), and
// try matching the vararg signature.
llvm::SmallVector<llvm::Type*, 4> var_types;
for (int i = 0, e = arg_types.size(); i <= e; ++i) {
if (i > 0) var_types.push_back(arg_types[i - 1]);
auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module, id, overload_types));
#else
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif
}
}
// Failed to identify the type.
return nullptr;
}
void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) {
const std::string& cpu = llvm_target_->GetCPU();
if (!cpu.empty()) {
func->addFnAttr("target-cpu", cpu);
}
const std::string& features = llvm_target_->GetTargetFeatureString();
if (!features.empty()) {
func->addFnAttr("target-features", features);
}
}
void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
// The LLVM IR for these function was obtained by compiling
//
// For integer ABI:
// __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a);
// __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a);
// For floating-point ABI:
// __truncXfYf2__<float, uint32_t, 23, _Float16, uint16_t, 10>(a);
// __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a);
static const char trunc_body[] = // __truncsfhf2
" %v0 = bitcast float %a0 to i32\n"
" %v1 = and i32 %v0, 2147483647\n"
" %v2 = add nsw i32 %v1, -947912704\n"
" %v3 = add nsw i32 %v1, -1199570944\n"
" %v4 = icmp ult i32 %v2, %v3\n"
" br i1 %v4, label %b1, label %b5\n"
"b1:\n"
" %v5 = lshr i32 %v0, 13\n"
" %v6 = and i32 %v5, 65535\n"
" %v7 = add nuw nsw i32 %v6, -114688\n"
" %v8 = and i32 %v0, 8191\n"
" %v9 = icmp ugt i32 %v8, 4096\n"
" br i1 %v9, label %b2, label %b3\n"
"b2:\n"
" %v10 = add nuw nsw i32 %v6, -114687\n"
" br label %b13\n"
"b3:\n"
" %v11 = icmp eq i32 %v8, 4096\n"
" br i1 %v11, label %b4, label %b13\n"
"b4:\n"
" %v12 = and i32 %v7, 65535\n"
" %v13 = and i32 %v5, 1\n"
" %v14 = add nuw nsw i32 %v12, %v13\n"
" br label %b13\n"
"b5:\n"
" %v15 = icmp ugt i32 %v1, 2139095040\n"
" br i1 %v15, label %b6, label %b7\n"
"b6:\n"
" %v16 = lshr i32 %v0, 13\n"
" %v17 = and i32 %v16, 511\n"
" %v18 = or i32 %v17, 32256\n"
" br label %b13\n"
"b7:\n"
" %v19 = icmp ugt i32 %v1, 1199570943\n"
" br i1 %v19, label %b13, label %b8\n"
"b8:\n"
" %v20 = icmp ult i32 %v1, 754974720\n"
" br i1 %v20, label %b13, label %b9\n"
"b9:\n"
" %v21 = lshr i32 %v1, 23\n"
" %v22 = sub nsw i32 113, %v21\n"
" %v23 = and i32 %v0, 8388607\n"
" %v24 = or i32 %v23, 8388608\n"
" %v25 = add nsw i32 %v21, -81\n"
" %v26 = shl i32 %v24, %v25\n"
" %v27 = icmp ne i32 %v26, 0\n"
" %v28 = lshr i32 %v24, %v22\n"
" %v29 = zext i1 %v27 to i32\n"
" %v30 = lshr i32 %v28, 13\n"
" %v31 = and i32 %v28, 8191\n"
" %v32 = or i32 %v31, %v29\n"
" %v33 = icmp ugt i32 %v32, 4096\n"
" br i1 %v33, label %b10, label %b11\n"
"b10:\n"
" %v34 = add nuw nsw i32 %v30, 1\n"
" br label %b13\n"
"b11:\n"
" %v35 = icmp eq i32 %v32, 4096\n"
" br i1 %v35, label %b12, label %b13\n"
"b12:\n"
" %v36 = and i32 %v30, 1\n"
" %v37 = add nuw nsw i32 %v36, %v30\n"
" br label %b13\n"
"b13:\n"
" %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n"
" [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n"
" [ %v30, %b11 ]\n"
" %v39 = lshr i32 %v0, 16\n"
" %v40 = and i32 %v39, 32768\n"
" %v41 = or i32 %v38, %v40\n"
" %vlast = trunc i32 %v41 to i16\n";
static const char extend_body[] = // __extendhfsf2
" %v1 = and i16 %vinp, 32767\n"
" %v2 = zext i16 %v1 to i32\n"
" %v3 = add nsw i16 %v1, -1024\n"
" %v4 = icmp ult i16 %v3, 30720\n"
" br i1 %v4, label %b1, label %b2\n"
"b1:\n"
" %v5 = shl nuw nsw i32 %v2, 13\n"
" %v6 = add nuw nsw i32 %v5, 939524096\n"
" br label %b6\n"
"b2:\n"
" %v7 = icmp ugt i16 %v1, 31743\n"
" br i1 %v7, label %b3, label %b4\n"
"b3:\n"
" %v8 = shl nuw nsw i32 %v2, 13\n"
" %v9 = or i32 %v8, 2139095040\n"
" br label %b6\n"
"b4:\n"
" %v10 = icmp eq i16 %v1, 0\n"
" br i1 %v10, label %b6, label %b5\n"
"b5:\n"
" %v11 = icmp ult i16 %v1, 256\n"
" %v12 = lshr i32 %v2, 8\n"
" %v13 = select i1 %v11, i32 %v2, i32 %v12\n"
" %v14 = select i1 %v11, i32 32, i32 24\n"
" %v15 = icmp ult i32 %v13, 16\n"
" %v16 = lshr i32 %v13, 4\n"
" %v17 = add nsw i32 %v14, -4\n"
" %v18 = select i1 %v15, i32 %v13, i32 %v16\n"
" %v19 = select i1 %v15, i32 %v14, i32 %v17\n"
" %v20 = icmp ult i32 %v18, 4\n"
" %v21 = lshr i32 %v18, 2\n"
" %v22 = add nsw i32 %v19, -2\n"
" %v23 = select i1 %v20, i32 %v18, i32 %v21\n"
" %v24 = select i1 %v20, i32 %v19, i32 %v22\n"
" %v25 = icmp ult i32 %v23, 2\n"
" %v26 = sub nsw i32 0, %v23\n"
" %v27 = select i1 %v25, i32 %v26, i32 -2\n"
" %v28 = add nsw i32 %v27, %v24\n"
" %v29 = add nsw i32 %v28, -8\n"
" %v30 = shl i32 %v2, %v29\n"
" %v31 = xor i32 %v30, 8388608\n"
" %v32 = shl i32 %v28, 23\n"
" %v33 = sub i32 1124073472, %v32\n"
" %v34 = or i32 %v31, %v33\n"
" br label %b6\n"
"b6:\n"
" %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n"
" %v36 = and i16 %vinp, -32768\n"
" %v37 = zext i16 %v36 to i32\n"
" %v38 = shl nuw i32 %v37, 16\n"
" %v39 = or i32 %v35, %v38\n"
" %v40 = bitcast i32 %v39 to float\n"
" ret float %v40\n"
"}\n";
std::string short_type = use_float16_abi ? "half" : "i16";
std::string short_cast_in, short_cast_out;
if (use_float16_abi) {
short_cast_in = " %vinp = bitcast half %a0 to i16\n";
short_cast_out = " %vres = bitcast i16 %vlast to half\n";
} else {
// No-ops that preserve the i16 values.
short_cast_in = " %vinp = add i16 %a0, 0\n";
short_cast_out = " %vres = add i16 %vlast, 0\n";
}
llvm::Triple triple(llvm_target_->GetTargetTriple());
static const char elf_section_name[] = ".text.tvm.fp16.conv";
std::string section = triple.getObjectFormat() == llvm::Triple::ELF
? std::string("section \"") + elf_section_name + "\" "
: "";
std::string trunc_header = "define weak dso_local " + short_type +
" @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section +
"{\nb0:\n";
std::string trunc_return = " ret " + short_type + " %vres\n}\n";
std::string extend_header = "define weak dso_local float @__extendhfsf2(" + short_type +
" %a0) local_unnamed_addr #0 " + section + "{\nb0:\n";
// truncate = trunc_header + trunc_body + short_cast_out + trunc_return
// extend = extend_header + short_cast_in + extend_body
std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" +
llvm_target_->GetCPU() + "\" \"target-features\"=\"" +
llvm_target_->GetTargetFeatureString() + "\" }\n";
auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout();
std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" +
"target datalayout = \"" + data_layout.getStringRepresentation() +
"\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return +
extend_header + short_cast_in + extend_body + attributes;
auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir);
link_modules_.push_back(std::move(builtins_module));
}
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
TVM_FFI_ICHECK_GE(op->args.size(), 1U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = GetLLVMType(ffi::GetRef<PrimExpr>(op));
llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
TVM_FFI_ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
<< llvmGetIntrinName(id);
// In earlier versions of LLVM's, the prefetch intrinsic is not
// overloaded, and always takes the first argument as i8*. If
// this is the case, this argument should insert a cast to i8*.
if (id == llvm::Intrinsic::prefetch) {
llvm::Type* param_type = f->arg_begin()->getType();
if (param_type != arg_value[0]->getType()) {
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(arg_value[0]->getType())->getAddressSpace();
arg_value[0] =
builder_->CreatePointerCast(arg_value[0], llvmGetPointerTo(t_char_, addrspace));
}
}
return builder_->CreateCall(f, arg_value);
} else if (op->op.same_as(builtin::bitwise_and())) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::bitwise_or())) {
return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::bitwise_not())) {
return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::bitwise_xor())) {
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::shift_left())) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::shift_right())) {
if (op->args[0].dtype().is_int()) {
return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
}
} else if (op->op.same_as(builtin::tvm_storage_sync())) {
return CreateStorageSync(op);
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
TVM_FFI_ICHECK(op->args.size() == 1 && load);
ffi::Array<PrimExpr> indices = load->indices;
if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) {
indices.Set(indices.size() - 1, r->base);
}
std::vector<llvm::Value*> indices_val;
for (const auto& index : indices) {
indices_val.push_back(MakeValue(index));
}
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
indices_val, load->dtype);
return buffer_ptr.addr;
} else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else if (op->op.same_as(builtin::isnullptr())) {
return builder_->CreateIsNull(MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::handle_add_byte_offset())) {
llvm::Value* ptr = MakeValue(op->args[0]);
llvm::Value* offset = MakeValue(op->args[1]);
return builder_->CreateInBoundsGEP(t_int8_, ptr, offset);
} else if (op->op.same_as(builtin::large_uint_imm())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
} else if (op->op.same_as(builtin::if_then_else())) {
TVM_FFI_ICHECK_EQ(op->args[0].dtype().lanes(), 1)
<< "if_then_else can only take scalar condition";
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]);
llvm::BasicBlock* then_value_block = builder_->GetInsertBlock();
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
llvm::Value* else_value = MakeValue(op->args[2]);
llvm::BasicBlock* else_value_block = builder_->GetInsertBlock();
builder_->CreateBr(end_block);
builder_->SetInsertPoint(end_block);
llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->op.same_as(builtin::ret())) {
auto const* val = op->args[0].as<IntImmNode>();
TVM_FFI_ICHECK(val) << "the tirx.ret should be transformed to return zero "
<< "before the llvm code generation.";
TVM_FFI_ICHECK_EQ(val->value, 0) << "the tirx.ret should be transformed to "
<< "return zero before the llvm code generation.";
builder_->CreateRet(ConstInt32(0));
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* ret_dummy =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_);
builder_->SetInsertPoint(ret_dummy);
return ret_dummy;
} else if (op->op.same_as(builtin::continue_loop())) {
TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty())
<< "the tirx.continue_loop should be inserted under at least one For or While stmts.";
builder_->CreateBr(loop_frame_jump_tgts_.back().first);
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* post_dummy =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_cont_dummy", function_);
builder_->SetInsertPoint(post_dummy);
return post_dummy;
} else if (op->op.same_as(builtin::break_loop())) {
TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty())
<< "the tirx.break_loop should be inserted under at least one For or While stmts.";
builder_->CreateBr(loop_frame_jump_tgts_.back().second);
// LLVM allows exactly one terminator in a single basic block
// append a new dummy basic block to avoid error.
llvm::BasicBlock* post_dummy =
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_break_dummy", function_);
builder_->SetInsertPoint(post_dummy);
return post_dummy;
} else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
llvm::Value* value = MakeValue(op->args[0]);
if (value->getType()->isPointerTy() && target->isIntegerTy()) {
return builder_->CreatePtrToInt(value, target);
} else if (value->getType()->isIntegerTy() && target->isPointerTy()) {
return builder_->CreateIntToPtr(value, target);
}
return builder_->CreateBitCast(value, target);
} else if (op->op.same_as(builtin::isnan())) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
} else if (op->op.same_as(builtin::vectorlow())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = GetVectorNumElements(v);
return CreateVecSlice(v, 0, l / 2);
} else if (op->op.same_as(builtin::vectorhigh())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = GetVectorNumElements(v);
return CreateVecSlice(v, l / 2, l / 2);
} else if (op->op.same_as(builtin::vectorcombine())) {
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
int num_elems = GetVectorNumElements(v0) * 2;
std::vector<int> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else if (op->op.same_as(builtin::atomic_add())) {
// TODO(masahi): Support atomic for CPU backend
TVM_FFI_THROW(InternalError) << "CPU backend does not support atomic add yet.";
} else if (op->op.same_as(builtin::start_profile_intrinsic()) ||
op->op.same_as(builtin::end_profile_intrinsic())) {
LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op;
return nullptr;
} else if (op->op.same_as(builtin::assume())) {
llvm::Value* cond = MakeValue(op->args[0]);
return builder_->CreateAssumption(cond);
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else if (op->op.same_as(builtin::vscale())) {
llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
return builder_->CreateCall(f);
} else if (op->op.same_as(builtin::get_active_lane_mask())) {
llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask;
llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype),
{builder_->getInt32Ty(), builder_->getInt32Ty()});
return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])});
} else {
TVM_FFI_THROW(InternalError) << "unknown intrinsic " << op->op;
}
}
void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
PrimExpr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
}
} else {
llvm::Value* value = MakeValue(e);
for (int i = 0; i < e.dtype().lanes(); ++i) {
f(i, builder_->CreateExtractElement(value, i));
}
}
}
// Visitors
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
if (t.bits() >= 32) { \
return builder_->CreateNSW##Op(a, b); \
} else { \
return builder_->Create##Op(a, b); \
} \
} else if (t.is_uint()) { \
if (t.bits() >= 32) { \
return builder_->CreateNUW##Op(a, b); \
} else { \
return builder_->Create##Op(a, b); \
} \
} else { \
TVM_FFI_ICHECK(t.is_float()); \
return builder_->CreateF##Op(a, b); \
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
#define DEFINE_CODEGEN_CMP_OP(Op) \
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
return builder_->CreateICmpS##Op(a, b); \
} else if (t.is_uint()) { \
return builder_->CreateICmpU##Op(a, b); \
} else { \
TVM_FFI_ICHECK(t.is_float()); \
return builder_->CreateFCmpO##Op(a, b); \
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
return builder_->CreateSDiv(a, b);
} else if (op->dtype.is_uint()) {
return builder_->CreateUDiv(a, b);
} else {
TVM_FFI_ICHECK(op->dtype.is_float());
return builder_->CreateFDiv(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
return builder_->CreateSRem(a, b);
} else if (op->dtype.is_uint()) {
return builder_->CreateURem(a, b);
} else {
TVM_FFI_ICHECK(op->dtype.is_float());
return builder_->CreateFRem(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
return builder_->CreateICmpEQ(a, b);
} else {
return builder_->CreateFCmpOEQ(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
return builder_->CreateICmpNE(a, b);
} else {
return builder_->CreateFCmpONE(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
return builder_->CreateNot(MakeValue(op->a));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
MakeValue(op->false_value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
auto var_value = MakeValue(op->value);
var_map_[op->var.get()] = var_value;
AddDebugInformation(var_value, op->var);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}
bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
const llvm::DataLayout& data_layout = module_->getDataLayout();
int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype));
int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of()));
return bytes != bytes_scalar * dtype.lanes();
}
void CodeGenLLVM::BufferAccessHelper(
Buffer buffer, ffi::Array<PrimExpr> indices, ffi::Optional<PrimExpr> predicate,
DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i,
llvm::Value* predicate, int alignment, bool is_volatile)>
make_instruction) {
DataType buffer_element_dtype = buffer->dtype;
TVM_FFI_ICHECK_GE(indices.size(), 1)
<< "Buffer " << buffer->name << " is accessed with no indices. "
<< "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen.";
// Only the last index is allowed to be multi-lane. All earlier
// indices must be scalar. This only matters for subclasses of
// CodeGenLLVM, because the default implementation of GetBufferPtr
// requires 1-d indices.
std::vector<llvm::Value*> earlier_index_values;
for (size_t i = 0; i < indices.size() - 1; i++) {
TVM_FFI_ICHECK_EQ(indices[i].dtype().lanes(), 1)
<< "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i
<< ". Multi-lane indices are only supported as the last index.";
earlier_index_values.push_back(MakeValue(indices[i]));
}
PrimExpr last_index = indices[indices.size() - 1];
TVM_FFI_ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(),
last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes());
// Record index and elemtype in original form used for alias info
PrimExpr last_index_origin = last_index;
DataType buffer_element_dtype_origin = buffer_element_dtype;
bool is_volatile = volatile_buf_.count(buffer->data.get());
// If the buffer index is a contiguous ramp node, we only need to
// access the first element, then cast to the value type.
if (const RampNode* ramp_index = last_index.as<RampNode>()) {
if (is_one(ramp_index->stride)) {
last_index = ramp_index->base;
}
}
// All TVM arrays are densely packed. If the vectorized LLVM type
// contains padding for alignment, we need to index based on the
// size of the scalar type to avoid introducing that padding.
if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
last_index = buffer_element_dtype.lanes() * last_index;
buffer_element_dtype = buffer_element_dtype.element_of();
}
int alignment;
if (last_index.dtype().lanes() == 1) {
// If we are accessing with a single index, then the vectorized
// element being accessed may require more alignment than the
// underlying data type.
int native_bits;
GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits);
} else {
// Otherwise, alignment is based on the return value's scalar
// type.
TVM_FFI_ICHECK_GE(value_dtype.bits(), 8);
alignment = value_dtype.bits() / 8;
}
llvm::Value* cached_vector_index = nullptr;
for (int i = 0; i < last_index.dtype().lanes(); ++i) {
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().is_vector()) {
if (i == 0) {
cached_vector_index = MakeValue(last_index);
}
last_index_value = builder_->CreateExtractElement(cached_vector_index, i);
} else {
last_index_value = MakeValue(last_index);
subelement_i = -1;
}
std::vector<llvm::Value*> all_index_values = earlier_index_values;
all_index_values.push_back(last_index_value);
llvm::Value* predicate_value = nullptr;
if (predicate.defined()) {
predicate_value = MakeValue(predicate.value());
}
TypedPointer buffer_ptr =
value_dtype.is_scalable_vector()
? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() /
last_index.dtype().lanes()))
: CreateBufferPtr(
MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
auto instruction =
make_instruction(buffer_ptr, subelement_i, predicate_value, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
DataType value_dtype = op->dtype;
std::vector<llvm::Value*> loads;
auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */,
llvm::Value* predicate, int alignment, bool is_volatile) {
llvm::Instruction* load = nullptr;
if (predicate != nullptr) {
TVM_FFI_ICHECK(!is_volatile)
<< "The masked load intrinsic does not support declaring load as volatile.";
load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment),
predicate);
} else {
load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment),
is_volatile);
}
loads.push_back(load);
return load;
};
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load);
if (loads.size() == 1) {
return loads[0];
} else {
llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype));
for (size_t i = 0; i < loads.size(); i++) {
ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i));
}
return ret;
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
if (auto opt_call_op = op->op.as<Op>()) {
auto call_op = opt_call_op.value();
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
// call extern intrinsic
TVM_FFI_ICHECK_GE(op->args.size(), 1U);
auto global_symbol = Downcast<StringImm>(op->args[0]);
return this->CreateCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), global_symbol->value,
op->args, true);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
return this->CreateCallExtern(GetType(ffi::GetRef<PrimExpr>(op)),
op_attr_global_symbol_[call_op], op->args, false);
} else {
VLOG(2) << "CreateIntrinsic: " << ffi::GetRef<Call>(op);
auto x = CreateIntrinsic(op);
VLOG(2) << "CreateIntrinsic done";
return x;
}
} else if (auto* ptr_gvar = op->op.as<GlobalVarNode>()) {
auto gvar = ffi::GetRef<GlobalVar>(ptr_gvar);
auto it = functions_.find(ptr_gvar);
TVM_FFI_ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\"";
llvm::Function* callee = it->second;
std::vector<llvm::Value*> arg_value;
for (const auto& arg : op->args) {
arg_value.push_back(MakeValue(arg));
}
return builder_->CreateCall(callee, arg_value);
} else {
TVM_FFI_THROW(InternalError) << "Unsupported operation in CallNode: " << op->op;
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
// TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
TVM_FFI_ICHECK(!op->dtype.is_scalable_vector());
int lanes = op->dtype.lanes();
for (int i = 0; i < lanes; ++i) {
vec = builder_->CreateInsertElement(
vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i));
}
return vec;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
std::vector<llvm::Value*> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
vecs[i] = VisitExpr(op->vectors[i]);
total_lanes += op->vectors[i].dtype().lanes();
}
llvm::Value* v0 = CreateVecConcat(vecs);
std::vector<uint32_t> idx(op->indices.size());
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t* val = as_const_int(op->indices[i]);
TVM_FFI_ICHECK(val && *val >= 0 && *val < total_lanes)
<< "Shuffled indeces are suppose to be int, "
<< "but get " << op->indices[i] << "\n";
idx[i] = *val;
}
llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
// If the output is a single-element vector, convert it back to a scalar.
if (idx.size() == 1) {
res = builder_->CreateExtractElement(res, ConstInt32(0));
}
return res;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
DataType dtype = op->dtype;
llvm::Value* value = MakeValue(op->value);
llvm::Type* type = DTypeToLLVMType(dtype);
llvm::Constant* undef = llvm::UndefValue::get(type);
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(undef, value, zero);
llvm::ElementCount ec =
llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), dtype.is_scalable_vector());
llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero);
return builder_->CreateShuffleVector(value, undef, mask);
}
void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
EmitDebugLocation(op);
DataType value_dtype = op->value.dtype();
Var buffer_var = op->buffer->data;
llvm::Value* value = MakeValue(op->value);
auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate,
int alignment, bool is_volatile) {
llvm::Value* to_store = value;
llvm::Instruction* store;
if (subelement_i != -1) {
to_store = builder_->CreateExtractElement(value, subelement_i);
}
if (predicate != nullptr) {
TVM_FFI_ICHECK(!is_volatile)
<< "The masked store intrinsic does not support declaring store as volatile.";
store =
builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate);
} else {
store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment),
is_volatile);
}
return store;
};
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store);
}
void CodeGenLLVM::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
if (op->kind == ForKind::kUnrolled) {
LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
<< " consider set unroll_explicit=True";
} else {
TVM_FFI_ICHECK(op->kind == ForKind::kSerial);
}
PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1));
PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent);
llvm::Value* begin_value = MakeValue(op->min);
llvm::Value* end_value = MakeValue(end);
CreateSerialFor(begin_value, end_value, MakeValue(step), op->loop_var, op->body);
}
void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
EmitDebugLocation(op);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_);
auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_);
auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge", function_);
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_cond);
builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
builder_->SetInsertPoint(while_body);
PushLoopFrame(while_cond, while_merge);
this->VisitStmt(op->body);
PopLoopFrame();
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_merge);
}
void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
EmitDebugLocation(op);
llvm::Value* cond = MakeValue(op->condition);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
if (op->else_case) {
auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case.value());
builder_->CreateBr(end_block);
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::VisitStmt_(const AllocBufferNode* op) {
EmitDebugLocation(op);
TVM_FFI_ICHECK_EQ(op->buffer->shape.size(), 1)
<< "LLVM codegen only supports flat 1-d buffer allocation, but allocation of "
<< op->buffer->name << " is " << op->buffer->shape << "-d";
llvm::Value* buf = nullptr;
const IntImmNode* dim_imm = op->buffer->shape[0].as<IntImmNode>();
TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation";
int32_t constant_size = static_cast<int32_t>(dim_imm->value);
TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation";
StorageInfo& info = alloc_storage_info_[op->buffer->data.get()];
// Use buffer's data_alignment if specified, otherwise compute from shape.
if (op->buffer->data_alignment > 0) {
info.alignment = op->buffer->data_alignment;
} else if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->buffer->dtype, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(DTypeToLLVMType(op->buffer->dtype), ConstInt32(constant_size));
});
auto alignment = static_cast<unsigned>(alloca->getAlign().value());
if (alignment < static_cast<unsigned>(info.alignment)) {
alloca->setAlignment(llvm::Align(info.alignment));
}
info.alignment = static_cast<unsigned>(alloca->getAlign().value());
buf = alloca;
buf =
builder_->CreatePointerCast(buf, llvmGetPointerTo(DTypeToLLVMType(op->buffer->dtype),
buf->getType()->getPointerAddressSpace()));
AddDebugInformation(buf, op->buffer->data);
TVM_FFI_ICHECK(!var_map_.count(op->buffer->data.get()));
var_map_[op->buffer->data.get()] = buf;
if (op->annotations.count(tirx::attr::kVolatile)) {
volatile_buf_.insert(op->buffer->data.get());
}
}
void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
EmitDebugLocation(op);
if (op->attr_key == tirx::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv);
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
}
}
} else if (op->attr_key == tirx::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
TVM_FFI_ICHECK(v);
alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) {
builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
alloc_storage_info_[v].alignment);
}
}
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
EmitDebugLocation(op);
// AssertStmt is a leaf — no body to visit.
// Constraint scoping is handled by ScopeStack in analysis passes.
}
void CodeGenLLVM::VisitStmt_(const BindNode* op) {
EmitDebugLocation(op);
const VarNode* v = op->var.get();
TVM_FFI_ICHECK(!var_map_.count(v));
if (v->dtype.is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(v);
}
}
llvm::Value* value = MakeValue(op->value);
// TIR has type-annotations on variables, but not on each PrimExpr.
// Therefore, to have the correct LLVM type for pointers, we may
// need to introduce a pointer-cast, even though pointer-to-pointer
// casts are not expressible with the `tirx::CastNode`.
if (v->dtype.is_handle() && v->type_annotation.defined()) {
TVM_FFI_ICHECK(op->value->dtype.is_handle())
<< "Variable " << op->var << " is a pointer with type " << op->value
<< ", but is being bound to expression with type " << op->value->dtype;
auto* llvm_type = GetLLVMType(v->type_annotation);
if (llvm_type != value->getType()) {
value->setName((v->name_hint + "_void_ptr").c_str());
value = builder_->CreatePointerCast(value, llvm_type);
}
}
AddDebugInformation(value, op->var);
var_map_[v] = value;
analyzer_->Bind(op->var, op->value);
if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
alloc_storage_info_[v].alignment);
}
AddDebugInformation(value, op->var);
}
void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
EmitDebugLocation(op);
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
}
}
void CodeGenLLVM::VisitStmt_(const DeclBufferNode* op) { EmitDebugLocation(op); }
void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
EmitDebugLocation(op);
MakeValue(op->value);
}
void CodeGenLLVM::EmitDebugLocation(const ffi::Optional<Span>& span) {
if (di_subprogram_ == nullptr) {
// debug info is not always generated outside of CPU codegen
return;
}
llvm::LLVMContext* ctx = llvm_target_->GetContext();
int line = 0;
int column = 0;
if (span) {
auto ptr = span.as<SpanNode>();
line = ptr->line;
column = ptr->column;
}
auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, line, column, di_subprogram_));
builder_->SetCurrentDebugLocation(loc);
}
void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); }
void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); }
// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv
void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm,
const ffi::Array<Type>& tvm_param_types) {
TVM_FFI_ICHECK(di_subprogram_);
f_llvm->setSubprogram(di_subprogram_);
TVM_FFI_ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_);
IRBuilder builder(&f_llvm->getEntryBlock());
if (!f_llvm->getEntryBlock().empty()) {
builder.SetInsertPoint(&f_llvm->getEntryBlock().front());
}
llvm::DebugLoc DL;
builder.SetCurrentDebugLocation(DL);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
TVM_FFI_ICHECK_EQ(f_llvm->arg_size(), tvm_param_types.size());
for (auto iter_param = f_llvm->arg_begin(); iter_param != f_llvm->arg_end(); iter_param++) {
size_t i = std::distance(f_llvm->arg_begin(), iter_param);
auto* paramAlloca = builder.CreateAlloca(iter_param->getType());
auto param = dbg_info_->di_builder_->createParameterVariable(
di_subprogram_, iter_param->getName(), i + 1, dbg_info_->file_, 0,
GetDebugType(tvm_param_types[i], iter_param->getType()),
/*alwaysPreserve=*/true);
auto* store = builder.CreateStore(iter_param, paramAlloca);
auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_);
#if TVM_LLVM_DIBUILDER_USES_ITERATOR
dbg_info_->di_builder_->insertDeclare(
paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc),
llvm::BasicBlock::iterator(store));
#else
dbg_info_->di_builder_->insertDeclare(paramAlloca, param,
dbg_info_->di_builder_->createExpression(),
llvm::DebugLoc(di_loc), store);
#endif
}
dbg_info_->di_builder_->finalizeSubprogram(f_llvm->getSubprogram());
auto* scope = f_llvm->getSubprogram();
if (!scope) {
return;
}
for (auto& BB : *f_llvm) {
for (auto& I : BB) {
if (I.getDebugLoc()) {
continue;
}
auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope);
I.setDebugLoc(llvm::DebugLoc(di_loc));
}
}
}
void CodeGenLLVM::AddDebugInformation(llvm::Value* llvm_value, const Var& tir_var,
llvm::Instruction* insert_before) {
llvm_value->setName(tir_var->name_hint.c_str());
if (!di_subprogram_) return;
auto dbg_dtype = GetDebugType(GetType(tir_var));
// no invalid dtypes
if (!dbg_dtype) return;
auto local_var = dbg_info_->di_builder_->createAutoVariable(
di_subprogram_, std::string(tir_var->name_hint), dbg_info_->file_, 0, dbg_dtype);
auto* di_loc = llvm::DILocation::get(*llvm_target_->GetContext(), 0, 0, di_subprogram_);
// LLVM 15+ requires dbg_declare to reference pointer or integer types only.
// For non-pointer types (floats, vectors), use dbg_value instead to track
// the SSA value directly rather than a memory location.
if (!llvm_value->getType()->isPointerTy()) {
if (insert_before) {
// Upstream LLVM 20+ changed insertDbgValueIntrinsic to take
// BasicBlock::iterator; ROCm-bundled LLVM 20 retains Instruction*.
// TVM_LLVM_DIBUILDER_USES_ITERATOR is set by CMake feature detection.
#if TVM_LLVM_DIBUILDER_USES_ITERATOR
dbg_info_->di_builder_->insertDbgValueIntrinsic(
llvm_value, local_var, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc),
llvm::BasicBlock::iterator(insert_before));
#else
dbg_info_->di_builder_->insertDbgValueIntrinsic(llvm_value, local_var,
dbg_info_->di_builder_->createExpression(),
llvm::DebugLoc(di_loc), insert_before);
#endif
} else {
dbg_info_->di_builder_->insertDbgValueIntrinsic(
llvm_value, local_var, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc),
builder_->GetInsertBlock());
}
return;
}
if (insert_before) {
#if TVM_LLVM_DIBUILDER_USES_ITERATOR
dbg_info_->di_builder_->insertDeclare(
llvm_value, local_var, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc),
llvm::BasicBlock::iterator(insert_before));
#else
dbg_info_->di_builder_->insertDeclare(llvm_value, local_var,
dbg_info_->di_builder_->createExpression(),
llvm::DebugLoc(di_loc), insert_before);
#endif
} else {
dbg_info_->di_builder_->insertDeclare(llvm_value, local_var,
dbg_info_->di_builder_->createExpression(),
llvm::DebugLoc(di_loc), builder_->GetInsertBlock());
}
}
llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir) {
return GetDebugType(ty_tir, GetLLVMType(ty_tir));
}
llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) {
if (ty_llvm == nullptr || ty_llvm == t_void_ || ty_llvm == t_tvm_tensormap_) {
return nullptr;
} else if (ty_llvm->isPointerTy()) {
auto* ptr_type = ty_tir.as<PointerTypeNode>();
TVM_FFI_ICHECK(ptr_type != nullptr || GetRuntimeDataType(ty_tir).is_handle())
<< "Got LLVM pointer type from non-pointer IR type: " << ty_tir;
auto* pointee_type = ptr_type != nullptr ? GetDebugType(ptr_type->element_type,
GetLLVMType(ptr_type->element_type))
: nullptr;
return dbg_info_->di_builder_->createPointerType(pointee_type,
ty_llvm->getPrimitiveSizeInBits());
} else if (auto* prim_type = ty_tir.as<PrimTypeNode>()) {
DataType dtype = prim_type->dtype;
llvm::dwarf::TypeKind dwarf_type;
if (dtype.is_bool()) {
dwarf_type = llvm::dwarf::DW_ATE_boolean;
} else if (dtype.is_float()) {
dwarf_type = llvm::dwarf::DW_ATE_float;
} else if (dtype.is_int()) {
dwarf_type = llvm::dwarf::DW_ATE_signed;
} else if (dtype.is_uint()) {
dwarf_type = llvm::dwarf::DW_ATE_unsigned;
} else {
return nullptr;
}
if (dtype.is_scalable_vector()) return nullptr;
return dbg_info_->di_builder_->createBasicType(
ffi::DLDataTypeToString(dtype).operator std::string(), dtype.bits() * dtype.lanes(),
dwarf_type);
} else {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
ty_llvm->print(rso);
TVM_FFI_THROW(InternalError) << "Unknown LLVM type:" << rso.str();
}
return nullptr;
}
static void CodegenLLVMRegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tvm.codegen.llvm.GetDefaultTargetTriple",
[]() -> std::string { return llvm::sys::getDefaultTargetTriple(); })
.def("tvm.codegen.llvm.GetProcessTriple",
[]() -> std::string { return llvm::sys::getProcessTriple(); })
.def("tvm.codegen.llvm.GetHostCPUName",
[]() -> std::string { return llvm::sys::getHostCPUName().str(); })
.def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> ffi::Map<ffi::String, IntImm> {
#if TVM_LLVM_VERSION >= 190
ffi::Map<ffi::String, IntImm> ret;
auto features = llvm::sys::getHostCPUFeatures();
for (auto it = features.begin(); it != features.end(); ++it) {
std::string name = it->getKey().str();
bool value = it->getValue();
ret.Set(name, IntImm(DataType::Bool(), value));
}
return ret;
#else
llvm::StringMap<bool> features;
if (llvm::sys::getHostCPUFeatures(features)) {
ffi::Map<ffi::String, IntImm> ret;
for (auto it = features.begin(); it != features.end(); ++it) {
std::string name = it->getKey().str();
bool value = it->getValue();
ret.Set(name, IntImm(DataType::Bool(), value));
}
return ret;
}
#endif
LOG(WARNING) << "Current version of LLVM does not support feature detection on your CPU";
return {};
});
}
TVM_FFI_STATIC_INIT_BLOCK() { CodegenLLVMRegisterReflection(); }
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION