blob: 5fc01e1517608afd1133fcb8dae140a632baf8ce [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file to_basic_block_normal_form.cc
*
* \brief Turn an expression to the basic normal form.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/support/logging.h>
#include "../../support/arena.h"
#include "../analysis/dependency_graph.h"
#include "let_list.h"
#include "pass_util.h"
namespace tvm {
namespace relay {
Expr ToBasicBlockNormalFormAux(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* The scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
* We also record the set of expressions whose scope is lifted.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}
IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;
tvm::Map<GlobalVar, Function> updates;
auto funcs = mod->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second);
updates.Set(it.first, Downcast<Function>(ret));
}
for (auto pair : updates) {
mod->Add(pair.first, pair.second, true);
}
DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;
return mod;
}
bool BasicBlockNormalFormCheck(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
for (auto expr : scopes.second) {
LOG(FATAL) << "The expression below violates the basic block normal form in that "
<< "its scope should be lifted:\n"
<< expr;
}
return scopes.second.size() == 0;
}
TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form")
.set_body_typed(BasicBlockNormalFormCheck);
namespace transform {
Pass ToBasicBlockNormalForm() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::ToBasicBlockNormalForm(m); };
return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm", {});
}
TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm")
.set_body_typed(ToBasicBlockNormalForm);
} // namespace transform
} // namespace relay
} // namespace tvm