| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * Compile executable modules. |
| * \file src/target/target.cc |
| */ |
| #include <dmlc/thread_local.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/target/tag.h> |
| #include <tvm/target/target.h> |
| #include <tvm/target/target_kind.h> |
| #include <tvm/tir/expr.h> |
| |
| #include <algorithm> |
| #include <stack> |
| |
| #include "../runtime/object_internal.h" |
| |
| namespace tvm { |
| |
| TVM_REGISTER_NODE_TYPE(TargetNode); |
| |
| class TargetInternal { |
| public: |
| static void EnterScope(Target target) { target.EnterWithScope(); } |
| static void ExitScope(Target target) { target.ExitWithScope(); } |
| static Map<String, ObjectRef> Export(Target target) { return target->Export(); } |
| static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind, |
| const std::string& key); |
| static Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs); |
| static ObjectRef ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); |
| static ObjectRef ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info); |
| static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str); |
| static ObjectPtr<Object> FromConfigString(const String& config_str); |
| static ObjectPtr<Object> FromRawString(const String& target_str); |
| static ObjectPtr<Object> FromConfig(std::unordered_map<String, ObjectRef> config); |
| static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); |
| }; |
| |
| /********** Helper functions **********/ |
| |
| static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) { |
| std::vector<String> new_keys; |
| for (size_t i = 0; i < keys.size(); ++i) { |
| bool found = false; |
| for (size_t j = 0; j < i; ++j) { |
| if (keys[i] == keys[j]) { |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| new_keys.push_back(keys[i]); |
| } |
| } |
| return new_keys; |
| } |
| |
| template <class TObj> |
| static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expected_type) { |
| const TObj* ptr = obj.as<TObj>(); |
| if (ptr == nullptr) { |
| std::ostringstream os; |
| os << ": Expects type \"" << expected_type << "\", but gets \"" << obj->GetTypeKey() |
| << "\" for object: " << obj; |
| throw dmlc::Error(os.str()); |
| } |
| return ptr; |
| } |
| |
| static TargetKind GetTargetKind(const String& name) { |
| Optional<TargetKind> kind = TargetKind::Get(name); |
| if (!kind.defined()) { |
| throw dmlc::Error(": Target kind \"" + name + "\" is not defined"); |
| } |
| return kind.value(); |
| } |
| |
| static std::string RemovePrefixDashes(const std::string& s) { |
| int n_dashes = 0; |
| int len = s.length(); |
| for (; n_dashes < len && s[n_dashes] == '-'; ++n_dashes) { |
| } |
| if (n_dashes == 0) { |
| throw dmlc::Error(": Attribute keys should start with '-', not an attribute key: " + s); |
| } |
| if (n_dashes >= len) { |
| throw dmlc::Error(": Not an attribute key: " + s); |
| } |
| return s.substr(n_dashes); |
| } |
| |
| static int FindFirstSubstr(const std::string& str, const std::string& substr) { |
| size_t pos = str.find_first_of(substr); |
| return pos == std::string::npos ? -1 : pos; |
| } |
| |
| static Optional<String> JoinString(const std::vector<String>& array, char separator) { |
| if (array.empty()) { |
| return NullOpt; |
| } |
| std::ostringstream os; |
| os << array[0]; |
| for (size_t i = 1; i < array.size(); ++i) { |
| os << separator << array[i]; |
| } |
| return String(os.str()); |
| } |
| |
| static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, |
| std::string* value) { |
| int pos; |
| std::string& result_k = *key; |
| std::string& result_v = *value; |
| if ((pos = FindFirstSubstr(s, "=")) != -1) { |
| // case 1. --key=value |
| result_k = s.substr(0, pos); |
| result_v = s.substr(pos + 1); |
| if (result_k.empty() || result_v.empty()) { |
| throw dmlc::Error(": Empty attribute key or value in \"" + s + "\""); |
| } |
| return 1; |
| } else if (!s_next.empty() && s_next[0] != '-') { |
| // case 2. --key value |
| result_k = s; |
| result_v = s_next; |
| return 2; |
| } |
| // case 3. --boolean-key |
| result_k = s; |
| result_v = "1"; |
| return 1; |
| } |
| |
| const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKind& kind, |
| const std::string& key) { |
| auto it = kind->key2vtype_.find(key); |
| if (it == kind->key2vtype_.end()) { |
| std::ostringstream os; |
| os << ": Cannot recognize \'" << key << "\'. Candidates are: "; |
| bool is_first = true; |
| for (const auto& kv : kind->key2vtype_) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| os << ", "; |
| } |
| os << kv.first; |
| } |
| throw dmlc::Error(os.str()); |
| } |
| return it->second; |
| } |
| |
| /********** Parsing **********/ |
| |
| ObjectRef TargetInternal::ParseType(const std::string& str, |
| const TargetKindNode::ValueTypeInfo& info) { |
| std::istringstream is(str); |
| if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing integer |
| int v; |
| if (!(is >> v)) { |
| throw dmlc::Error(": Cannot parse into type \"Integer\" from string: " + str); |
| } |
| return Integer(v); |
| } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing string |
| std::string v; |
| if (!(is >> v)) { |
| throw dmlc::Error(": Cannot parse into type \"String\" from string: " + str); |
| } |
| return String(v); |
| } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing target |
| return Target(TargetInternal::FromString(str)); |
| } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing array |
| std::vector<ObjectRef> result; |
| for (std::string substr; std::getline(is, substr, ',');) { |
| try { |
| ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); |
| result.push_back(parsed); |
| } catch (const dmlc::Error& e) { |
| std::string index = "[" + std::to_string(result.size()) + "]"; |
| throw dmlc::Error(index + e.what()); |
| } |
| } |
| return Array<ObjectRef>(result); |
| } |
| throw dmlc::Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str); |
| } |
| |
| ObjectRef TargetInternal::ParseType(const ObjectRef& obj, |
| const TargetKindNode::ValueTypeInfo& info) { |
| if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing integer |
| return GetRef<Integer>(ObjTypeCheck<IntImmNode>(obj, "Integer")); |
| } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing string |
| return GetRef<String>(ObjTypeCheck<StringObj>(obj, "String")); |
| } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing target |
| if (const auto* ptr = obj.as<TargetNode>()) { |
| return GetRef<Target>(ptr); |
| } else if (const auto* ptr = obj.as<StringObj>()) { |
| return Target(TargetInternal::FromString(GetRef<String>(ptr))); |
| } else if (const auto* ptr = obj.as<MapNode>()) { |
| for (const auto& kv : *ptr) { |
| if (!kv.first->IsInstance<StringObj>()) { |
| throw dmlc::Error(": Target object requires key of dict to be str, but get: " + |
| kv.first->GetTypeKey()); |
| } |
| } |
| Map<String, ObjectRef> config = GetRef<Map<String, ObjectRef>>(ptr); |
| return Target(TargetInternal::FromConfig({config.begin(), config.end()})); |
| } |
| throw dmlc::Error(": Expect type 'dict' or 'str' to construct Target, but get: " + |
| obj->GetTypeKey()); |
| } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing array |
| const auto* array = ObjTypeCheck<ArrayNode>(obj, "Array"); |
| std::vector<ObjectRef> result; |
| for (const ObjectRef& e : *array) { |
| try { |
| result.push_back(TargetInternal::ParseType(e, *info.key)); |
| } catch (const dmlc::Error& e) { |
| std::string index = '[' + std::to_string(result.size()) + ']'; |
| throw dmlc::Error(index + e.what()); |
| } |
| } |
| return Array<ObjectRef>(result); |
| } else if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing map |
| const auto* map = ObjTypeCheck<MapNode>(obj, "Map"); |
| std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result; |
| for (const auto& kv : *map) { |
| ObjectRef key, val; |
| try { |
| key = TargetInternal::ParseType(kv.first, *info.key); |
| } catch (const dmlc::Error& e) { |
| std::ostringstream os; |
| os << "'s key \"" << key << "\"" << e.what(); |
| throw dmlc::Error(os.str()); |
| } |
| try { |
| val = TargetInternal::ParseType(kv.second, *info.val); |
| } catch (const dmlc::Error& e) { |
| std::ostringstream os; |
| os << "[\"" << key << "\"]" << e.what(); |
| throw dmlc::Error(os.str()); |
| } |
| result[key] = val; |
| } |
| return Map<ObjectRef, ObjectRef>(result); |
| } |
| if (info.type_index != obj->type_index()) { |
| std::ostringstream os; |
| os << ": Parsing type \"" << info.type_key |
| << "\" is not supported for the given object of type \"" << obj->GetTypeKey() |
| << "\". The object is: " << obj; |
| throw dmlc::Error(os.str()); |
| } |
| return obj; |
| } |
| |
| /********** Stringifying **********/ |
| |
| static inline Optional<String> StringifyAtomicType(const ObjectRef& obj) { |
| if (const auto* p = obj.as<IntImmNode>()) { |
| return String(std::to_string(p->value)); |
| } |
| if (const auto* p = obj.as<StringObj>()) { |
| return GetRef<String>(p); |
| } |
| return NullOpt; |
| } |
| |
| Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) { |
| std::ostringstream os; |
| std::vector<String> keys; |
| for (const auto& kv : attrs) { |
| keys.push_back(kv.first); |
| } |
| std::sort(keys.begin(), keys.end()); |
| std::vector<String> result; |
| for (const auto& key : keys) { |
| const ObjectRef& obj = attrs[key]; |
| Optional<String> value = NullOpt; |
| if (const auto* array = obj.as<ArrayNode>()) { |
| std::vector<String> items; |
| for (const ObjectRef& item : *array) { |
| Optional<String> str = StringifyAtomicType(item); |
| if (str.defined()) { |
| items.push_back(str.value()); |
| } else { |
| items.clear(); |
| break; |
| } |
| } |
| value = JoinString(items, ','); |
| } else { |
| value = StringifyAtomicType(obj); |
| } |
| if (value.defined()) { |
| result.push_back("-" + key + "=" + value.value()); |
| } |
| } |
| return JoinString(result, ' '); |
| } |
| |
| const std::string& TargetNode::str() const { |
| if (str_repr_.empty()) { |
| std::ostringstream os; |
| os << kind->name; |
| if (!this->keys.empty()) { |
| os << " -keys="; |
| bool is_first = true; |
| for (const String& s : keys) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| os << ','; |
| } |
| os << s; |
| } |
| } |
| if (Optional<String> attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { |
| os << ' ' << attrs_str.value(); |
| } |
| str_repr_ = os.str(); |
| } |
| return str_repr_; |
| } |
| |
| /********** Small member methods **********/ |
| |
| Target::Target(const String& tag_or_config_or_target_str) { |
| ObjectPtr<Object> target; |
| try { |
| target = TargetInternal::FromString(tag_or_config_or_target_str); |
| } catch (const dmlc::Error& e) { |
| LOG(FATAL) << "ValueError" << e.what() |
| << ". Target creation from string failed: " << tag_or_config_or_target_str; |
| } |
| data_ = std::move(target); |
| } |
| |
| Target::Target(const Map<String, ObjectRef>& config) { |
| ObjectPtr<Object> target; |
| try { |
| target = TargetInternal::FromConfig({config.begin(), config.end()}); |
| } catch (const dmlc::Error& e) { |
| LOG(FATAL) << "ValueError" << e.what() |
| << ". Target creation from config dict failed: " << config; |
| } |
| data_ = std::move(target); |
| } |
| |
| std::vector<std::string> TargetNode::GetKeys() const { |
| std::vector<std::string> result; |
| for (auto& expr : keys) { |
| result.push_back(expr); |
| } |
| return result; |
| } |
| |
| std::unordered_set<std::string> TargetNode::GetLibs() const { |
| Optional<Array<String>> libs = this->GetAttr<Array<String>>("libs"); |
| if (!libs.defined()) { |
| return {}; |
| } |
| std::unordered_set<std::string> result; |
| for (const auto& item : libs.value()) { |
| result.insert(item); |
| } |
| return result; |
| } |
| |
| Map<String, ObjectRef> TargetNode::Export() const { |
| Map<String, ObjectRef> result = { |
| {"kind", this->kind->name}, |
| {"tag", this->tag}, |
| {"keys", this->keys}, |
| }; |
| for (const auto& kv : attrs) { |
| result.Set(kv.first, kv.second); |
| } |
| return result; |
| } |
| |
| /*! \brief Entry to hold the Target context stack. */ |
| struct TVMTargetThreadLocalEntry { |
| /*! \brief The current target context */ |
| std::stack<Target> context_stack; |
| }; |
| |
| /*! \brief Thread local store to hold the Target context stack. */ |
| using TVMTargetThreadLocalStore = dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>; |
| |
| void Target::EnterWithScope() { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| entry->context_stack.push(*this); |
| } |
| |
| void Target::ExitWithScope() { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| CHECK(!entry->context_stack.empty()); |
| CHECK(entry->context_stack.top().same_as(*this)); |
| entry->context_stack.pop(); |
| } |
| |
| Target Target::Current(bool allow_not_defined) { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| if (entry->context_stack.size() > 0) { |
| return entry->context_stack.top(); |
| } |
| CHECK(allow_not_defined) |
| << "Target context required. Please set it by constructing a TargetContext"; |
| |
| return Target(); |
| } |
| |
| /********** Creation **********/ |
| |
| void TargetInternal::ConstructorDispatcher(TVMArgs args, TVMRetValue* rv) { |
| if (args.num_args == 1) { |
| const auto& arg = args[0]; |
| if (arg.IsObjectRef<Target>()) { |
| *rv = Target(arg.AsObjectRef<Target>()); |
| } else if (String::CanConvertFrom(arg)) { |
| *rv = Target(arg.operator String()); |
| } else if (arg.IsObjectRef<Map<String, ObjectRef>>()) { |
| *rv = Target(arg.operator Map<String, ObjectRef>()); |
| } else if (arg.type_code() == kTVMObjectHandle) { |
| ObjectRef obj = arg; |
| LOG(FATAL) << "TypeError: Cannot create target with type: " << obj->GetTypeKey(); |
| } else { |
| LOG(FATAL) << "TypeError: Cannot create target with type: " |
| << runtime::ArgTypeCode2Str(arg.type_code()); |
| } |
| return; |
| } |
| LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1, but gets: " << args.num_args; |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromString(const String& tag_or_config_or_target_str) { |
| if (Optional<Target> target = TargetTag::Get(tag_or_config_or_target_str)) { |
| Target value = target.value(); |
| return runtime::ObjectInternal::MoveObjectPtr(&value); |
| } |
| if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { |
| return TargetInternal::FromConfigString(tag_or_config_or_target_str); |
| } |
| return TargetInternal::FromRawString(tag_or_config_or_target_str); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) { |
| const auto* loader = tvm::runtime::Registry::Get("target._load_config_dict"); |
| CHECK(loader) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " |
| "if the python module is properly loaded"; |
| Optional<Map<String, ObjectRef>> config = (*loader)(config_str); |
| if (!config.defined()) { |
| throw dmlc::Error(": Cannot load config dict with python JSON loader"); |
| } |
| return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) { |
| // Split the string by empty spaces |
| std::string name; |
| std::vector<std::string> options; |
| std::string str; |
| for (std::istringstream is(target_str); is >> str;) { |
| if (name.empty()) { |
| name = str; |
| } else { |
| options.push_back(str); |
| } |
| } |
| if (name.empty()) { |
| throw dmlc::Error(": Cannot parse empty target string"); |
| } |
| // Create the target config |
| std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}}; |
| TargetKind kind = GetTargetKind(name); |
| for (size_t iter = 0, end = options.size(); iter < end;) { |
| std::string key, value; |
| try { |
| // Parse key-value pair |
| std::string s_next = (iter + 1 < options.size()) ? options[iter + 1] : ""; |
| iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value); |
| } catch (const dmlc::Error& e) { |
| throw dmlc::Error(": Error when parsing target" + std::string(e.what())); |
| } |
| try { |
| // check if `key` has been used |
| if (config.count(key)) { |
| throw dmlc::Error(": The key \"" + key + "\" appears more than once"); |
| } |
| config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key)); |
| } catch (const dmlc::Error& e) { |
| throw dmlc::Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
| } |
| } |
| return TargetInternal::FromConfig(config); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRef> config) { |
| const String kKind = "kind"; |
| const String kTag = "tag"; |
| const String kKeys = "keys"; |
| const String kDeviceName = "device"; |
| ObjectPtr<TargetNode> target = make_object<TargetNode>(); |
| // parse 'kind' |
| if (config.count(kKind)) { |
| if (const auto* kind = config[kKind].as<StringObj>()) { |
| target->kind = GetTargetKind(GetRef<String>(kind)); |
| config.erase(kKind); |
| } else { |
| throw dmlc::Error(": Expect type of field \"kind\" is String, but get type: " + |
| config[kKind]->GetTypeKey()); |
| } |
| } else { |
| throw dmlc::Error(": Field \"kind\" is not found"); |
| } |
| // parse "tag" |
| if (config.count(kTag)) { |
| if (const auto* tag = config[kTag].as<StringObj>()) { |
| target->tag = GetRef<String>(tag); |
| config.erase(kTag); |
| } else { |
| throw dmlc::Error(": Expect type of field \"tag\" is String, but get type: " + |
| config[kTag]->GetTypeKey()); |
| } |
| } else { |
| target->tag = ""; |
| } |
| // parse "keys" |
| { |
| std::vector<String> keys; |
| if (config.count(kKeys)) { |
| // user provided keys |
| if (const auto* cfg_keys = config[kKeys].as<ArrayNode>()) { |
| for (const ObjectRef& e : *cfg_keys) { |
| if (const auto* key = e.as<StringObj>()) { |
| keys.push_back(GetRef<String>(key)); |
| } else { |
| throw dmlc::Error( |
| ": Expect 'keys' to be an array of strings, but it " |
| "contains an element of type: " + |
| e->GetTypeKey()); |
| } |
| } |
| } else { |
| throw dmlc::Error(": Expect type of field \"keys\" is Array, but get type: " + |
| config[kKeys]->GetTypeKey()); |
| } |
| } |
| // add device name |
| if (config.count(kDeviceName)) { |
| if (const auto* device = config.at(kDeviceName).as<StringObj>()) { |
| keys.push_back(GetRef<String>(device)); |
| } |
| } |
| // add default keys |
| for (const auto& key : target->kind->default_keys) { |
| keys.push_back(key); |
| } |
| // de-duplicate keys |
| target->keys = DeduplicateKeys(keys); |
| config.erase(kKeys); |
| } |
| // parse attrs |
| std::unordered_map<String, ObjectRef> attrs; |
| for (const auto& cfg_kv : config) { |
| const String& key = cfg_kv.first; |
| const ObjectRef& value = cfg_kv.second; |
| try { |
| const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key); |
| attrs[key] = TargetInternal::ParseType(value, info); |
| } catch (const dmlc::Error& e) { |
| throw dmlc::Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
| } |
| } |
| // set default attribute values if they do not exist |
| for (const auto& kv : target->kind->key2default_) { |
| if (!attrs.count(kv.first)) { |
| attrs[kv.first] = kv.second; |
| } |
| } |
| // do extra pre-processing |
| if (target->kind->preprocessor != nullptr) { |
| target->attrs = target->kind->preprocessor(Map<String, ObjectRef>(attrs)); |
| } else { |
| target->attrs = attrs; |
| } |
| return target; |
| } |
| |
| /********** Registry **********/ |
| |
| TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher); |
| TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); |
| TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); |
| TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); |
| TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) { |
| p->stream << Downcast<Target>(obj)->str(); |
| }); |
| |
| } // namespace tvm |