| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * Compile executable modules. |
| * \file src/target/target.cc |
| */ |
| #include <dmlc/thread_local.h> |
| #include <tvm/runtime/device_api.h> |
| #include <tvm/runtime/logging.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/target/tag.h> |
| #include <tvm/target/target.h> |
| #include <tvm/target/target_kind.h> |
| #include <tvm/tir/expr.h> |
| |
| #include <algorithm> |
| #include <cctype> |
| #include <ios> |
| #include <sstream> |
| #include <stack> |
| #include <string> |
| #include <unordered_map> |
| #include <utility> |
| #include <vector> |
| |
| #include "../runtime/object_internal.h" |
| |
| namespace tvm { |
| |
| TVM_REGISTER_NODE_TYPE(TargetNode); |
| |
| class TargetInternal { |
| public: |
| static void EnterScope(Target target) { target.EnterWithScope(); } |
| static void ExitScope(Target target) { target.ExitWithScope(); } |
| static Map<String, ObjectRef> Export(Target target) { return target->Export(); } |
| static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind, |
| const std::string& key); |
| static Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs); |
| static ObjectRef ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); |
| static ObjectRef ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info); |
| static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str); |
| static ObjectPtr<Object> FromConfigString(const String& config_str); |
| static ObjectPtr<Object> FromRawString(const String& target_str); |
| static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config); |
| static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); |
| static Target WithHost(const Target& target, const Target& target_host) { |
| ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get()); |
| n->host = target_host; |
| return (Target)n; |
| } |
| |
| private: |
| static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target); |
| static bool IsQuoted(const std::string& str); |
| static std::string Quote(const std::string& str); |
| static std::string JoinString(const std::vector<std::string>& array, char separator); |
| static std::vector<std::string> SplitString(const std::string& str, char separator); |
| static std::string Interpret(const std::string& str); |
| static std::string Uninterpret(const std::string& str); |
| static std::string StringifyAtomicType(const ObjectRef& obj); |
| static std::string StringifyArray(const ArrayNode& array); |
| |
| static constexpr char quote = '\''; |
| static constexpr char escape = '\\'; |
| }; |
| |
| /********** Helper functions **********/ |
| Target Target::WithHost(const Target& target, const Target& host) { |
| return TargetInternal::WithHost(target, host); |
| } |
| |
| void CheckAndUpdateHostConsistency(Target* target, Target* host) { |
| *target = Target(*target, *host); |
| *host = (*target)->GetHost().value_or(Target()); |
| } |
| |
| void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host) { |
| Map<Target, IRModule> new_targets; |
| for (auto& it : *targets) { |
| auto target = it.first; |
| CheckAndUpdateHostConsistency(&target, host); |
| new_targets.Set(target, it.second); |
| } |
| *targets = new_targets; |
| } |
| |
| static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) { |
| std::vector<String> new_keys; |
| for (size_t i = 0; i < keys.size(); ++i) { |
| bool found = false; |
| for (size_t j = 0; j < i; ++j) { |
| if (keys[i] == keys[j]) { |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| new_keys.push_back(keys[i]); |
| } |
| } |
| return new_keys; |
| } |
| |
| template <class TObj> |
| static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expected_type) { |
| const TObj* ptr = obj.as<TObj>(); |
| if (ptr == nullptr) { |
| std::ostringstream os; |
| os << ": Expects type \"" << expected_type << "\", but gets \"" << obj->GetTypeKey() |
| << "\" for object: " << obj; |
| throw Error(os.str()); |
| } |
| return ptr; |
| } |
| |
| static TargetKind GetTargetKind(const String& name) { |
| Optional<TargetKind> kind = TargetKind::Get(name); |
| if (!kind.defined()) { |
| throw Error(": Target kind \"" + name + "\" is not defined"); |
| } |
| return kind.value(); |
| } |
| |
| static std::string RemovePrefixDashes(const std::string& s) { |
| int n_dashes = 0; |
| int len = s.length(); |
| for (; n_dashes < len && s[n_dashes] == '-'; ++n_dashes) { |
| } |
| if (n_dashes == 0) { |
| throw Error(": Attribute keys should start with '-', not an attribute key: " + s); |
| } |
| if (n_dashes >= len) { |
| throw Error(": Not an attribute key: " + s); |
| } |
| return s.substr(n_dashes); |
| } |
| |
| bool TargetInternal::IsQuoted(const std::string& str) { |
| std::string::size_type start = 0, end = str.size(); |
| if (end < 2 || str[start] != quote || str[end - 1] != quote) { |
| return false; |
| } |
| bool escaping = false; |
| for (auto i = start + 1, e = end - 1; i < e; ++i) { |
| if (escaping) { |
| escaping = false; |
| } else if (str[i] == escape) { |
| escaping = true; |
| } else if (str[i] == quote) { |
| return false; |
| } |
| } |
| // If the reduced string ends with \, then the terminating quote is escaped. |
| return !escaping; |
| } |
| |
| std::string TargetInternal::Quote(const std::string& str) { |
| std::string result(1, quote); |
| result.append(str); |
| result.push_back(quote); |
| return result; |
| } |
| |
| std::string TargetInternal::JoinString(const std::vector<std::string>& array, char separator) { |
| std::string result; |
| ICHECK(separator != quote && separator != escape) |
| << "string join separator cannot be " << quote << " or " << escape; |
| |
| bool is_first = true; |
| for (const auto& s : array) { |
| if (!is_first) { |
| result.push_back(separator); |
| } |
| result.append(s); |
| is_first = false; |
| } |
| |
| return result; |
| } |
| |
| std::vector<std::string> TargetInternal::SplitString(const std::string& str, char separator) { |
| std::vector<std::string> output; |
| |
| const char* start = str.data(); |
| const char* end = start + str.size(); |
| const char* pos = start; |
| |
| std::stringstream current_word; |
| |
| auto finish_word = [&]() { |
| std::string word = current_word.str(); |
| if (word.size()) { |
| output.push_back(word); |
| current_word.str(""); |
| } |
| }; |
| |
| bool pos_quoted = false; |
| |
| while (pos < end) { |
| if ((*pos == separator) && !pos_quoted) { |
| finish_word(); |
| pos++; |
| } else if (*pos == escape && pos + 1 < end) { |
| current_word << escape; |
| current_word << pos[1]; |
| pos += 2; |
| } else if (*pos == quote) { |
| current_word << quote; |
| pos_quoted = !pos_quoted; |
| pos++; |
| } else { |
| current_word << *pos; |
| pos++; |
| } |
| } |
| |
| ICHECK(!pos_quoted) << "Mismatched quotes '' in string"; |
| |
| finish_word(); |
| |
| return output; |
| } |
| |
| std::string TargetInternal::Interpret(const std::string& str) { |
| // String interpretation deals with quotes (') and escapes(\). |
| // - An escape character must be followed by another character forming an |
| // "escape sequence". (Trailing escape is not allowed.) An escape prevents |
| // interpretation of the character that follows. This happens regardless of |
| // whether the escape sequence appears within quoted substring or not. |
| // - A quote character, when interpreted, marks the beginning or the end of a |
| // quoted substring. (A quoted substring cannot contain unescaped quotes.) |
| // - Any other character, when interpreted, represents itself. |
| // |
| // Interpretation happens in two steps: |
| // 1. If the entire string is quoted, the quotes are removed first, and the |
| // resulting string is treated as unquoted. |
| // 2. Each character or escape sequence is interpreted, and the result is copied |
| // to the result. When not inside a quoted substring, the interpretation of an |
| // escape sequence is the escaped character, otherwise it is the entire escape |
| // sequence. |
| // |
| // Examples: |
| // blah -> blah Nothing happened |
| // 'blah' -> blah Enclosing quotes removed |
| // 'bl'ah -> 'bl'ah Non-enclosing quotes remain |
| // '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes |
| // interpreted. |
| // '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above. |
| // |
| // Note that |
| // '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah' |
| |
| std::string result; |
| if (str.empty()) { |
| return result; |
| } |
| |
| // Check if the entire string is enclosed in quotes ''. If so, strip the quotes |
| // and treat the string as unquoted (so that escapes are interpreted). Doing that |
| // will allow '\'foo\'' to become 'foo', instead of \'foo\'. |
| std::string::size_type start = 0, end = str.size(); |
| if (IsQuoted(str)) { |
| start++; |
| end--; |
| } |
| |
| bool inside_quote = false; |
| bool escaping = false; |
| |
| for (auto i = start, e = end; i < e; ++i) { |
| std::string::value_type c = str[i]; |
| if (escaping) { |
| escaping = false; |
| } else if (c == escape) { |
| escaping = true; |
| if (!inside_quote) { |
| continue; |
| } |
| } else if (c == quote) { |
| inside_quote = !inside_quote; |
| } |
| result.push_back(c); |
| } |
| |
| return result; |
| } |
| |
| std::string TargetInternal::Uninterpret(const std::string& str) { |
| // Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str. |
| std::string result; |
| |
| for (std::string::size_type i = 0, e = str.size(); i < e; ++i) { |
| std::string::value_type c = str[i]; |
| if (c == escape || c == quote) { |
| result.push_back(escape); |
| } |
| result.push_back(c); |
| } |
| |
| return result; |
| } |
| |
| static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, |
| std::string* value) { |
| std::string::size_type pos; |
| std::string& result_k = *key; |
| std::string& result_v = *value; |
| if ((pos = s.find_first_of('=')) != std::string::npos) { |
| // case 1. --key=value |
| result_k = s.substr(0, pos); |
| result_v = s.substr(pos + 1); |
| if (result_k.empty() || result_v.empty()) { |
| throw Error(": Empty attribute key or value in \"" + s + "\""); |
| } |
| return 1; |
| } else if (!s_next.empty() && s_next[0] != '-') { |
| // case 2. --key value |
| result_k = s; |
| result_v = s_next; |
| return 2; |
| } |
| // case 3. --boolean-key |
| result_k = s; |
| result_v = "1"; |
| return 1; |
| } |
| |
| const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKind& kind, |
| const std::string& key) { |
| auto it = kind->key2vtype_.find(key); |
| if (it == kind->key2vtype_.end()) { |
| std::ostringstream os; |
| os << ": Cannot recognize \'" << key << "\'. Candidates are: "; |
| bool is_first = true; |
| for (const auto& kv : kind->key2vtype_) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| os << ", "; |
| } |
| os << kv.first; |
| } |
| throw Error(os.str()); |
| } |
| return it->second; |
| } |
| |
| /********** Parsing **********/ |
| |
| ObjectRef TargetInternal::ParseType(const std::string& str, |
| const TargetKindNode::ValueTypeInfo& info) { |
| std::string interp_str = Interpret(str); |
| if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing integer |
| std::istringstream is(interp_str); |
| int v; |
| if (!(is >> v)) { |
| std::string lower(interp_str.size(), '\x0'); |
| std::transform(interp_str.begin(), interp_str.end(), lower.begin(), |
| [](unsigned char c) { return std::tolower(c); }); |
| // Bool is a subclass of IntImm, so allow textual boolean values. |
| if (lower == "true") { |
| v = 1; |
| } else if (lower == "false") { |
| v = 0; |
| } else { |
| throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); |
| } |
| } |
| return Integer(v); |
| } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing string, strip leading/trailing spaces, and enclosing quotes if any |
| auto start = interp_str.find_first_not_of(' '); |
| auto end = interp_str.find_last_not_of(' '); |
| if (start == std::string::npos || end == std::string::npos) { |
| // The whole string is made of spaces. |
| return String(); |
| } |
| return String(interp_str.substr(start, (end - start + 1))); |
| |
| } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing target |
| return Target(TargetInternal::FromString(interp_str)); |
| } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing array |
| std::vector<ObjectRef> result; |
| for (const std::string& substr : SplitString(interp_str, ',')) { |
| try { |
| ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); |
| result.push_back(parsed); |
| } catch (const Error& e) { |
| std::string index = "[" + std::to_string(result.size()) + "]"; |
| throw Error(index + e.what()); |
| } |
| } |
| return Array<ObjectRef>(result); |
| } |
| throw Error(": Unsupported type \"" + info.type_key + |
| "\" for parsing from string: " + interp_str); |
| } |
| |
| ObjectRef TargetInternal::ParseType(const ObjectRef& obj, |
| const TargetKindNode::ValueTypeInfo& info) { |
| if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing integer |
| return GetRef<Integer>(ObjTypeCheck<IntImmNode>(obj, "Integer")); |
| } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing string |
| return GetRef<String>(ObjTypeCheck<StringObj>(obj, "String")); |
| } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing target |
| if (const auto* ptr = obj.as<TargetNode>()) { |
| return GetRef<Target>(ptr); |
| } else if (const auto* ptr = obj.as<StringObj>()) { |
| return Target(TargetInternal::FromString(GetRef<String>(ptr))); |
| } else if (const auto* ptr = obj.as<MapNode>()) { |
| for (const auto& kv : *ptr) { |
| if (!kv.first->IsInstance<StringObj>()) { |
| throw Error(": Target object requires key of dict to be str, but get: " + |
| kv.first->GetTypeKey()); |
| } |
| } |
| Map<String, ObjectRef> config = GetRef<Map<String, ObjectRef>>(ptr); |
| return Target(TargetInternal::FromConfig({config.begin(), config.end()})); |
| } |
| throw Error(": Expect type 'dict' or 'str' to construct Target, but get: " + obj->GetTypeKey()); |
| } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing array |
| const auto* array = ObjTypeCheck<ArrayNode>(obj, "Array"); |
| std::vector<ObjectRef> result; |
| for (const ObjectRef& e : *array) { |
| try { |
| result.push_back(TargetInternal::ParseType(e, *info.key)); |
| } catch (const Error& e) { |
| std::string index = '[' + std::to_string(result.size()) + ']'; |
| throw Error(index + e.what()); |
| } |
| } |
| return Array<ObjectRef>(result); |
| } else if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) { |
| // Parsing map |
| const auto* map = ObjTypeCheck<MapNode>(obj, "Map"); |
| std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result; |
| for (const auto& kv : *map) { |
| ObjectRef key, val; |
| try { |
| key = TargetInternal::ParseType(kv.first, *info.key); |
| } catch (const Error& e) { |
| std::ostringstream os; |
| os << "'s key \"" << key << "\"" << e.what(); |
| throw Error(os.str()); |
| } |
| try { |
| val = TargetInternal::ParseType(kv.second, *info.val); |
| } catch (const Error& e) { |
| std::ostringstream os; |
| os << "[\"" << key << "\"]" << e.what(); |
| throw Error(os.str()); |
| } |
| result[key] = val; |
| } |
| return Map<ObjectRef, ObjectRef>(result); |
| } |
| if (info.type_index != obj->type_index()) { |
| std::ostringstream os; |
| os << ": Parsing type \"" << info.type_key |
| << "\" is not supported for the given object of type \"" << obj->GetTypeKey() |
| << "\". The object is: " << obj; |
| throw Error(os.str()); |
| } |
| return obj; |
| } |
| |
| /********** Stringifying **********/ |
| |
| std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { |
| if (const auto* p = obj.as<IntImmNode>()) { |
| return std::to_string(p->value); |
| } |
| if (const auto* p = obj.as<StringObj>()) { |
| auto s = static_cast<std::string>(GetRef<String>(p)); |
| auto u = Uninterpret(s); |
| if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { |
| u = Quote(u); |
| } |
| return u; |
| } |
| LOG(FATAL) << "Cannot stringify this object"; |
| } |
| |
| std::string TargetInternal::StringifyArray(const ArrayNode& array) { |
| std::vector<std::string> elements; |
| |
| for (const ObjectRef& item : array) { |
| std::string s = StringifyAtomicType(item); |
| std::string u = Uninterpret(s); |
| if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) { |
| u = Quote(u); |
| } |
| elements.push_back(u); |
| } |
| |
| return JoinString(elements, ','); |
| } |
| |
| Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) { |
| std::ostringstream os; |
| std::vector<String> keys; |
| for (const auto& kv : attrs) { |
| keys.push_back(kv.first); |
| } |
| std::sort(keys.begin(), keys.end()); |
| std::vector<std::string> result; |
| |
| for (const auto& key : keys) { |
| const ObjectRef& obj = attrs[key]; |
| std::string value; |
| if (const auto* array = obj.as<ArrayNode>()) { |
| value = String(StringifyArray(*array)); |
| } else { |
| value = StringifyAtomicType(obj); |
| } |
| if (!value.empty()) { |
| result.push_back("-" + key + "=" + value); |
| } |
| } |
| return String(JoinString(result, ' ')); |
| } |
| |
| const std::string& TargetNode::str() const { |
| if (str_repr_.empty()) { |
| std::ostringstream os; |
| os << kind->name; |
| if (!this->keys.empty()) { |
| os << " -keys="; |
| bool is_first = true; |
| for (const String& s : keys) { |
| if (is_first) { |
| is_first = false; |
| } else { |
| os << ','; |
| } |
| os << s; |
| } |
| } |
| if (Optional<String> attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { |
| os << ' ' << attrs_str.value(); |
| } |
| |
| str_repr_ = os.str(); |
| } |
| return str_repr_; |
| } |
| |
| /********** Small member methods **********/ |
| |
| Target::Target(const String& tag_or_config_or_target_str) { |
| ObjectPtr<Object> target; |
| try { |
| target = TargetInternal::FromString(tag_or_config_or_target_str); |
| } catch (const Error& e) { |
| LOG(FATAL) << "ValueError" << e.what() |
| << ". Target creation from string failed: " << tag_or_config_or_target_str; |
| } |
| data_ = std::move(target); |
| } |
| |
| Target::Target(const Map<String, ObjectRef>& config) { |
| ObjectPtr<Object> target; |
| try { |
| target = TargetInternal::FromConfig({config.begin(), config.end()}); |
| } catch (const Error& e) { |
| LOG(FATAL) << "ValueError" << e.what() |
| << ". Target creation from config dict failed: " << config; |
| } |
| data_ = std::move(target); |
| } |
| |
| Target::Target(Target target, Target host) { |
| ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get()); |
| n->host = std::move(host); |
| data_ = std::move(n); |
| } |
| |
| Target::Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<String> keys, |
| Map<String, ObjectRef> attrs) { |
| auto data = runtime::make_object<TargetNode>(); |
| data->kind = std::move(kind); |
| data->host = std::move(host); |
| data->tag = std::move(tag); |
| data->keys = std::move(keys); |
| data->attrs = std::move(attrs); |
| data_ = std::move(data); |
| } |
| |
| bool Target::IsExternalCodegen() const { |
| TargetKindAttrMap<Bool> is_external_codegen_map = |
| TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen); |
| TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map = |
| TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR); |
| return is_external_codegen_map.get(get()->kind, Bool(false)) || |
| relay_to_tir_map.count(get()->kind); |
| } |
| |
| bool Target::IsExternalCodegenFor(const Target& that) const { |
| return get()->GetTargetDeviceType() == that->GetTargetDeviceType() && IsExternalCodegen() && |
| !that.IsExternalCodegen(); |
| } |
| |
| std::vector<std::string> TargetNode::GetKeys() const { |
| std::vector<std::string> result; |
| for (auto& expr : keys) { |
| result.push_back(expr); |
| } |
| return result; |
| } |
| |
| std::unordered_set<std::string> TargetNode::GetLibs() const { |
| Optional<Array<String>> libs = this->GetAttr<Array<String>>("libs"); |
| if (!libs.defined()) { |
| return {}; |
| } |
| std::unordered_set<std::string> result; |
| for (const auto& item : libs.value()) { |
| result.insert(item); |
| } |
| return result; |
| } |
| |
| Map<String, ObjectRef> TargetNode::Export() const { |
| Map<String, ObjectRef> result = { |
| {"kind", this->kind->name}, |
| {"tag", this->tag}, |
| {"keys", this->keys}, |
| }; |
| if (this->host.defined()) { |
| result.Set("host", this->GetHost().value_or(Target())->Export()); |
| } |
| for (const auto& kv : attrs) { |
| result.Set(kv.first, kv.second); |
| } |
| return result; |
| } |
| |
| Optional<Target> TargetNode::GetHost() const { |
| return GetRef<Optional<Target>>(this->host.as<TargetNode>()); |
| } |
| |
| int TargetNode::GetTargetDeviceType() const { |
| if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) { |
| return Downcast<Integer>(device_type)->value; |
| } |
| return kind->default_device_type; |
| } |
| |
| String TargetNode::ToDebugString() const { |
| std::ostringstream os; |
| os << "Target("; |
| os << "id=" << std::hex << reinterpret_cast<size_t>(this); |
| os << ", kind='" << kind->name << "'"; |
| if (!tag.empty()) { |
| os << ", tag='" << tag << "'"; |
| } |
| if (!keys.empty()) { |
| os << ", keys={"; |
| bool first = true; |
| for (const auto& key : keys) { |
| if (!first) { |
| os << ", "; |
| } |
| os << "'" << key << "'"; |
| first = false; |
| } |
| os << "}"; |
| } |
| if (!attrs.empty()) { |
| os << ", attrs={"; |
| bool first = true; |
| for (const auto& pair : attrs) { |
| if (!first) { |
| os << ", "; |
| } |
| os << "'" << pair.first << "': " << pair.second; |
| first = false; |
| } |
| os << "}"; |
| } |
| if (host.defined()) { |
| os << ", host=" << GetHost().value()->ToDebugString(); |
| } |
| os << ")"; |
| return os.str(); |
| } |
| |
| bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { |
| return equal(kind.get(), other->kind.get()) && equal(host, other->host) && |
| equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); |
| } |
| |
| void TargetNode::SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(kind.get()); |
| hash_reduce(host); |
| hash_reduce(tag); |
| hash_reduce(keys); |
| hash_reduce(attrs); |
| } |
| |
| /*! \brief Entry to hold the Target context stack. */ |
| struct TVMTargetThreadLocalEntry { |
| /*! \brief The current target context */ |
| std::stack<Target> context_stack; |
| }; |
| |
| /*! \brief Thread local store to hold the Target context stack. */ |
| using TVMTargetThreadLocalStore = dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>; |
| |
| void Target::EnterWithScope() { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| entry->context_stack.push(*this); |
| } |
| |
| void Target::ExitWithScope() { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| ICHECK(!entry->context_stack.empty()); |
| ICHECK(entry->context_stack.top().same_as(*this)); |
| entry->context_stack.pop(); |
| } |
| |
| Target Target::Current(bool allow_not_defined) { |
| TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
| if (entry->context_stack.size() > 0) { |
| return entry->context_stack.top(); |
| } |
| ICHECK(allow_not_defined) |
| << "Target context required. Please set it by constructing a TargetContext"; |
| |
| return Target(); |
| } |
| |
| /********** Creation **********/ |
| |
| void TargetInternal::ConstructorDispatcher(TVMArgs args, TVMRetValue* rv) { |
| if (args.num_args == 1) { |
| const auto& arg = args[0]; |
| if (arg.IsObjectRef<Target>()) { |
| *rv = Target(arg.AsObjectRef<Target>()); |
| } else if (String::CanConvertFrom(arg)) { |
| *rv = Target(arg.operator String()); |
| } else if (arg.IsObjectRef<Map<String, ObjectRef>>()) { |
| *rv = Target(arg.operator Map<String, ObjectRef>()); |
| } else if (arg.type_code() == kTVMObjectHandle) { |
| ObjectRef obj = arg; |
| LOG(FATAL) << "TypeError: Cannot create target with type: " << obj->GetTypeKey(); |
| } else { |
| LOG(FATAL) << "TypeError: Cannot create target with type: " |
| << runtime::ArgTypeCode2Str(arg.type_code()); |
| } |
| return; |
| } else if (args.num_args == 2) { |
| if (args[0].IsObjectRef<Target>() && args[1].IsObjectRef<Target>()) { |
| Target target = args[0]; |
| Target host = args[1]; |
| *rv = Target(target, host); |
| } else { |
| LOG(FATAL) << "ValueError: Invalid type of arguments. Expect 2 Target arguments."; |
| } |
| return; |
| } |
| LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " |
| << args.num_args; |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromString(const String& tag_or_config_or_target_str) { |
| if (Optional<Target> target = TargetTag::Get(tag_or_config_or_target_str)) { |
| Target value = target.value(); |
| return runtime::ObjectInternal::MoveObjectPtr(&value); |
| } |
| if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { |
| return TargetInternal::FromConfigString(tag_or_config_or_target_str); |
| } |
| return TargetInternal::FromRawString(tag_or_config_or_target_str); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) { |
| const auto* loader = tvm::runtime::Registry::Get("target._load_config_dict"); |
| ICHECK(loader) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " |
| "if the python module is properly loaded"; |
| Optional<Map<String, ObjectRef>> config = (*loader)(config_str); |
| if (!config.defined()) { |
| throw Error(": Cannot load config dict with python JSON loader"); |
| } |
| return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) { |
| ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; |
| // Split the string by empty spaces |
| std::vector<std::string> options = SplitString(std::string(target_str), ' '); |
| std::string name = options[0]; |
| // Create the target config |
| std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}}; |
| TargetKind kind = GetTargetKind(name); |
| for (size_t iter = 1, end = options.size(); iter < end;) { |
| std::string key, value; |
| try { |
| // Parse key-value pair |
| std::string s_next = (iter + 1 < options.size()) ? options[iter + 1] : ""; |
| iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value); |
| } catch (const Error& e) { |
| throw Error(": Error when parsing target" + std::string(e.what())); |
| } |
| try { |
| // check if `key` has been used |
| if (config.count(key)) { |
| throw Error(": The key \"" + key + "\" appears more than once"); |
| } |
| config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key)); |
| } catch (const Error& e) { |
| throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
| } |
| } |
| return TargetInternal::FromConfig(config); |
| } |
| |
| ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) { |
| const String kKind = "kind"; |
| const String kTag = "tag"; |
| const String kKeys = "keys"; |
| const String kDeviceName = "device"; |
| const String kHost = "host"; |
| const String kFeatures = "features"; |
| ObjectPtr<TargetNode> target = make_object<TargetNode>(); |
| |
| ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser"; |
| |
| // parse 'kind' |
| if (config.count(kKind)) { |
| if (const auto* kind = config[kKind].as<StringObj>()) { |
| target->kind = GetTargetKind(GetRef<String>(kind)); |
| ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) |
| << "Cannot use both set_attrs_preprocessor and set_target_parser"; |
| |
| // Run JSON Parser over JSON input |
| if (target->kind->target_parser != nullptr) { |
| VLOG(9) << "TargetInternal::FromConfig - Running target_parser"; |
| config = target->kind->target_parser(config); |
| if (config.count(kFeatures)) { |
| target->features = Downcast<Map<String, ObjectRef>>(config[kFeatures]); |
| config.erase(kFeatures); |
| } |
| } |
| |
| config.erase(kKind); |
| } else { |
| throw Error(": Expect type of field \"kind\" is String, but get type: " + |
| config[kKind]->GetTypeKey()); |
| } |
| } else { |
| throw Error(": Field \"kind\" is not found"); |
| } |
| // parse "tag" |
| if (config.count(kTag)) { |
| if (const auto* tag = config[kTag].as<StringObj>()) { |
| target->tag = GetRef<String>(tag); |
| config.erase(kTag); |
| } else { |
| throw Error(": Expect type of field \"tag\" is String, but get type: " + |
| config[kTag]->GetTypeKey()); |
| } |
| } else { |
| target->tag = ""; |
| } |
| // parse "keys" |
| { |
| std::vector<String> keys; |
| bool has_user_keys = config.count(kKeys); |
| if (has_user_keys) { |
| // user provided keys |
| if (const auto* cfg_keys = config[kKeys].as<ArrayNode>()) { |
| for (const ObjectRef& e : *cfg_keys) { |
| if (const auto* key = e.as<StringObj>()) { |
| keys.push_back(GetRef<String>(key)); |
| } else { |
| throw Error( |
| ": Expect 'keys' to be an array of strings, but it " |
| "contains an element of type: " + |
| e->GetTypeKey()); |
| } |
| } |
| } else { |
| throw Error(": Expect type of field \"keys\" is Array, but get type: " + |
| config[kKeys]->GetTypeKey()); |
| } |
| } |
| // add device name |
| if (config.count(kDeviceName)) { |
| if (const auto* device = config.at(kDeviceName).as<StringObj>()) { |
| keys.push_back(GetRef<String>(device)); |
| } |
| } |
| if (!has_user_keys) { |
| // add default keys |
| for (const auto& key : target->kind->default_keys) { |
| keys.push_back(key); |
| } |
| } |
| // de-duplicate keys |
| target->keys = DeduplicateKeys(keys); |
| config.erase(kKeys); |
| } |
| // parse host |
| if (config.count(kHost)) { |
| target->host = PackedFunc(ConstructorDispatcher)(config[kHost]).AsObjectRef<Target>(); |
| config.erase(kHost); |
| } else { |
| target->host = NullOpt; |
| } |
| // parse attrs |
| std::unordered_map<String, ObjectRef> attrs; |
| for (const auto& cfg_kv : config) { |
| const String& key = cfg_kv.first; |
| const ObjectRef& value = cfg_kv.second; |
| try { |
| const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key); |
| attrs[key] = TargetInternal::ParseType(value, info); |
| } catch (const Error& e) { |
| throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
| } |
| } |
| |
| // If requested, query attributes from the device. User-specified |
| // parameters take precedence over queried parameters. |
| if (attrs.count("from_device")) { |
| int device_id = Downcast<Integer>(attrs.at("from_device")).IntValue(); |
| attrs.erase("from_device"); |
| auto device_params = QueryDevice(device_id, target.get()); |
| |
| for (const auto& kv : device_params) { |
| if (attrs.count(kv.first) == 0) { |
| attrs[kv.first] = kv.second; |
| } |
| } |
| } |
| |
| // set default attribute values if they do not exist |
| for (const auto& kv : target->kind->key2default_) { |
| if (!attrs.count(kv.first)) { |
| attrs[kv.first] = kv.second; |
| } |
| } |
| // do extra pre-processing |
| if (target->kind->preprocessor != nullptr) { |
| target->attrs = target->kind->preprocessor(Map<String, ObjectRef>(attrs)); |
| } else { |
| target->attrs = attrs; |
| } |
| |
| return target; |
| } // namespace tvm |
| |
| std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id, |
| const TargetNode* target) { |
| std::unordered_map<String, ObjectRef> output; |
| |
| Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), device_id}; |
| |
| auto api = runtime::DeviceAPI::Get(device, true); |
| if (!api) { |
| LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id " |
| << device_id << ", but support for this runtime wasn't enabled at compile-time. " |
| << "Using default target parameters."; |
| return output; |
| } |
| |
| TVMRetValue ret; |
| api->GetAttr(device, runtime::kExist, &ret); |
| bool device_exists = ret; |
| if (!device_exists) { |
| ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name |
| << " from device_id " << device_id << ", but device_id " << device_id |
| << " doesn't exist. Using default target parameters."; |
| return output; |
| } |
| |
| for (const auto& kv : target->kind->key2vtype_) { |
| const String& key = kv.first; |
| const TargetKindNode::ValueTypeInfo& type_info = kv.second; |
| |
| TVMRetValue ret; |
| api->GetTargetProperty(device, key, &ret); |
| |
| switch (ret.type_code()) { |
| case kTVMNullptr: |
| // Nothing returned for this parameter, move on to the next one. |
| continue; |
| |
| case kTVMArgInt: |
| if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| output[key] = Integer(static_cast<int64_t>(ret)); |
| } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
| output[key] = Bool(static_cast<bool>(ret)); |
| } else { |
| LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key |
| << "', but received integer from device api"; |
| } |
| break; |
| |
| case kTVMStr: |
| ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) |
| << "Expected " << type_info.type_key << " parameter for attribute '" << key |
| << "', but received string from device api"; |
| output[key] = String(ret.operator std::string()); |
| break; |
| |
| default: |
| LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key |
| << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; |
| break; |
| } |
| } |
| |
| return output; |
| } |
| |
| /********** Registry **********/ |
| |
| TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher); |
| TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); |
| TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); |
| TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); |
| TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); |
| TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); |
| TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { |
| return target->GetTargetDeviceType(); |
| }); |
| TVM_REGISTER_GLOBAL("target.TargetGetFeature") |
| .set_body_typed([](const Target& target, const String& feature_key) { |
| return target->GetFeature<ObjectRef>(feature_key); |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) { |
| p->stream << Downcast<Target>(obj)->str(); |
| }); |
| |
| } // namespace tvm |