blob: d74b95abebdb8a1c410f4e3c658ea25340451c25 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/transform.cc
* \brief Infrastructure for transformation passes.
*/
#include <dmlc/thread_local.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <stack>
#include <unordered_set>
#include "../runtime/object_internal.h"
namespace tvm {
namespace transform {
using tvm::ReprPrinter;
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() { default_context = PassContext(make_object<PassContextNode>()); }
};
/*! \brief Thread local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry> RelayPassContextThreadLocalStore;
void PassContext::EnterWithScope() {
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
PassContext PassContext::Current() {
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
} else {
return entry->default_context;
}
}
class PassConfigManager {
public:
void Register(std::string key, uint32_t value_type_index) {
CHECK_EQ(key2vtype_.count(key), 0U);
ValueTypeInfo info;
info.type_index = value_type_index;
info.type_key = runtime::Object::TypeIndex2Key(value_type_index);
key2vtype_[key] = info;
}
// Trying to validate and legalize a config.
void Legalize(Map<String, ObjectRef>* config) {
std::vector<std::pair<std::string, ObjectRef>> update;
auto* reflection = ReflectionVTable::Global();
for (auto kv : *config) {
auto it = key2vtype_.find(kv.first);
if (it == key2vtype_.end()) {
std::ostringstream os;
os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:";
int counter = 0;
for (const auto& kv : key2vtype_) {
os << ' ';
if (counter++ != 0) os << ',';
os << kv.first;
}
LOG(FATAL) << os.str();
}
const auto& info = it->second;
CHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None";
if (kv.second->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
ObjectRef converted =
reflection->CreateObject(info.type_key, Downcast<Map<String, ObjectRef>>(kv.second));
update.emplace_back(kv.first, converted);
} else {
if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) {
LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type "
<< info.type_key << " but get " << kv.second->GetTypeKey();
}
}
}
for (auto&& kv : update) {
config->Set(kv.first, kv.second);
}
}
static PassConfigManager* Global() {
static auto* inst = new PassConfigManager();
return inst;
}
private:
struct ValueTypeInfo {
std::string type_key;
uint32_t type_index;
};
std::unordered_map<std::string, ValueTypeInfo> key2vtype_;
};
void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) {
PassConfigManager::Global()->Register(key, value_type_index);
}
PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }
void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
pass_ctx_node->trace_func(module, info, is_before);
}
}
class ModulePass;
/*!
* \brief Module-level passes are designed to implement global
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
* at this level have the full control of a given Relay program including
* addition and deletion of functions.
*/
class ModulePassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The pass function sketches the real optimization. For example,
* we may need to perform dead code elimination on the module level. We could
* implement the algorithm in the `pass_func` and let it run on a module. It
* will then remove the dead code including the unused functions in the module.
*/
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;
ModulePassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
/*!
* \brief Run a module pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};
class ModulePass : public Pass {
public:
ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);
TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
};
/*!
* \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled
* passes.
*/
class SequentialNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
}
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
/*!
* \brief Check if a pass is enabled.
*
* \param info The pass information.
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool PassEnabled(const PassInfo& info) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module after resolving pass dependencies.
*
* TODO(zhiics) Build a dependency graph among the passes using provided
* metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
* PassInfo, to store the relevant information including the parent passes.
*/
void ResolveDependency(const IRModule& mod);
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that these passes are applied on.
* \param pass_ctx The context that these passes execute on.
*
* \return Return the updated module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
pass_info->required = std::move(required);
data_ = std::move(pass_info);
}
ModulePass::ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<ModulePassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
// Module -> Module optimizations.
IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
mod = pass_func(std::move(mod), pass_ctx);
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, false);
return mod;
}
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
Sequential::Sequential(tvm::Array<Pass> passes, String name) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfo(2, std::move(name), {});
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(get());
}
void SequentialNode::ResolveDependency(const IRModule& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes.
// 3. Build a dependency graph. Probably we need to update the pass list.
LOG(FATAL) << "Pass dependency has not been resolved yet."
<< "\n";
}
// linearly scan the pass array to match pass_name
inline bool PassArrayContains(const Array<runtime::String>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
if (x == pass_name) return true;
}
return false;
}
bool SequentialNode::PassEnabled(const PassInfo& info) const {
PassContext ctx = PassContext::Current();
if (PassArrayContains(ctx->disabled_pass, info->name)) {
return false;
}
if (PassArrayContains(ctx->required_pass, info->name)) {
return true;
}
return ctx->opt_level >= info->opt_level;
}
Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
CHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass";
return (*f)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, String name, tvm::Array<String> required) {
return PassInfo(opt_level, name, required);
});
TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
*ret = pass->Info();
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name;
p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: ["
<< "\n";
for (const auto& it : node->required) {
p->stream << it << ", ";
}
p->stream << "]\n";
});
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_GLOBAL("transform.MakeModulePass")
.set_body_typed([](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) { return ModulePass(pass_func, pass_info); });
TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name << " at the optimization level "
<< info->opt_level;
});
TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<runtime::String> required = args[3];
PassInfo pass_info = PassInfo(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name << " at the optimization level "
<< info->opt_level << ". ";
p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) {
const PassInfo pass_info = it->Info();
p->stream << pass_info->name << " ";
}
p->stream << "]";
});
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
TraceFunc trace_func, Optional<Map<String, ObjectRef>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
if (config.defined()) {
pctx->config = config.value();
}
PassConfigManager::Global()->Legalize(&(pctx->config));
return pctx;
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: "
<< "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\trequired passes: [";
for (const auto& it : node->required_pass) {
p->stream << it << " ";
}
p->stream << "]\n";
p->stream << "\tdisabled passes: [";
for (const auto& it : node->disabled_pass) {
p->stream << it << " ";
}
p->stream << "]\n";
p->stream << "\tconfig: " << node->config;
});
class PassContext::Internal {
public:
static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); }
static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); }
};
TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current);
TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope);
TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
Pass PrintIR(String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);
} // namespace transform
} // namespace tvm