blob: 64dea2991539865e09837e02dc28a1ff7066d6b9 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_nvptx.cc
* \brief NVPTX code generator.
*/
#ifdef TVM_LLVM_VERSION
#include <llvm/ADT/SmallString.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalValue.h>
#include <llvm/IR/InlineAsm.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/IntrinsicsNVPTX.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/Alignment.h>
#include <llvm/Support/CodeGen.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <tvm/ffi/reflection/registry.h>
#if TVM_LLVM_VERSION < 170
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#endif
#include <tvm/runtime/device_api.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "../../build_common.h"
#include "../../llvm/codegen_llvm.h"
#include "../../llvm/llvm_instance.h"
#include "../cuda_fallback_module.h"
namespace tvm {
namespace codegen {
// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
public:
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
return CodeGenLLVM::DeclareFunctionInternal(gvar, f);
}
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f);
// annotate as kernel function
llvm::LLVMContext* ctx = llvm_target_->GetContext();
module_->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(
*ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel"),
llvm::ValueAsMetadata::get(ConstInt32(1))}));
}
void VisitStmt_(const AllocBufferNode* op) final {
llvm::Value* buf = nullptr;
StorageInfo& info = alloc_storage_info_[op->buffer->data.get()];
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
DataType dtype = op->buffer->dtype;
if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") {
// Shared memory: address space == 3
buf = AllocateSharedMemory(dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage);
} else {
// Compute constant_size from buffer shape
const IntImmNode* dim_imm = op->buffer->shape[0].as<IntImmNode>();
TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation in GPU";
size_t constant_size = static_cast<size_t>(dim_imm->value);
TVM_FFI_ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(dtype, constant_size);
}
if (storage_scope.rank == runtime::StorageRank::kLocal) {
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(DTypeToLLVMType(dtype), ConstInt32(constant_size));
});
auto alignment = static_cast<unsigned>(alloca->getAlign().value());
if (alignment < static_cast<unsigned>(info.alignment)) {
alloca->setAlignment(llvm::Align(info.alignment));
}
buf = alloca;
} else {
TVM_FFI_ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
buf = AllocateSharedMemory(dtype, constant_size, 3, info.alignment,
llvm::GlobalValue::ExternalLinkage);
}
}
buf = builder_->CreatePointerCast(
buf, llvmGetPointerTo(DTypeToLLVMType(dtype), buf->getType()->getPointerAddressSpace()));
TVM_FFI_ICHECK(!var_map_.count(op->buffer->data.get()));
var_map_[op->buffer->data.get()] = buf;
if (op->annotations.count(tirx::attr::kVolatile)) {
volatile_buf_.insert(op->buffer->data.get());
}
}
// Return the thread index via intrinsics.
llvm::Value* GetThreadIndex(const IterVar& iv) final {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
case 0:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
break;
case 1:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y;
break;
case 2:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z;
break;
default:
TVM_FFI_THROW(InternalError) << "unknown thread idx";
}
} else {
TVM_FFI_ICHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x;
break;
case 1:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y;
break;
case 2:
intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z;
break;
default:
TVM_FFI_THROW(InternalError) << "unknown thread idx";
}
}
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), intrin_id, {}));
#else
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
#endif
return builder_->CreateCall(f, {});
}
llvm::Value* CreateStorageSync(const CallNode* op) final {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
} else if (sync == "shared" || sync == "shared.dyn") {
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(
#if TVM_LLVM_VERSION >= 210
module_.get(), llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {}));
#else
module_.get(), llvm::Intrinsic::nvvm_barrier0, {}));
#endif
#else
llvm::Function* f =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::nvvm_barrier0);
#endif
return builder_->CreateCall(f, {});
} else {
TVM_FFI_THROW(InternalError) << "Do not support sync " << sync;
}
}
#if TVM_LLVM_VERSION < 160
// This function only works with the legacy pass manager.
void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
// Additional optimization hook to tweak the builder.
}
#endif
void Optimize() final {
for (auto& f : *module_) {
auto fname = static_cast<std::string>(f.getName());
if (fname.substr(0, 4) != "__nv") continue;
// This is to strip off unused __nv_* functions from the final module
// The one that is actually used will be inlined at call site
// Adapted from Halide's runtime linker
if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) {
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
}
}
CodeGenLLVM::Optimize();
}
llvm::Value* CreateIntrinsic(const CallNode* op) override;
protected:
void InitTarget() final {
// Maximum vector lane = float4
native_vector_bits_ = 4 * 32;
CodeGenLLVM::InitTarget();
}
};
// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_fixed_length_vector() || op->dtype.bits() != 32) {
return false;
}
// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};
int offset = 0;
if (op->op.same_as(builtin::tvm_warp_shuffle())) {
offset = 0;
} else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) {
offset = 2;
} else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) {
offset = 4;
} else {
return false;
}
*id = ids[offset + op->dtype.is_float()];
return true;
}
llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->op.same_as(builtin::tvm_warp_activemask())) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->op.same_as(builtin::atomic_add())) {
TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now";
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
if (op->args[1]->dtype.is_float()) {
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(),
llvm::AtomicOrdering::Monotonic);
}
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, llvm::MaybeAlign(),
llvm::AtomicOrdering::Monotonic);
}
return CodeGenLLVM::CreateIntrinsic(op);
}
int GetCUDAComputeVersion(const Target& target) {
ffi::Optional<ffi::String> mcpu = target->GetAttr<ffi::String>("mcpu");
TVM_FFI_CHECK(mcpu.has_value(), InternalError) << "\"-mcpu\" is undefined in the NVPTX target";
std::string sm_version = mcpu.value();
return std::stoi(sm_version.substr(3));
}
ffi::Module BuildNVPTX(IRModule mod, Target target) {
LLVMInstance llvm_instance;
With<LLVMTarget> llvm_target(llvm_instance, target);
int compute_ver = GetCUDAComputeVersion(target);
auto cg = std::make_unique<CodeGenNVPTX>();
cg->Init("TVMPTXModule", llvm_target.get(), std::nullopt, false, false);
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const auto flibdevice_path = tvm::ffi::Function::GetGlobal("tvm_callback_libdevice_path");
if (flibdevice_path.has_value()) {
std::string path = (*flibdevice_path)(compute_ver).cast<std::string>();
if (path.length() != 0) {
std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(path);
#if TVM_LLVM_VERSION >= 210
mlib->setTargetTriple(llvm::Triple(llvm_target->GetTargetTriple()));
#else
mlib->setTargetTriple(llvm_target->GetTargetTriple());
#endif
mlib->setDataLayout(tm->createDataLayout());
cg->AddLinkModule(std::move(mlib));
}
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> data_ptx, data_ll;
llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
dest_ptx.SetUnbuffered();
dest_ll.SetUnbuffered();
// print ll
module->print(dest_ll, nullptr);
std::string ll(data_ll.begin(), data_ll.end());
// emit ptx
llvm::legacy::PassManager pass;
#if TVM_LLVM_VERSION <= 170
TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#else
TVM_FFI_ICHECK(
tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0)
<< "Cannot emit target CodeGenFileType::ObjectFile";
#endif
pass.run(*module);
std::string ptx(data_ptx.begin(), data_ptx.end());
// BuildNVPTX produces PTX directly via the LLVM AMDGPU backend; hand it to
// the fallback-aware factory. Source map is `{"ll": ll}` so InspectSource
// can recover the LLVM IR even when the receiver only has a fallback module.
ffi::Map<ffi::String, ffi::String> source_map;
source_map.Set("ll", ll);
return target::CUDAModuleCreateWithFallback(ffi::Bytes(ptx.data(), ptx.size()),
ffi::String("ptx"), ExtractFuncInfo(mod), source_map);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.nvptx", BuildNVPTX)
.def_packed("tvm.codegen.llvm.target_nvptx", [](const ffi::PackedArgs& targs, ffi::Any* rv) {
*rv = static_cast<void*>(new CodeGenNVPTX());
});
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION