blob: 43fea7897f243bda4e5c27c86ad1b38337a38523 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/transform/normalize.cc
* \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no
* nesting and hence the AST is in ANF), and all struct_info_ of expressions are
* available.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
namespace tvm {
namespace relax {
// TODO(@altanh): LCA binding lifting
class NormalizeMutator : public ExprMutatorBase {
public:
NormalizeMutator() { builder_ = BlockBuilder::Create(std::nullopt); }
Expr VisitExpr(const Expr& expr) override {
return builder_->Normalize(ExprMutatorBase::VisitExpr(expr));
}
Expr VisitExpr_(const FunctionNode* op) final {
Expr body = this->VisitWithNewScope(op->body, op->params);
if (body.same_as(op->body)) {
return ffi::GetRef<Expr>(op);
} else {
return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs);
}
}
Expr VisitExpr_(const IfNode* op) final {
Expr guard = this->VisitExpr(op->cond);
Expr true_b = this->VisitWithNewScope(op->true_branch);
Expr false_b = this->VisitWithNewScope(op->false_branch);
if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
return ffi::GetRef<Expr>(op);
} else {
return If(guard, true_b, false_b, op->span);
}
}
Expr VisitWithNewScope(const Expr& expr, ffi::Optional<ffi::Array<Var>> params = std::nullopt) {
builder_->BeginBindingBlock();
if (params.defined()) {
builder_->BeginScope(params);
} else {
builder_->BeginInnerScope();
}
Expr ret = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
ret = SeqExpr({prologue}, ret);
}
builder_->EndScope();
return ret;
}
Expr VisitExpr_(const SeqExprNode* op) final {
bool all_blocks_unchanged = true;
ffi::Array<BindingBlock> blocks;
for (auto block : op->blocks) {
BindingBlock new_block = this->VisitBindingBlock(block);
if (!new_block->bindings.empty()) {
blocks.push_back(new_block);
}
all_blocks_unchanged &= block.same_as(new_block);
}
builder_->BeginBindingBlock();
Expr body = this->VisitExpr(op->body);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
blocks.push_back(prologue);
all_blocks_unchanged = false;
}
if (all_blocks_unchanged && body.same_as(op->body)) {
return ffi::GetRef<Expr>(op);
} else {
return SeqExpr(blocks, body);
}
}
BindingBlock VisitBindingBlock(const BindingBlock& block) final {
BindingBlock ret;
if (const auto* node = block.as<DataflowBlockNode>()) {
ret = VisitBindingBlock_(node);
} else if (const auto* node = block.as<BindingBlockNode>()) {
ret = VisitBindingBlock_(node);
} else {
TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey();
}
return ret;
}
BindingBlock VisitBindingBlock_(const BindingBlockNode* block) {
builder_->BeginBindingBlock();
for (Binding binding : block->bindings) {
this->VisitBinding(binding);
}
return builder_->EndBlock();
}
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
builder_->BeginDataflowBlock();
for (Binding binding : block->bindings) {
this->VisitBinding(binding);
}
return builder_->EndBlock();
}
void VisitBinding(const Binding& binding) {
if (const auto* node = binding.as<VarBindingNode>()) {
VisitBinding_(node);
} else if (const auto* node = binding.as<MatchCastNode>()) {
VisitBinding_(node);
} else {
TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey();
}
}
void VisitBinding_(const VarBindingNode* binding) {
Expr new_value = this->VisitExpr(binding->value);
if (!binding->var->struct_info_.defined()) {
UpdateStructInfo(binding->var, GetStructInfo(new_value));
}
if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(ffi::GetRef<VarBinding>(binding));
} else {
builder_->EmitNormalized(VarBinding(binding->var, new_value));
}
}
void VisitBinding_(const MatchCastNode* binding) {
Expr new_value = this->VisitExpr(binding->value);
if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(ffi::GetRef<MatchCast>(binding));
} else {
builder_->EmitNormalized(
MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info));
}
}
private:
/*! \brief Internal block builder to emit bindings during rewriting. */
BlockBuilder builder_;
}; // namespace relax
Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); }
class GlobalVarNormalizer : private ExprMutator {
public:
static IRModule Normalize(const IRModule& m) {
GlobalVarNormalizer renamer(m);
return renamer.RenameModule();
}
private:
explicit GlobalVarNormalizer(const IRModule& m) : ExprMutator(), module_(m) {}
using ExprMutator::VisitExpr_;
IRModule RenameModule() {
if (!NeedRename()) {
return std::move(module_);
}
// Step 1. Add public functions (functions with global_symbol attributes)
AddPublicFunctions();
// Step 2. Rename private functions
AddPrivateFunctions();
// Step 3. Substitute global vars in functions
for (auto [gvar, func] : module_->functions) {
if (!func->IsInstance<FunctionNode>()) {
continue;
}
auto new_func = Downcast<BaseFunc>(this->VisitExpr(func));
builder_->UpdateFunction(gvar_map_[gvar], new_func);
}
// Step 4. Update the original module (because we do not want to copy all metadata to the new
// module)
auto after_module = builder_->GetContextIRModule();
auto module_node = module_.CopyOnWrite();
module_node->functions = after_module->functions;
module_node->global_var_map_ = after_module->global_var_map_;
return std::move(module_);
}
/*! \brief Check if any function needs to be renamed. */
bool NeedRename() {
for (const auto& [gvar, func] : module_->functions) {
auto global_symbol = func->GetAttr<ffi::String>("global_symbol");
if (global_symbol && global_symbol.value() != gvar->name_hint) {
return true;
}
}
return false;
}
/*! \brief Add public functions to the builder, and update the name supplier. */
void AddPublicFunctions() {
for (const auto& [gvar, func] : module_->functions) {
auto global_symbol = func->GetAttr<ffi::String>("global_symbol");
if (!global_symbol) {
continue;
}
auto global_symbol_value = global_symbol.value();
TVM_FFI_ICHECK(!name_supply_->ContainsName(global_symbol_value))
<< "IRModule contains duplicate global symbol: " << global_symbol_value;
name_supply_->ReserveName(global_symbol_value);
auto new_gvar = builder_->AddFunction(func, global_symbol_value);
gvar_map_.Set(gvar, new_gvar);
}
}
/*!
* \brief Add private functions to the builder with names provided by name supplier. Renaming may
* happen if the name of any function conflicts with the name of a public function.
*/
void AddPrivateFunctions() {
for (auto [gvar, func] : module_->functions) {
auto global_symbol = func->GetAttr<ffi::String>("global_symbol");
if (global_symbol) {
continue;
}
auto new_name = name_supply_->FreshName(gvar->name_hint, false, false);
auto new_gvar = builder_->AddFunction(func, new_name);
gvar_map_.Set(gvar, new_gvar);
}
}
Expr VisitExpr_(const GlobalVarNode* op) final {
TVM_FFI_ICHECK(gvar_map_.count(ffi::GetRef<GlobalVar>(op)));
return gvar_map_[ffi::GetRef<GlobalVar>(op)];
}
IRModule module_;
NameSupply name_supply_;
ffi::Map<GlobalVar, GlobalVar> gvar_map_;
};
namespace transform {
Pass Normalize() {
auto pass_func = [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(Normalize(f));
};
return CreateFunctionPass(pass_func, 1, "Normalize", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.Normalize", Normalize);
}
Pass NormalizeGlobalVar() {
auto pass_func = [=](IRModule mod, PassContext pc) {
return GlobalVarNormalizer::Normalize(mod);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"NormalizeGlobalVar",
/*required=*/{});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.NormalizeGlobalVar", NormalizeGlobalVar);
}
} // namespace transform
} // namespace relax
} // namespace tvm