blob: e593852e43ad5451d5d07d2cd9285528e0b37045 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/target/target_kind.cc
* \brief Target kind registry
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/device_api.h>
#include <tvm/target/target.h>
#include <tvm/target/target_kind.h>
#include <tvm/runtime/logging.h>
#include <algorithm>
#include "../ir/attr_registry.h"
#include "../support/utils.h"
#include "./canonicalizer/llvm/canonicalize.h"
namespace tvm {
namespace refl = ffi::reflection;
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
TargetKindNode::RegisterReflection();
refl::TypeAttrDef<TargetKindNode>()
.def("__data_to_json__",
[](const TargetKindNode* node) {
// simply save as the string
return node->name;
})
.def("__data_from_json__", [](const ffi::String& name) {
auto kind = TargetKind::Get(name);
TVM_FFI_ICHECK(kind.has_value()) << "Cannot find target kind \'" << name << '\'';
return kind.value();
});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::TypeAttrDef<TargetKindNode>().def(
refl::type_attr::kRepr,
[](TargetKind kind, ffi::Function) -> ffi::String { return kind->name; });
}
/********** Registry-related code **********/
using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>;
ffi::Array<ffi::String> TargetKindRegEntry::ListTargetKinds() {
return TargetKindRegistry::Global()->ListAllNames();
}
ffi::Map<ffi::String, ffi::String> TargetKindRegEntry::ListTargetKindOptions(
const TargetKind& target_kind) {
ffi::Map<ffi::String, ffi::String> options;
for (const auto& e : target_kind->schema_.ListOptions()) {
options.Set(e.key, e.type_str);
}
return options;
}
TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const ffi::String& target_kind_name) {
return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name);
}
void TargetKindRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel) {
TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel);
}
const AttrRegistryMapContainerMap<TargetKind>& TargetKind::GetAttrMapContainer(
const ffi::String& attr_name) {
return TargetKindRegistry::Global()->GetAttrMap(attr_name);
}
ffi::Optional<TargetKind> TargetKind::Get(const ffi::String& target_kind_name) {
const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name);
if (reg == nullptr) {
return std::nullopt;
}
return reg->kind_;
}
/********** Utility functions **********/
/*!
* \brief Extract a string from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
* This function first checks if `str` starts with `prefix`,
* then return the integer 20 after the `prefix`
* \param str The string to be extracted
* \param prefix The prefix to be checked
* \return A string, the extracted string. "" if the check fails
*/
std::string ExtractStringWithPrefix(const std::string& str, const std::string& prefix) {
if (str.find(prefix) != 0) return "";
std::size_t pos = prefix.length();
while (pos < str.length() && (std::isdigit(str[pos]) || std::isalpha(str[pos]))) {
++pos;
}
return str.substr(prefix.length(), pos - prefix.length());
}
/*!
* \brief Using TVM DeviceAPI to detect the device flag
* \param device The device to be detected
* \param flag The device flag to be detected
* \param val The detected value
* \return A boolean indicating if detection succeeds
*/
static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::Any* val) {
using runtime::DeviceAPI;
DeviceAPI* api = DeviceAPI::Get(device, true);
// Check if compiled with the corresponding device api
if (api == nullptr) {
return false;
}
// Check if the device exists
api->GetAttr(device, runtime::kExist, val);
int exists = val->cast<int>();
if (!exists) {
return false;
}
// Get the arch of the device
DeviceAPI::Get(device)->GetAttr(device, flag, val);
return true;
}
void CheckOrSetAttr(ffi::Map<ffi::String, ffi::Any>* attrs, const ffi::String& name,
const ffi::String& value) {
auto iter = attrs->find(name);
if (iter == attrs->end()) {
attrs->Set(name, value);
} else {
auto str = (*iter).second.try_cast<ffi::String>();
TVM_FFI_CHECK(str && str.value() == value, ValueError)
<< "Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second;
}
}
/********** Target kind attribute updaters **********/
/*!
* \brief Update the attributes in the CUDA target.
* \param target The Target to update
* \return The updated attributes
*/
ffi::Map<ffi::String, ffi::Any> UpdateCUDAAttrs(ffi::Map<ffi::String, ffi::Any> target) {
// Update -arch=sm_xx
if (target.count("arch")) {
// If -arch has been specified, validate the correctness
ffi::String archStr = Downcast<ffi::String>(target.at("arch"));
TVM_FFI_CHECK(support::StartsWith(archStr, "sm_"), ValueError)
<< "CUDA target gets an invalid CUDA arch: -arch=" << archStr;
} else {
// Use the compute version of the first CUDA GPU instead
int archInt;
ffi::Any version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead";
archInt = 50;
} else {
archInt = std::stod(version.cast<std::string>()) * 10 + 0.1;
}
target.Set("arch", ffi::String("sm_") + std::to_string(archInt));
}
return target;
}
/*!
* \brief Update the attributes in the LLVM NVPTX target.
* \param target The Target to update
* \return The updated attributes
*/
ffi::Map<ffi::String, ffi::Any> UpdateNVPTXAttrs(ffi::Map<ffi::String, ffi::Any> target) {
CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
if (target.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
ffi::String mcpu = Downcast<ffi::String>(target.at("mcpu"));
TVM_FFI_CHECK(support::StartsWith(mcpu, "sm_"), ValueError)
<< "NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
// Use the compute version of the first CUDA GPU instead
int arch;
ffi::Any version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead";
arch = 50;
} else {
arch = std::stod(version.cast<std::string>()) * 10 + 0.1;
}
target.Set("mcpu", ffi::String("sm_") + std::to_string(arch));
}
return target;
}
/*!
* \brief Update the attributes in the LLVM ROCm target.
* \param target The Target to update
* \return The updated attributes
*/
ffi::Map<ffi::String, ffi::Any> UpdateROCmAttrs(ffi::Map<ffi::String, ffi::Any> target) {
CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
std::string arch = "gfx900";
if (target.count("mcpu")) {
ffi::String mcpu = Downcast<ffi::String>(target.at("mcpu"));
arch = ExtractStringWithPrefix(mcpu, "gfx");
TVM_FFI_CHECK(!arch.empty(), ValueError)
<< "ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
} else {
ffi::Any val;
if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) {
arch = (*f_get_rocm_arch)().cast<std::string>();
}
target.Set("mcpu", ffi::String(arch));
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
ffi::Any val;
int version;
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) {
LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5";
version = 305;
} else {
version = val.cast<int>();
}
if (version < 305) {
ffi::Array<ffi::String> mattr;
if (target.count("mattr")) {
mattr = Downcast<ffi::Array<ffi::String>>(target.at("mattr"));
}
mattr.push_back("-code-object-v3");
target.Set("mattr", mattr);
}
return target;
}
/*!
* \brief Update WebGPU target attributes for subgroup-enabled lowering.
* Runtime routing on the WebLLM side guarantees subgroup size == 32.
* Runtime routing on the WebLLM side guarantees
* maxComputeInvocationsPerWorkgroup >= 1024.
* This is intentionally constrained for the subgroup-enabled WASM variant.
* When supports_subgroups is true, canonicalize thread_warp_size to 32 so
* TIR lowering can emit subgroup shuffle reductions.
* \param target The Target to update
* \return The updated attributes
*/
ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, ffi::Any> target) {
bool subgroups = false;
if (target.count("supports_subgroups")) {
subgroups = Downcast<Bool>(target.at("supports_subgroups"));
}
if (target.count("thread_warp_size")) {
int64_t thread_warp_size = Downcast<Integer>(target.at("thread_warp_size"))->value;
TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1)
<< "WebGPU target with thread_warp_size=" << thread_warp_size
<< " requires supports_subgroups=true";
}
if (subgroups) {
target.Set("thread_warp_size", int64_t(32));
}
return target;
}
/*!
* \brief Test Target Parser
* \param target The Target to update
* \return The updated attributes
*/
ffi::Map<ffi::String, ffi::Any> TestTargetParser(ffi::Map<ffi::String, ffi::Any> target) {
target.Set("feature.is_test", true);
return target;
}
/********** Register Target kinds and attributes **********/
TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<ffi::Array<ffi::String>>("mattr")
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("mtriple")
.add_attr_option<ffi::String>("mfloat-abi")
.add_attr_option<ffi::String>("mabi")
.add_attr_option<int64_t>("num-cores")
// Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
.add_attr_option<bool>("fast-math") // implies all the below
.add_attr_option<bool>("fast-math-nnan")
.add_attr_option<bool>("fast-math-ninf")
.add_attr_option<bool>("fast-math-nsz")
.add_attr_option<bool>("fast-math-arcp")
.add_attr_option<bool>("fast-math-contract")
.add_attr_option<bool>("fast-math-reassoc")
.add_attr_option<int64_t>("opt-level")
// LLVM command line flags, see below
.add_attr_option<ffi::Array<ffi::String>>("cl-opt")
// LLVM JIT engine mcjit/orcjit
.add_attr_option<ffi::String>("jit")
// TVM & LLVM custom vector bit width
.add_attr_option<int64_t>("vector-width")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
.set_target_canonicalizer(tvm::target::canonicalizer::llvm::Canonicalize);
// Note regarding the "cl-opt" attribute:
// Each string in the array has the format
// -optionname[[:type]=value]
// where
// * optionname is the actual LLVM option (e.g. "unroll-threshold")
// * type is one of "bool", "int", "uint", or "string"
// * value is the corresponding option value (for "bool" type is can be 0 or "false"
// for false value, or 1 or "true" for true value)
// If type is omitted, it is assumed to be "bool". If value is omitted, it is assumed
// to be "true".
//
// The type must match the option type in LLVM. To find the type, search the LLVM
// repository (https://github.com/llvm/llvm-project) for optionname, and look for
// its definition: it will be a declaration of a variable of type cl::opt<T> with
// optionname being an argument to the constructor. The T in the declaration is
// the type.
// For example, for unroll-threshold, we get the following declaration:
// static cl::opt<unsigned>
// UnrollThreshold("unroll-threshold", cl::Hidden,
// cl::desc("The cost threshold for loop unrolling"));
// Hence the type is "uint".
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("march")
.add_attr_option<int64_t>("workspace-byte-alignment")
.add_attr_option<int64_t>("constants-byte-alignment")
.set_default_keys({"cpu"})
.set_target_canonicalizer(tvm::target::canonicalizer::llvm::Canonicalize);
TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("arch")
.add_attr_option<int64_t>("max_shared_memory_per_block")
.add_attr_option<int64_t>("max_threads_per_block")
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(32))
.add_attr_option<int64_t>("registers_per_block")
.add_attr_option<int64_t>("l2_cache_size_bytes")
.add_attr_option<int64_t>("max_num_threads",
refl::DefaultValue(1024)) // TODO(@zxybazh): deprecate it
.set_default_keys({"cuda", "gpu"})
.set_target_canonicalizer(UpdateCUDAAttrs);
TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("mtriple")
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(1024))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(32))
.set_default_keys({"cuda", "gpu"})
.set_target_canonicalizer(UpdateNVPTXAttrs);
TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("mtriple")
.add_attr_option<ffi::Array<ffi::String>>("mattr")
// TODO(masahi): Support querying from a target device
// On RDNA cards, thread_warp_size should be 32
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_threads_per_block", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_shared_memory_per_block", refl::DefaultValue(65536))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(64))
.set_default_keys({"rocm", "gpu"})
.set_target_canonicalizer(UpdateROCmAttrs);
TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<int64_t>("max_threads_per_block", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_shared_memory_per_block", refl::DefaultValue(16384))
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.add_attr_option<int64_t>("texture_spatial_limit", refl::DefaultValue(16384))
.add_attr_option<int64_t>("texture_depth_limit", refl::DefaultValue(2048))
// Faced that Qualcomm OpenCL runtime crashed without any error message in
// the case when the number of kernel arguments was pretty big. OpenCL doesn't
// specify any limitations on the number of kernel arguments. max_function_args
// equals to 128 looks like a reasonable number of kernel arguments.
.add_attr_option<int64_t>("max_function_args", refl::DefaultValue(128))
.add_attr_option<int64_t>("image_base_address_alignment", refl::DefaultValue(64))
.set_default_keys({"opencl", "gpu"});
// The metal has some limitations on the number of input parameters. This is why attribute
// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More
// information about this limitation can be found here:
// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
// See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_threads_per_block", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_shared_memory_per_block", refl::DefaultValue(32768))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(16))
.add_attr_option<int64_t>("max_function_args", refl::DefaultValue(31))
.set_default_keys({"metal", "gpu"});
TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<ffi::Array<ffi::String>>("mattr")
// Feature support
.add_attr_option<bool>("supports_float16")
.add_attr_option<bool>("supports_float32", refl::DefaultValue(true))
.add_attr_option<bool>("supports_float64")
.add_attr_option<bool>("supports_int8")
.add_attr_option<bool>("supports_int16")
.add_attr_option<bool>("supports_int32", refl::DefaultValue(true))
.add_attr_option<bool>("supports_int64")
.add_attr_option<bool>("supports_8bit_buffer")
.add_attr_option<bool>("supports_16bit_buffer")
.add_attr_option<bool>("supports_storage_buffer_storage_class")
.add_attr_option<bool>("supports_push_descriptor")
.add_attr_option<bool>("supports_dedicated_allocation")
.add_attr_option<bool>("supports_integer_dot_product")
.add_attr_option<bool>("supports_cooperative_matrix")
.add_attr_option<int64_t>("supported_subgroup_operations")
// Physical device limits
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<int64_t>("max_threads_per_block", refl::DefaultValue(256))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.add_attr_option<int64_t>("max_block_size_x")
.add_attr_option<int64_t>("max_block_size_y")
.add_attr_option<int64_t>("max_block_size_z")
.add_attr_option<int64_t>("max_push_constants_size")
.add_attr_option<int64_t>("max_uniform_buffer_range")
.add_attr_option<int64_t>("max_storage_buffer_range")
.add_attr_option<int64_t>("max_per_stage_descriptor_storage_buffer")
.add_attr_option<int64_t>("max_shared_memory_per_block")
// Other device properties
.add_attr_option<ffi::String>("device_type")
.add_attr_option<ffi::String>("device_name")
.add_attr_option<ffi::String>("driver_name")
.add_attr_option<int64_t>("driver_version")
.add_attr_option<int64_t>("vulkan_api_version")
.add_attr_option<int64_t>("max_spirv_version")
// Tags
.set_default_keys({"vulkan", "gpu"});
TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
// thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no
// subgroup ops are emitted.
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});
TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
.add_attr_option<ffi::Array<ffi::String>>("mattr")
.add_attr_option<ffi::String>("mcpu")
.add_attr_option<ffi::String>("mtriple")
.add_attr_option<ffi::Array<ffi::String>>("llvm-options")
.add_attr_option<int64_t>("num-cores")
.add_attr_option<int64_t>("vtcm-capacity")
.set_default_keys({"hexagon", "cpu"});
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);
TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
.add_attr_option<ffi::Array<Target>>(
"devices",
ir::ConfigSchema::AttrValidator(ffi::TypedFunction<ffi::Any(ffi::Any)>( //
[](ffi::Any val) -> ffi::Any {
// Allow elements to be strings or dicts, converting them to Target objects.
if (val.try_cast<ffi::Array<Target>>().has_value()) return val;
auto arr = val.cast<ffi::Array<ffi::Any>>();
ffi::Array<Target> result;
for (const auto& elem : arr) {
if (auto t = elem.try_cast<Target>()) {
result.push_back(t.value());
} else if (auto s = elem.try_cast<ffi::String>()) {
result.push_back(Target(s.value()));
} else if (auto m = elem.try_cast<ffi::Map<ffi::String, ffi::Any>>()) {
result.push_back(Target(m.value()));
} else {
TVM_FFI_THROW(TypeError)
<< "Expected Target, string, or dict in 'devices' array, got '"
<< elem.GetTypeKey() << "'";
}
}
return ffi::Any(result);
})));
TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break
.set_target_canonicalizer(TestTargetParser);
/********** Registry **********/
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.TargetKindGetAttr",
[](TargetKind kind, ffi::String attr_name) -> ffi::Any {
auto target_attr_map = TargetKind::GetAttrMap<ffi::Any>(attr_name);
ffi::Any rv;
if (target_attr_map.count(kind)) {
rv = target_attr_map[kind];
}
return rv;
})
.def("target.ListTargetKinds", TargetKindRegEntry::ListTargetKinds)
.def("target.ListTargetKindOptions", TargetKindRegEntry::ListTargetKindOptions)
.def("target.ListTargetKindOptionsFromName", [](ffi::String target_kind_name) {
TargetKind kind = TargetKind::Get(target_kind_name).value();
return TargetKindRegEntry::ListTargetKindOptions(kind);
});
}
} // namespace tvm