blob: db17d1862846d81abf3d24e60fc8b37d5204ea08 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/target/metadata_utils.cc
* \brief Defines utility functions and classes for emitting metadata.
*/
#include "metadata_utils.h"
namespace tvm {
namespace codegen {
namespace metadata {
std::string AddressFromParts(const std::vector<std::string>& parts) {
std::stringstream ss;
for (unsigned int i = 0; i < parts.size(); ++i) {
if (i > 0) {
ss << "_";
}
ss << parts[i];
}
return ss.str();
}
DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector<DiscoveredArray>* queue) : queue_{queue} {}
void DiscoverArraysVisitor::Visit(const char* key, double* value) {}
void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {}
void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {}
void DiscoverArraysVisitor::Visit(const char* key, int* value) {}
void DiscoverArraysVisitor::Visit(const char* key, bool* value) {}
void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {}
void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {}
void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {}
void DiscoverArraysVisitor::Visit(const char* key, void** value) {}
void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) {
address_parts_.push_back(key);
if (value->as<runtime::metadata::MetadataBaseNode>() != nullptr) {
auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
const runtime::metadata::MetadataArrayNode* arr =
value->as<runtime::metadata::MetadataArrayNode>();
if (arr != nullptr) {
for (unsigned int i = 0; i < arr->array.size(); i++) {
ObjectRef o = arr->array[i];
if (o.as<runtime::metadata::MetadataBaseNode>() != nullptr) {
std::stringstream ss;
ss << i;
address_parts_.push_back(ss.str());
runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(o);
ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
address_parts_.pop_back();
}
}
queue_->push_back(std::make_tuple(AddressFromParts(address_parts_),
Downcast<runtime::metadata::MetadataArray>(metadata)));
} else {
ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
}
}
address_parts_.pop_back();
}
void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {}
void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {}
bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) {
VLOG(2) << "DiscoverType " << type_key;
auto position_it = type_key_to_position_.find(type_key);
if (position_it != type_key_to_position_.end()) {
return false;
}
queue_->emplace_back(tvm::runtime::metadata::MetadataBase());
type_key_to_position_[type_key] = queue_->size() - 1;
return true;
}
void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) {
auto position_it = type_key_to_position_.find(md->GetTypeKey());
ICHECK(position_it != type_key_to_position_.end())
<< "DiscoverInstance requires that DiscoverType has already been called: type_key="
<< md->GetTypeKey();
int queue_position = (*position_it).second;
if (!(*queue_)[queue_position].defined() && md.defined()) {
VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md;
(*queue_)[queue_position] = md;
}
}
void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) {
ICHECK_NOTNULL(value->as<runtime::metadata::MetadataBaseNode>());
auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
const runtime::metadata::MetadataArrayNode* arr =
value->as<runtime::metadata::MetadataArrayNode>();
if (arr == nullptr) {
VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey();
ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
DiscoverType(metadata->GetTypeKey());
DiscoverInstance(metadata);
return;
}
if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) {
return;
}
bool needs_instance = DiscoverType(arr->type_key);
for (unsigned int i = 0; i < arr->array.size(); i++) {
tvm::runtime::metadata::MetadataBase o =
Downcast<tvm::runtime::metadata::MetadataBase>(arr->array[i]);
if (needs_instance) {
DiscoverInstance(o);
needs_instance = false;
}
ReflectionVTable::Global()->VisitAttrs(o.operator->(), this);
}
}
void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) {
ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
DiscoverType(metadata->GetTypeKey());
DiscoverInstance(metadata);
}
} // namespace metadata
} // namespace codegen
} // namespace tvm