blob: adb757b9de0cdcd42da50456d4b318e1664b4c5d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file to_a_normal_form.cc
*
* \brief Turn implicit sharing into observable sharing.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/support/logging.h>
#include "../../support/arena.h"
#include "../analysis/dependency_graph.h"
#include "let_list.h"
#include "pass_util.h"
namespace tvm {
namespace relay {
Scope LCA(Scope lhs, Scope rhs) {
while (lhs != rhs) {
if (lhs->level > rhs->level) {
lhs = lhs->parent;
} else if (lhs->level < rhs->level) {
rhs = rhs->parent;
} else {
lhs = lhs->parent;
rhs = rhs->parent;
}
}
return lhs;
}
std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
NodeScopeMap expr_scope;
ExprSet lifted_exprs;
std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
for (auto expr_node : dg.expr_node) {
node_to_expr[expr_node.second] = expr_node.first;
}
bool global_scope_used = false;
Scope global_scope = std::make_shared<ScopeNode>();
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->parents.head;
Scope s;
if (iit == nullptr) {
CHECK(!global_scope_used);
s = global_scope;
global_scope_used = true;
} else {
s = expr_scope.at(iit->value);
const auto original_s = s;
iit = iit->next;
for (; iit != nullptr; iit = iit->next) {
s = LCA(s, expr_scope.at(iit->value));
}
if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
// filter out exprs whose scope do not matter
Expr expr = node_to_expr[n];
if (!expr.as<OpNode>()) {
lifted_exprs.insert(expr);
}
}
}
if (n->new_scope) {
auto child_scope = std::make_shared<ScopeNode>(s);
expr_scope.insert({n, child_scope});
} else {
expr_scope.insert({n, s});
}
}
CHECK(global_scope_used);
return std::make_pair(expr_scope, lifted_exprs);
}
Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) {
Fill fi(dg, node_scope, nullptr);
return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
}
// For basic block normal form, bind expressions only if the original expression's scope
// should be lifted
Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
NodeScopeMap* node_scope, ExprSet* lifted) {
Fill fi(dg, node_scope, lifted);
auto var = fi.VisitExpr(e);
return fi.GetScope(e)->let_list->Get(var);
}
Scope Fill::GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); }
Scope Fill::GetSubScope(const Expr& e, size_t i) {
DependencyGraph::Node* n = dg_.expr_node.at(e);
auto h = n->children.head;
while (i != 0) {
CHECK(h);
--i;
h = h->next;
}
CHECK(h);
return node_scope_->at(h->value);
}
Expr Fill::VisitExpr(const Expr& e, const Var& v) {
if (memo.count(e) == 0) {
memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
} else if (v.defined()) {
GetScope(e)->let_list->Push(v, memo.at(e));
}
auto ret = memo.at(e);
// if no include_set is specified, every expression should be atomic.
if (include_set_ == nullptr) CHECK(IsAtomic(ret));
return ret;
}
Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
Expr Fill::Atomic(const Expr& e, const Var& v) {
return v.defined() ? GetScope(e)->let_list->Push(v, e) : e;
}
// Bind expression `now` to var `v` if the original expression is in the include set, or if
// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly
Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) {
Var var = v.defined() ? v : Var(String("x"), Type());
bool not_included = include_set_ && include_set_->find(orig) == include_set_->end();
if (!v.defined() && not_included) {
return now;
} else {
return GetScope(orig)->let_list->Push(var, now);
}
}
Expr Fill::VisitExpr_(const CallNode* c, const Var& v) {
Expr e = GetRef<Expr>(c);
std::vector<Expr> args;
for (const auto& a : c->args) {
args.push_back(VisitExpr(a));
}
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
}
Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) {
Expr e = GetRef<Expr>(t);
std::vector<Expr> fields;
for (const auto& a : t->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, Tuple(fields), v);
}
Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) {
Expr e = GetRef<Expr>(t);
return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
}
Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) {
Expr e = GetRef<Expr>(r);
return Compound(e, RefCreate(VisitExpr(r->value)), v);
}
Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) {
Expr e = GetRef<Expr>(r);
return Compound(e, RefRead(VisitExpr(r->ref)), v);
}
Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) {
Expr e = GetRef<Expr>(r);
return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
}
Expr Fill::VisitExpr_(const IfNode* i, const Var& v) {
Expr e = GetRef<Expr>(i);
Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)),
GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch)));
return Compound(e, ret, v);
}
Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) {
Expr e = GetRef<Expr>(f);
Expr ret;
if (f->HasNonzeroAttr(attr::kPrimitive)) {
ret = e;
} else {
ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type,
f->type_params, f->attrs);
}
return Compound(e, ret, v);
}
Expr Fill::VisitExpr_(const LetNode* l, const Var& v) {
Expr e = GetRef<Expr>(l);
VisitExpr(l->value, l->var);
Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body));
return Compound(e, ret, v);
}
Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) {
Expr e = GetRef<Expr>(c);
return Compound(e, e, v);
}
Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) {
Expr e = GetRef<Expr>(vn);
return Atomic(e, v);
}
Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) {
GlobalVar gv = GetRef<GlobalVar>(gvn);
return Atomic(gv, v);
}
Expr Fill::VisitExpr_(const OpNode* op, const Var& v) {
Expr e = GetRef<Expr>(op);
return Atomic(e, v);
}
Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) {
Expr e = GetRef<Expr>(c);
return Atomic(e, v);
}
Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) {
Expr e = GetRef<Expr>(m);
Expr data = VisitExpr(m->data);
std::vector<Clause> clauses;
for (const Clause& c : m->clauses) {
clauses.push_back(
Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs))));
}
return Compound(e, Match(data, clauses, m->complete), v);
}
IRModule ToANormalForm(const IRModule& m) {
DLOG(INFO) << "ToANF:" << std::endl << m;
tvm::Map<GlobalVar, Function> updates;
auto funcs = m->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second);
CHECK_EQ(FreeVars(ret).size(), 0)
<< AsText(ret) << "should not has free vars: " << FreeVars(ret);
updates.Set(it.first, Downcast<Function>(ret));
}
for (auto pair : updates) {
m->Add(pair.first, pair.second, true);
}
DLOG(INFO) << "ToANF: transformed" << std::endl << m;
return m;
}
namespace transform {
Expr ToANormalForm(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
*
* To make this more principled,
* we always determine the scope of parent before determining the scope of children.
*
* So we calculate all the dependency between nodes.
*/
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* In order to model new subscopes created by lambda, if else and pattern matching,
* we also assign scope to edge as well.
* The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
*
* So, the scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
*
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(e, dg, &scopes.first);
}
Pass ToANormalForm() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); };
return CreateModulePass(pass_func, 1, "ToANormalForm", {});
}
TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
return ToANormalForm();
});
TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) {
return ToANormalForm(e);
});
} // namespace transform
} // namespace relay
} // namespace tvm