blob: 66bce0f6b8824255f075a91467a7971af196f06f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file module.cc
* \brief The global module in Relay.
*/
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into relay's analysis module to verify correctness.
#include <tvm/parser/parser.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <fstream>
#include <sstream>
#include <unordered_set>
namespace tvm {
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
n->constructor_tag_map_ = {};
n->import_set_ = std::move(import_set);
for (const auto& kv : n->functions) {
// set global var map
CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->name_hint;
n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}
data_ = std::move(n);
}
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}
void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
using KV = std::pair<std::string, ObjectRef>;
// hash the functions.
std::vector<KV> temp;
auto reduce_temp = [&]() {
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
[](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
hash_reduce(static_cast<uint64_t>(temp.size()));
// hash the content
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce(temp[i].first);
hash_reduce(temp[i].second);
}
};
for (const auto& kv : this->functions) {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
temp.clear();
for (const auto& kv : this->type_definitions) {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
}
bool IRModuleNode::ContainGlobalVar(const String& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}
GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) {
std::ostringstream msg;
msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
<< "candidates are: [";
int counter = 0;
for (auto kv : global_var_map_) {
if (counter++ != 0) {
msg << ", ";
}
msg << "\"" << kv.first << "\"";
}
msg << "]";
LOG(FATAL) << msg.str();
}
return (*it).second;
}
tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
std::vector<GlobalVar> global_vars;
for (const auto& pair : global_var_map_) {
global_vars.push_back(pair.second);
}
return tvm::Array<GlobalVar>(global_vars);
}
GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
return (*it).second;
}
Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}
LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}
tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
global_type_vars.push_back(pair.second);
}
return tvm::Array<GlobalTypeVar>(global_type_vars);
}
template <typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l);
for (const T& t : r) {
ret.push_back(t);
}
return ret;
}
// helper function to run type check
relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) {
auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
// Type check the item before we add it to the module.
auto fv = relay::FreeVars(func);
auto ftv = relay::FreeTypeVars(func, mod);
CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
<< " in function: " << AsText(func, false);
CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
<< " in function: " << AsText(func, false);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
}
void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
BaseFunc checked_func = f;
if (auto* ptr = f.as<relay::FunctionNode>()) {
checked_func = RunTypeCheck(GetRef<IRModule>(this), var, GetRef<relay::Function>(ptr));
}
Type type = checked_func->checked_type();
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
CHECK(update) << "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
AddUnchecked(var, checked_func);
}
void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
CHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
}
void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
}
}
void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
CHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
}
void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) {
this->Add(var, func, true);
}
void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddTypeDef(var, type, true);
}
void IRModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->erase(var);
auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->erase(var->name_hint);
}
BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
CHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
return (*it).second;
}
BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
return (*it).second;
}
TypeData IRModuleNode::LookupTypeDef(const String& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupTypeDef(id);
}
Constructor IRModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag;
return (*it).second;
}
void IRModuleNode::Update(const IRModule& mod) {
// add functions and type defs. we add them unchecked first, so all definitions
// can reference each other, independent of the order in which they were defined.
for (auto pair : mod->functions) {
this->AddUnchecked(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->AddTypeDefUnchecked(pair.first, pair.second);
}
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->UpdateTypeDef(pair.first, pair.second);
}
}
IRModule IRModule::FromExpr(const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
std::string gv_name = "main";
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}
} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
return mod;
}
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
std::string file_contents{std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>()};
auto mod_to_import = IRModule::FromText(file_contents, path);
Update(mod_to_import);
}
}
void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
this->Import(std_path + "/" + path);
}
std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
IRModule IRModule::FromText(const String& text, const String& source_path) {
return tvm::parser::ParseModule(source_path, text);
}
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return IRModule(funcs, types, {});
});
TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) {
IRModule mod = args[0];
GlobalVar var = args[1];
ObjectRef val = args[2];
bool update = args[3];
CHECK(val->IsInstance<RelayExprNode>());
if (val->IsInstance<BaseFuncNode>()) {
mod->Add(var, Downcast<BaseFunc>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
mod_copy = relay::transform::EtaExpand(
/* expand_constructor */ false,
/* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<relay::Function>(func), update);
} else {
auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
mod->Add(var, func, update);
}
*ret = mod;
});
TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
return mod->LookupTag(tag);
});
TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
.set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return IRModule::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
mod->Update(from);
});
TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
.set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
mod->Import(path);
});
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
mod->ImportFromStd(path);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
} // namespace tvm