blob: 91a5854b3934e3f4f2d40eab303212ce3f0fef93 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Compile executable modules.
* \file src/target/target.cc
*/
#include <tvm/ffi/extra/json.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/target/tag.h>
#include <tvm/target/target.h>
#include <tvm/target/target_kind.h>
#include <tvm/tir/expr.h>
#include <algorithm>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
TVM_FFI_STATIC_INIT_BLOCK() { TargetNode::RegisterReflection(); }
class TargetInternal {
public:
static void EnterScope(Target target) { target.EnterWithScope(); }
static void ExitScope(Target target) { target.ExitWithScope(); }
static ffi::Map<ffi::String, ffi::Any> ToConfig(Target target) { return target->ToConfig(); }
static ObjectPtr<TargetNode> FromString(const ffi::String& tag_or_config_or_target_str);
static ObjectPtr<TargetNode> FromConfigString(const ffi::String& config_str);
static ObjectPtr<TargetNode> FromConfig(ffi::Map<ffi::String, ffi::Any> config);
static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv);
static Target WithHost(const Target& target, const Target& target_host) {
ObjectPtr<TargetNode> n = ffi::make_object<TargetNode>(*target.get());
n->host = target_host;
return (Target)n;
}
private:
static std::unordered_map<ffi::String, ffi::Any> QueryDevice(int device_id,
const TargetNode* target);
};
/********** Helper functions **********/
Target Target::WithHost(const Target& target, const Target& host) {
return TargetInternal::WithHost(target, host);
}
void CheckAndUpdateHostConsistency(Target* target, Target* host) {
*target = Target(*target, *host);
*host = (*target)->GetHost().value_or(Target());
}
static std::vector<ffi::String> DeduplicateKeys(const std::vector<ffi::String>& keys) {
std::vector<ffi::String> new_keys;
for (size_t i = 0; i < keys.size(); ++i) {
bool found = false;
for (size_t j = 0; j < i; ++j) {
if (keys[i] == keys[j]) {
found = true;
break;
}
}
if (!found) {
new_keys.push_back(keys[i]);
}
}
return new_keys;
}
static TargetKind GetTargetKind(const ffi::String& name) {
ffi::Optional<TargetKind> kind = TargetKind::Get(name);
if (!kind.defined()) {
TVM_FFI_THROW(TypeError) << "Target kind \"" + name + "\" is not defined";
}
return kind.value();
}
const std::string& TargetNode::str() const {
if (str_repr_.empty()) {
str_repr_ = std::string(ffi::json::Stringify(ToConfig()));
}
return str_repr_;
}
/********** Small member methods **********/
Target::Target(const ffi::String& tag_or_config_or_target_str) {
ObjectPtr<Object> target;
try {
target = TargetInternal::FromString(tag_or_config_or_target_str);
} catch (const Error& e) {
std::ostringstream os;
os << ". Target creation from string failed: " << tag_or_config_or_target_str;
throw Error("ValueError", e.message() + os.str(), e.backtrace());
}
data_ = std::move(target);
}
Target::Target(const ffi::Map<ffi::String, ffi::Any>& config) {
ObjectPtr<Object> target;
try {
target = TargetInternal::FromConfig(config);
} catch (const Error& e) {
std::ostringstream os;
os << ". Target creation from config dict failed: " << config;
throw Error("ValueError", std::string(e.message()) + os.str(), e.backtrace());
}
data_ = std::move(target);
}
Target::Target(Target target, Target host) {
ObjectPtr<TargetNode> n = ffi::make_object<TargetNode>(*target.get());
n->host = std::move(host);
data_ = std::move(n);
}
Target::Target(TargetKind kind, ffi::Optional<ObjectRef> host, ffi::String tag,
ffi::Array<ffi::String> keys, ffi::Map<ffi::String, ffi::Any> attrs) {
auto data = ffi::make_object<TargetNode>();
data->kind = std::move(kind);
data->host = std::move(host);
data->tag = std::move(tag);
data->keys = std::move(keys);
data->attrs = std::move(attrs);
data_ = std::move(data);
}
ffi::Map<ffi::String, ffi::Any> TargetNode::ToConfig() const {
ffi::Map<ffi::String, ffi::Any> result = {
{"kind", this->kind->name},
{"tag", this->tag},
{"keys", this->keys},
};
if (this->host.defined()) {
result.Set("host", this->GetHost().value_or(Target())->ToConfig());
}
for (const auto& kv : attrs) {
result.Set(kv.first, kv.second);
}
return result;
}
ffi::Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }
Target Target::WithoutHost() const {
if ((*this)->GetHost()) {
auto output = ffi::make_object<TargetNode>(*get());
output->host = std::nullopt;
return Target(output);
} else {
return *this;
}
}
int TargetNode::GetTargetDeviceType() const {
if (ffi::Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
return Downcast<Integer>(device_type)->value;
}
return kind->default_device_type;
}
bool TargetNode::HasKey(const std::string& query_key) const {
return std::any_of(keys.begin(), keys.end(),
[&query_key](const auto& key) { return key == query_key; });
}
/*! \brief Entry to hold the Target context stack. */
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
std::stack<Target> context_stack;
};
/*! \brief Thread local store to hold the Target context stack. */
static TVMTargetThreadLocalEntry* TVMTargetThreadLocalStoreGet() {
static thread_local TVMTargetThreadLocalEntry inst;
return &inst;
}
void Target::EnterWithScope() {
TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStoreGet();
entry->context_stack.push(*this);
}
void Target::ExitWithScope() {
TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStoreGet();
TVM_FFI_ICHECK(!entry->context_stack.empty());
TVM_FFI_ICHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
Target Target::Current(bool allow_not_defined) {
TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStoreGet();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
TVM_FFI_ICHECK(allow_not_defined)
<< "Target context required. Please set it by constructing a TargetContext";
return Target();
}
/********** Creation **********/
void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) {
if (args.size() == 1) {
const auto& arg = args[0];
if (auto opt_target = arg.as<Target>()) {
*rv = Target(opt_target.value());
} else if (auto opt_str = arg.try_cast<ffi::String>()) {
*rv = Target(opt_str.value());
} else if (auto opt_map = arg.try_cast<ffi::Map<ffi::String, ffi::Any>>()) {
*rv = Target(opt_map.value());
} else {
TVM_FFI_THROW(TypeError) << "Cannot create target with type: " << args[0].GetTypeKey();
}
return;
} else if (args.size() == 2) {
if (args[0].as<Target>().has_value() && args[1].as<Target>().has_value()) {
Target target = args[0].cast<Target>();
Target host = args[1].cast<Target>();
*rv = Target(target, host);
} else {
TVM_FFI_THROW(ValueError) << "Invalid type of arguments. Expect 2 Target arguments.";
}
return;
}
TVM_FFI_THROW(ValueError) << "Invalid number of arguments. Expect 1 or 2, but gets: "
<< args.size();
}
ObjectPtr<TargetNode> TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) {
if (ffi::Optional<Target> target = TargetTag::Get(tag_or_config_or_target_str)) {
Target value = target.value();
return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<TargetNode>(value);
}
if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') {
return TargetInternal::FromConfigString(tag_or_config_or_target_str);
}
// Treat as bare kind name (e.g. "llvm", "cuda"). Reject strings with spaces.
std::string s(tag_or_config_or_target_str);
if (s.find(' ') != std::string::npos) {
TVM_FFI_THROW(ValueError)
<< "Cannot parse target string \"" << s
<< "\". CLI target string form (e.g. \"llvm -mcpu=xxx\") is no longer supported. "
<< "Please use JSON dict form (e.g. {\"kind\": \"llvm\", \"mcpu\": \"xxx\"}) instead.";
}
return TargetInternal::FromConfig(ffi::Map<ffi::String, ffi::Any>{{"kind", ffi::String(s)}});
}
ObjectPtr<TargetNode> TargetInternal::FromConfigString(const ffi::String& config_str) {
ffi::String error_msg;
ffi::json::Value parsed = ffi::json::Parse(config_str, &error_msg);
if (error_msg.size() > 0) {
TVM_FFI_THROW(ValueError) << "Failed to parse target JSON config: " << error_msg;
}
auto config = parsed.as<ffi::Map<ffi::String, ffi::Any>>();
if (!config.has_value()) {
TVM_FFI_THROW(ValueError) << "Target JSON config must be a dict, got: " << config_str;
}
return TargetInternal::FromConfig(config.value());
}
ObjectPtr<TargetNode> TargetInternal::FromConfig(ffi::Map<ffi::String, ffi::Any> config) {
const ffi::String kKind = "kind";
const ffi::String kTag = "tag";
const ffi::String kKeys = "keys";
const ffi::String kDeviceName = "device";
const ffi::String kHost = "host";
const ffi::String kFromDevice = "from_device";
ObjectPtr<TargetNode> target = ffi::make_object<TargetNode>();
// Step 0: If "tag" is present without "kind", look up the tag config and merge overrides on top
if (!config.count(kKind) && config.count(kTag)) {
auto tag_name = config[kTag].try_cast<ffi::String>();
TVM_FFI_ICHECK(tag_name.has_value())
<< "Expect type of field \"tag\" is String, but get type: " << config[kTag].GetTypeKey();
auto tag_config = TargetTag::GetConfig(tag_name.value());
TVM_FFI_ICHECK(tag_config.has_value()) << "Unknown target tag: " << tag_name.value();
// Start from the tag's base config, then apply user overrides
ffi::Map<ffi::String, ffi::Any> merged = tag_config.value();
for (const auto& kv : config) {
if (kv.first != kTag) {
merged.Set(kv.first, kv.second);
}
}
merged.Set(kTag, ffi::String(tag_name.value()));
config = std::move(merged);
}
// Step 1: Parse 'kind' (needed to look up the schema, but kept in config for canonicalizer)
if (config.count(kKind)) {
if (auto kind = config[kKind].try_cast<ffi::String>()) {
target->kind = GetTargetKind(kind.value());
} else {
TVM_FFI_THROW(TypeError) << "Expect type of field \"kind\" is String, but get type: "
<< config[kKind].GetTypeKey();
}
} else {
TVM_FFI_THROW(ValueError) << "Field \"kind\" is not found";
}
// Step 2: Extract "host" before schema validation (needs special recursive parsing)
if (config.count(kHost)) {
target->host = ffi::Function(ConstructorDispatcher)(config[kHost]).cast<Target>();
config.erase(kHost);
} else {
target->host = std::nullopt;
}
// Step 3: Use ConfigSchema to validate types, apply defaults, and run canonicalizer
// Note: structural keys (kind, tag, keys, device) pass through to canonicalizer
ffi::Map<ffi::String, ffi::Any> resolved = target->kind->schema_.Resolve(config);
// Step 4: Extract structural fields from resolved config
if (resolved.count(kTag)) {
if (auto tag = resolved[kTag].try_cast<ffi::String>()) {
target->tag = tag.value();
}
resolved.erase(kTag);
} else {
target->tag = "";
}
{
std::vector<ffi::String> keys;
bool has_keys = resolved.count(kKeys);
if (has_keys) {
ffi::Array<ffi::String> cfg_keys = Downcast<ffi::Array<ffi::String>>(resolved.at(kKeys));
for (const ffi::String& key : cfg_keys) {
keys.push_back(key);
}
}
if (resolved.count(kDeviceName)) {
if (auto device = resolved.at(kDeviceName).try_cast<ffi::String>()) {
keys.push_back(device.value());
}
}
if (!has_keys) {
for (const auto& key : target->kind->default_keys) {
keys.push_back(key);
}
}
target->keys = DeduplicateKeys(keys);
resolved.erase(kKeys);
}
// Step 5: Build attrs from resolved entries (excluding structural keys)
resolved.erase(kKind);
std::unordered_map<ffi::String, ffi::Any> attrs;
for (const auto& kv : resolved) {
attrs[kv.first] = kv.second;
}
// Step 6: If requested, query attributes from the device. User-specified
// parameters take precedence over queried parameters.
int64_t from_device_id = -1;
if (auto it = attrs.find(kFromDevice); it != attrs.end()) {
from_device_id = it->second.cast<int64_t>();
attrs.erase(it);
}
if (from_device_id >= 0) {
auto device_params = QueryDevice(from_device_id, target.get());
for (const auto& kv : device_params) {
if (attrs.count(kv.first) == 0) {
attrs[kv.first] = kv.second;
}
}
}
target->attrs = attrs;
return target;
}
std::unordered_map<ffi::String, ffi::Any> TargetInternal::QueryDevice(int device_id,
const TargetNode* target) {
std::unordered_map<ffi::String, ffi::Any> output;
Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), device_id};
auto api = runtime::DeviceAPI::Get(device, true);
if (!api) {
LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id "
<< device_id << ", but support for this runtime wasn't enabled at compile-time. "
<< "Using default target parameters.";
return output;
}
ffi::Any ret;
api->GetAttr(device, runtime::kExist, &ret);
bool device_exists = ret.cast<bool>();
if (!device_exists) {
TVM_FFI_ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id "
<< device_id
<< " doesn't exist. Using default target parameters.";
return output;
}
for (const auto& e : target->kind->schema_.ListOptions()) {
ffi::Any ret;
api->GetTargetProperty(device, e.key, &ret);
if (ret.type_index() != kTVMFFINone) {
output[e.key] = ret;
}
}
return output;
}
/********** Registry **********/
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("target.Target", TargetInternal::ConstructorDispatcher)
.def("target.TargetEnterScope", TargetInternal::EnterScope)
.def("target.TargetExitScope", TargetInternal::ExitScope)
.def("target.TargetCurrent", Target::Current)
.def("target.TargetExport", TargetInternal::ToConfig)
.def("target.WithHost", TargetInternal::WithHost)
.def("target.TargetGetDeviceType",
[](const Target& target) { return target->GetTargetDeviceType(); })
.def("target.TargetGetFeature",
[](const Target& target, const ffi::String& feature_key) -> Any {
ffi::String full_key = "feature." + std::string(feature_key);
auto it = target->attrs.find(full_key);
if (it != target->attrs.end()) {
return (*it).second;
}
return Any();
});
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) {
p->stream << Downcast<Target>(obj)->str();
});
} // namespace tvm