blob: b5d2bf7ceb8520cb0dc362e7102e362105242f5b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/target/target_kind.cc
* \brief Target kind registry
*/
#include <tvm/ir/expr.h>
#include <tvm/runtime/device_api.h>
#include <tvm/target/target.h>
#include <tvm/target/target_kind.h>
#include <algorithm>
#include "../node/attr_registry.h"
namespace tvm {
TVM_REGISTER_NODE_TYPE(TargetKindNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetKindNode>([](const ObjectRef& obj, ReprPrinter* p) {
const TargetKind& kind = Downcast<TargetKind>(obj);
p->stream << kind->name;
});
/********** Registry-related code **********/
using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>;
TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) {
return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name);
}
void TargetKindRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel);
}
const AttrRegistryMapContainerMap<TargetKind>& TargetKind::GetAttrMapContainer(
const String& attr_name) {
return TargetKindRegistry::Global()->GetAttrMap(attr_name);
}
Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name);
if (reg == nullptr) {
return NullOpt;
}
return reg->kind_;
}
/********** Utility functions **********/
/*!
* \brief Extract a number from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
* This function first checks if `str` starts with `prefix`,
* then return the integer 20 after the `prefix`
* \param str The string to be extracted
* \param prefix The prefix to be checked
* \return An integer, the extracted number. -1 if the check fails
*/
static int ExtractIntWithPrefix(const std::string& str, const std::string& prefix) {
if (str.substr(0, prefix.size()) != prefix) {
return -1;
}
int result = 0;
for (size_t i = prefix.size(); i < str.size(); ++i) {
char c = str[i];
if (!isdigit(c)) {
return -1;
}
result = result * 10 + c - '0';
}
return result;
}
/*!
* \brief Using TVM DeviceAPI to detect the device flag
* \param device The device to be detected
* \param flag The device flag to be detected
* \param val The detected value
* \return A boolean indicating if detection succeeds
*/
static bool DetectDeviceFlag(TVMContext device, runtime::DeviceAttrKind flag, TVMRetValue* val) {
using runtime::DeviceAPI;
DeviceAPI* api = DeviceAPI::Get(device, true);
// Check if compiled with the corresponding device api
if (api == nullptr) {
return false;
}
// Check if the device exists
api->GetAttr(device, runtime::kExist, val);
int exists = *val;
if (!exists) {
return false;
}
// Get the arch of the device
DeviceAPI::Get(device)->GetAttr(device, flag, val);
return true;
}
void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const String& value) {
auto iter = attrs->find(name);
if (iter == attrs->end()) {
attrs->Set(name, value);
} else {
const auto* str = (*iter).second.as<StringObj>();
CHECK(str != nullptr && GetRef<String>(str) == value)
<< "ValueError: Expects \"" << name << "\" to be \"" << value
<< "\", but gets: " << (*iter).second;
}
}
/********** Target kind attribute updaters **********/
/*!
* \brief Update the attributes in the LLVM NVPTX target.
* \param attrs The original attributes
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
CheckOrSetAttr(&attrs, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
int arch;
if (attrs.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
String mcpu = Downcast<String>(attrs.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "sm_");
CHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
// Use the compute version of the first CUDA GPU instead
TVMRetValue version;
if (!DetectDeviceFlag({kDLGPU, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_20\" instead";
arch = 20;
} else {
arch = std::stod(version.operator std::string()) * 10 + 0.1;
}
attrs.Set("mcpu", String("sm_") + std::to_string(arch));
}
return attrs;
}
/*!
* \brief Update the attributes in the LLVM ROCm target.
* \param attrs The original attributes
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
CheckOrSetAttr(&attrs, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
int arch;
if (attrs.count("mcpu")) {
String mcpu = Downcast<String>(attrs.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "gfx");
CHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
} else {
TVMRetValue val;
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead";
arch = 900;
} else {
arch = val.operator int();
}
attrs.Set("mcpu", String("gfx") + std::to_string(arch));
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
TVMRetValue val;
int version;
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) {
LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5";
version = 305;
} else {
version = val.operator int();
}
if (version < 305) {
Array<String> mattr;
if (attrs.count("mattr")) {
mattr = Downcast<Array<String>>(attrs.at("mattr"));
}
mattr.push_back("-code-object-v3");
attrs.Set("mattr", mattr);
}
return attrs;
}
/********** Register Target kinds and attributes **********/
TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<String>("mfloat-abi")
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
.set_default_keys({"cpu"});
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
.add_attr_option<String>("mcpu")
.set_default_keys({"cpu"});
TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("arch")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.set_default_keys({"cuda", "gpu"});
TVM_REGISTER_TARGET_KIND("nvptx", kDLGPU)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.set_default_keys({"cuda", "gpu"})
.set_attrs_preprocessor(UpdateNVPTXAttrs);
TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
.set_default_keys({"rocm", "gpu"})
.set_attrs_preprocessor(UpdateROCmAttrs);
TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size")
.set_default_keys({"opencl", "gpu"});
TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.set_default_keys({"metal", "gpu"});
TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.set_default_keys({"vulkan", "gpu"});
TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.set_default_keys({"webgpu", "gpu"});
TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
.set_default_keys({"sdaccel", "hls"});
TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL)
.add_attr_option<Bool>("system-lib")
.set_default_keys({"aocl", "hls"});
TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL)
.add_attr_option<Bool>("system-lib")
.set_default_keys({"aocl", "hls"});
TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Array<String>>("llvm-options")
.set_default_keys({"hexagon"});
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");
TVM_REGISTER_TARGET_KIND("composite", kDLCPU)
.add_attr_option<Target>("target_host")
.add_attr_option<Array<Target>>("devices");
} // namespace tvm