blob: 35f1e49e595ded3c4965dbe58f09000b76dbcdc4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/transform.cc
* \brief Infrastructure for transformation passes.
*/
#include <dmlc/thread_local.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/device_api.h>
#include <chrono>
#include <iomanip>
#include <stack>
#include <unordered_set>
#include "../runtime/regex.h"
namespace tvm {
namespace transform {
using tvm::ReprPrinter;
using tvm::ffi::Any;
using tvm::ffi::PackedArgs;
TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(ffi::make_object<PassContextNode>());
}
};
/*! \brief Thread local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry> RelayPassContextThreadLocalStore;
void PassContext::EnterWithScope() {
InstrumentEnterPassContext();
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
ICHECK(!entry->context_stack.empty());
ICHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
InstrumentExitPassContext();
}
PassContext PassContext::Current() {
PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
} else {
return entry->default_context;
}
}
// linearly scan the pass array to match pass_name
bool PassArrayContains(const ffi::Array<ffi::String>& pass_array, const std::string& pass_name) {
for (auto x : pass_array) {
if (x == pass_name) return true;
}
return false;
}
bool PassContext::PassEnabled(const PassInfo& info) const {
if (PassArrayContains(operator->()->disabled_pass, info->name)) {
return false;
}
if (PassArrayContains(operator->()->required_pass, info->name)) {
return true;
}
return operator->()->opt_level >= info->opt_level;
}
class PassConfigManager {
public:
void Register(std::string key, ffi::String value_type_str,
std::function<ffi::Any(ffi::Any)> legalization) {
ICHECK_EQ(key2vtype_.count(key), 0U);
ValueTypeInfo info;
info.type_str = value_type_str;
info.legalization = legalization;
key2vtype_[key] = info;
}
// Trying to validate and legalize a config.
void Legalize(ffi::Map<ffi::String, ffi::Any>* config) {
std::vector<std::pair<std::string, ffi::Any>> update;
for (auto [key, value] : *config) {
auto it = key2vtype_.find(key);
if (it == key2vtype_.end()) {
std::ostringstream os;
os << "AttributeError: Invalid config option \'" << key << "\' candidates are:";
int counter = 0;
for (const auto& [key, value] : key2vtype_) {
os << ' ';
if (counter++ != 0) os << ',';
os << key;
}
LOG(FATAL) << os.str();
}
const auto& info = it->second;
ICHECK(value != nullptr) << "AttributeError: " << key << " is None";
ICHECK(info.legalization) << "AttributeError: "
<< "Config option \'" << key
<< "\' was defined without a legalization function.";
auto legalized = info.legalization(value);
if (!legalized.same_as(value)) {
update.emplace_back(key, legalized);
}
}
for (auto&& kv : update) {
config->Set(kv.first, kv.second);
}
}
ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::String>> ListConfigs() {
ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::String>> configs;
for (const auto& kv : key2vtype_) {
ffi::Map<ffi::String, ffi::String> metadata;
metadata.Set("type", kv.second.type_str);
configs.Set(kv.first, metadata);
}
return configs;
}
static PassConfigManager* Global() {
static auto* inst = new PassConfigManager();
return inst;
}
private:
struct ValueTypeInfo {
std::string type_str;
std::function<ffi::Any(ffi::Any)> legalization;
};
std::unordered_map<std::string, ValueTypeInfo> key2vtype_;
};
void PassContext::RegisterConfigOption(const char* key, ffi::String value_type_str,
std::function<ffi::Any(ffi::Any)> legalization) {
PassConfigManager::Global()->Register(key, value_type_str, legalization);
}
ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::String>> PassContext::ListConfigs() {
return PassConfigManager::Global()->ListConfigs();
}
PassContext PassContext::Create() { return PassContext(ffi::make_object<PassContextNode>()); }
namespace {
struct ClearOnError {
ffi::Array<instrument::PassInstrument>* instruments{nullptr};
~ClearOnError() {
if (instruments) {
LOG(INFO) << "Pass instrumentation enter/exti failed.";
LOG(INFO) << "Disabling pass instrumentation.";
instruments->clear();
}
}
};
struct ExitContextOnError {
std::vector<instrument::PassInstrument> successes;
~ExitContextOnError() {
for (auto it = successes.rbegin(); it != successes.rend(); it++) {
LOG(INFO) << (*it)->name << " exiting PassContext ...";
(*it)->ExitPassContext();
LOG(INFO) << (*it)->name << " exited PassContext.";
}
}
};
} // namespace
void PassContext::InstrumentEnterPassContext() {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->instruments.defined()) {
ClearOnError clear_context{&pass_ctx_node->instruments};
ExitContextOnError exit_context;
for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
pi->EnterPassContext();
exit_context.successes.push_back(pi);
}
exit_context.successes.clear();
clear_context.instruments = nullptr;
}
}
namespace {
struct ExitPassSuccesses {
~ExitPassSuccesses() {
if (all_initialized) {
return;
}
LOG(INFO) << "Pass instrumentation entering pass context failed.";
LOG(INFO) << "Disable pass instrumentation.";
instruments->clear();
for (auto it = successes.rbegin(); it != successes.rend(); it++) {
LOG(INFO) << (*it)->name << " exiting PassContext ...";
(*it)->ExitPassContext();
LOG(INFO) << (*it)->name << " exited PassContext.";
}
}
bool all_initialized{false};
std::vector<instrument::PassInstrument> successes;
ffi::Array<instrument::PassInstrument>* instruments{nullptr};
};
} // namespace
void PassContext::InstrumentExitPassContext() {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->instruments.defined()) {
ClearOnError clear_context{&pass_ctx_node->instruments};
for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
pi->ExitPassContext();
}
clear_context.instruments = nullptr;
}
}
bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const {
auto pass_ctx_node = this->operator->();
if (!pass_ctx_node->instruments.defined()) {
return true;
}
const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name);
bool should_run = true;
if (!pass_required) {
for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
should_run &= pi->ShouldRun(ir_module, pass_info);
}
}
if (should_run) {
for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
pi->RunBeforePass(ir_module, pass_info);
}
}
return should_run;
}
void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->instruments.defined()) {
for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
pi->RunAfterPass(ir_module, pass_info);
}
}
}
IRModule Pass::operator()(IRModule mod) const {
return this->operator()(std::move(mod), PassContext::Current());
}
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
const PassInfo& pass_info = node->Info();
if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) {
DLOG(INFO) << "Skipping pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;
return mod;
}
IRModule ret;
if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
} else {
ret = node->operator()(std::move(mod), pass_ctx);
}
pass_ctx.InstrumentAfterPass(ret, pass_info);
return ret;
}
IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
const PassContext& pass_ctx) {
size_t before_pass_hash = tvm::StructuralHash()(mod);
IRModule copy_mod = mod;
IRModule ret = node->operator()(mod, pass_ctx);
size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
if (before_pass_hash != after_pass_hash) {
// The chance of getting a hash conflict between a module and the same module but mutated
// must be very low.
LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name;
}
return ret;
}
/*!
* \brief Module-level passes are designed to implement global
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
* at this level have the full control of a given Relax program including
* addition and deletion of functions.
*/
class ModulePassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The pass function sketches the real optimization. For example,
* we may need to perform dead code elimination on the module level. We could
* implement the algorithm in the `pass_func` and let it run on a module. It
* will then remove the dead code including the unused functions in the module.
*/
std::function<IRModule(IRModule, PassContext)> pass_func;
ModulePassNode() = default;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ModulePassNode>().def_ro("pass_info", &ModulePassNode::pass_info);
}
/*!
* \brief Run a module pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.ModulePass", ModulePassNode, PassNode);
};
class ModulePass : public Pass {
public:
ModulePass(std::function<IRModule(IRModule, PassContext)> pass_func, PassInfo pass_info);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ModulePass, Pass, ModulePassNode);
};
PassInfo::PassInfo(int opt_level, ffi::String name, tvm::ffi::Array<ffi::String> required,
bool traceable) {
auto pass_info = ffi::make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
pass_info->required = std::move(required);
pass_info->traceable = std::move(traceable);
data_ = std::move(pass_info);
}
ModulePass::ModulePass(std::function<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = ffi::make_object<ModulePassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
// Module -> Module optimizations.
IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
DiagnosticContext previous = DiagnosticContext::Default(mod);
if (pass_ctx->diag_ctx) {
DiagnosticContext tmp = pass_ctx->diag_ctx.value();
pass_ctx->diag_ctx = previous;
previous = tmp;
} else {
pass_ctx->diag_ctx = previous;
}
ICHECK(pass_ctx->diag_ctx)
<< "The diagnostic context was set at the top of this block this is a bug.";
const PassInfo& pass_info = Info();
ICHECK(mod.defined()) << "The input module must be set.";
VLOG_CONTEXT << pass_info->name;
VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
mod = pass_func(std::move(mod), pass_ctx);
ICHECK(mod.defined()) << "The return value of a module pass must be set.";
ICHECK(pass_ctx->diag_ctx)
<< "The diagnostic context was set at the top of this block this is a bug.";
pass_ctx->diag_ctx.value().Render();
pass_ctx->diag_ctx = previous;
return mod;
}
Sequential::Sequential(tvm::ffi::Array<Pass> passes, PassInfo pass_info) {
auto n = ffi::make_object<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
Sequential::Sequential(tvm::ffi::Array<Pass> passes, ffi::String name) {
auto n = ffi::make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(get());
}
void SequentialNode::ResolveDependency(const IRModule& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes.
// 3. Build a dependency graph. Probably we need to update the pass list.
LOG(FATAL) << "Pass dependency has not been resolved yet."
<< "\n";
}
Pass GetPass(const ffi::String& pass_name) {
std::optional<tvm::ffi::Function> f;
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = tvm::ffi::Function::GetGlobal(pass_name);
} else {
f = tvm::ffi::Function::GetGlobal("transform." + pass_name);
}
ICHECK(f.has_value()) << "Cannot use " << pass_name << " to create the pass";
return (*f)().cast<Pass>();
}
// TODO(zhiics): we currently only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
VLOG(0) << "Running pass " << pass->Info()->name;
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!pass_ctx.PassEnabled(pass_info)) {
VLOG(0) << "skipping disabled pass '" << pass_info->name << "'";
continue;
}
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
Pass CreateModulePass(std::function<IRModule(IRModule, PassContext)> pass_func, int opt_level,
ffi::String name, tvm::ffi::Array<ffi::String> required, bool traceable) {
PassInfo pass_info = PassInfo(opt_level, name, required, traceable);
return ModulePass(std::move(pass_func), pass_info);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("transform.PassInfo",
[](int opt_level, ffi::String name, tvm::ffi::Array<ffi::String> required,
bool traceable) { return PassInfo(opt_level, name, required, traceable); })
.def_packed("transform.Info", [](ffi::PackedArgs args, ffi::Any* ret) {
Pass pass = args[0].cast<Pass>();
*ret = pass->Info();
});
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass - ";
p->stream << "pass name: " << node->name;
p->stream << ", opt_level: " << node->opt_level;
if (node->required.empty()) {
p->stream << ", required passes: []\n";
} else {
p->stream << ", required passes: ["
<< "\n";
for (const auto& it : node->required) {
p->stream << it << ", ";
}
p->stream << "]\n";
}
});
TVM_FFI_STATIC_INIT_BLOCK() {
PassContextNode::RegisterReflection();
PassInfoNode::RegisterReflection();
SequentialNode::RegisterReflection();
ModulePassNode::RegisterReflection();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("transform.MakeModulePass",
[](ffi::TypedFunction<IRModule(ffi::RValueRef<IRModule>, PassContext)> pass_func,
PassInfo pass_info) {
auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) {
return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
};
return ModulePass(wrapped_pass_func, pass_info);
})
.def("transform.RunPass",
[](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name << " at the optimization level "
<< info->opt_level;
});
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("transform.Sequential", [](ffi::PackedArgs args, ffi::Any* ret) {
auto passes = args[0].cast<tvm::ffi::Array<Pass>>();
int opt_level = args[1].cast<int>();
std::string name = args[2].cast<std::string>();
auto required = args[3].cast<tvm::ffi::Array<ffi::String>>();
bool traceable = args[4].cast<bool>();
PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable);
*ret = Sequential(passes, pass_info);
});
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name << " at the optimization level "
<< info->opt_level << ". ";
p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) {
const PassInfo pass_info = it->Info();
p->stream << pass_info->name << " ";
}
p->stream << "]";
});
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"transform.PassContext",
[](int opt_level, ffi::Array<ffi::String> required, ffi::Array<ffi::String> disabled,
ffi::Array<instrument::PassInstrument> instruments,
ffi::Optional<ffi::Map<ffi::String, ffi::Any>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->instruments = std::move(instruments);
if (config.defined()) {
pctx->config = config.value();
}
PassConfigManager::Global()->Legalize(&(pctx->config));
return pctx;
});
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: "
<< "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\trequired passes: " << node->required_pass << "\n";
p->stream << "\tdisabled passes: " << node->disabled_pass << "\n";
p->stream << "\tinstruments: " << node->instruments << "\n";
p->stream << "\tconfig: " << node->config << "\n";
});
class PassContext::Internal {
public:
static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); }
static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); }
};
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("transform.GetCurrentPassContext", PassContext::Current)
.def("transform.EnterPassContext", PassContext::Internal::EnterScope)
.def("transform.ExitPassContext", PassContext::Internal::ExitScope)
.def("transform.OverrideInstruments",
[](PassContext pass_ctx, ffi::Array<instrument::PassInstrument> instruments) {
pass_ctx.InstrumentExitPassContext();
pass_ctx->instruments = instruments;
pass_ctx.InstrumentEnterPassContext();
});
}
Pass PrintIR(ffi::String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" << mod;
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("transform.PrintIR", PrintIR)
.def("transform.ListConfigs", PassContext::ListConfigs);
}
} // namespace transform
} // namespace tvm