blob: ec614d23a02e3127e92d96c486599bc8cc10e6bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file defunctionalization.cc
*
* \brief Defunctionalization for Relay IR
*
* This pass transforms a higher-order program into a first-order program with defunctionalization.
* This means that all higher order functions (i.e functions that take function arguments or return
* functions) should be transformed into a semantically equivalent first order one.
*
* This pass implements a basic typed defunctionalization method.
* All higher order functions are cloned and specialized (so that there are no type params).
* Function type arguments are encoded as datatypes and a helper `apply` function is used
* to "call" them.
*
* For example, take the following higher order program:
* fun map F y = case y of
* Nil => Nil
* | Cons(x, XS) => Cons(F z, map F XS)
* fun addone 1 = map (\x -> \x + 1) 1
*
* where `addone` is our program.
* When we call the `map` function, we see that it is a higher-order function,
* but we can clone `map ` function and specialize it with the type_params of the call.
* In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
* which we will call `incr`, and all calls to `F` in our specialized map function will use the
* helper `apply` function.
*
* After defunctionalization, we get:
* fun apply encoding arg = case encoding of
* “incr” => incr arg
* fun map’ F y = case y of
* Nil => Nil
* | Cons(x, xs) => Cons(apply F x, map’ F xs)
* fun addone 1 = map’ “incr” 1
*
* Currently, defunctionalization makes the following assumptions:
* - functions cannot return function values
* - function arguments are in two forms: identifier or a lambda abstraction
* - no functions stored in datatype
* - functions are not let binded
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
#include "../analysis/type_solver.h"
#include "../transforms/pass_util.h"
namespace tvm {
namespace relay {
// determine if type contains a FuncType
bool HasFuncType(const Type& t) {
struct FuncTypeVisitor : TypeVisitor {
bool has_func_type;
FuncTypeVisitor() : has_func_type(false) {}
void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
};
auto visitor = FuncTypeVisitor();
visitor.VisitType(t);
return visitor.has_func_type;
}
// determine if FuncType is a higher order type
bool IsHigherOrderFunc(const FuncType& t) {
bool higher_order = false;
for (auto arg : t->arg_types) {
higher_order |= HasFuncType(arg);
}
return higher_order |= HasFuncType(t->ret_type);
}
/*!
* \brief mutator for driving the Defunctionalization transformation
*/
class DefuncMutator : public ExprMutator {
public:
explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
Expr VisitExpr_(const CallNode* call) {
if (auto op = call->op.as<GlobalVarNode>()) {
CHECK_EQ(call->type_args.size(), op->checked_type().as<FuncTypeNode>()->type_params.size())
<< "all type args must be explicit";
auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
CHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated";
CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
if (!IsHigherOrderFunc(op_type)) {
// not higher order function
return ExprMutator::VisitExpr_(call);
}
// first we encode function arguments
Array<Expr> args;
for (size_t i = 0; i < call->args.size(); i++) {
auto arg = call->args[i];
auto type = op_type->arg_types[i];
if (!HasFuncType(type)) {
args.push_back(arg);
} else {
args.push_back(EncodeArg(arg, type));
}
}
auto name = op->name_hint + TypeToString(op_type);
auto gv = GlobalVar(name);
if (specialized_gv_map.count(name)) {
gv = specialized_gv_map[name];
} else {
specialized_gv_map[name] = gv;
// clone and specialize with specific type
auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op))));
auto specialized_function = Specialize(clone, call->type_args);
// change var types and change all applications to use `apply` method
auto f = Downcast<Function>(FirstifyVars(specialized_function));
mod->Add(gv, f);
}
return Call(gv, args);
} else if (auto op = call->op.as<FunctionNode>()) {
// reduction by applying vars
std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map;
for (size_t i = 0; i < op->params.size(); i++) {
var_binding_map[op->params[i]] = call->args[i];
}
auto e = Bind(op->body, var_binding_map);
return this->VisitExpr(e);
} else if (auto op = call->op.as<VarNode>()) {
// var node will be encoded as datatype
// so we need to use the `apply` helper method
auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>();
CHECK(var_original_type) << "var original type not saved in var_save_type map";
auto op_type = InstFuncType(var_original_type, call->type_args);
Array<Expr> args = {GetRef<Var>(op)};
for (auto arg : call->args) {
args.push_back(this->VisitExpr(arg));
}
return Call(GetApplyFunction(op_type), args);
}
return ExprMutator::VisitExpr_(call);
}
private:
// module
IRModule mod;
// gv + str(type) to specialized clone gv
std::unordered_map<std::string, GlobalVar> specialized_gv_map;
// str(func_type) to ADT
std::unordered_map<std::string, GlobalTypeVar> func_encoding;
// str(func_tyoe) to apply gv
std::unordered_map<std::string, GlobalVar> apply_map;
// encoded ADT handle to FuncType
std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map;
// gv to (str(func_type) to constructor encoding)
std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash,
ObjectEqual>
gv_datatype_map;
// use monotonically increasing integer to represent new constructor_name
uint64_t constructor_counter;
/*!
* \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not
* exist
*/
void AddConstructor(GlobalTypeVar gtv, Constructor c) {
if (!mod->ContainGlobalTypeVar(gtv->name_hint)) {
mod->AddTypeDef(gtv, TypeData(gtv, {}, {c}));
} else {
auto typedata = mod->LookupTypeDef(gtv);
auto constructors = typedata->constructors;
constructors.push_back(c);
mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors));
}
}
/*!
* \brief add a case to the apply function, creating the function if it does not exist
*
* \param apply_gv GlobalVar of the apply function
* \param ft is the type functions the apply function handles
* \param c constructor to add a case for
* \param expr calls this expr with the args to the apply_gv
* \param patterns PatterVars to match with the constructor, used for handling free vars in
* functions
*/
void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr,
const Array<Pattern> patterns) {
CHECK(c->inputs.size() == patterns.size())
<< "constructor function and pattern vars have different sizes";
if (!mod->ContainGlobalVar(apply_gv->name_hint)) {
auto x = Var("x", TypeCall(c->belong_to, {}));
auto vars = Array<Var>({x});
auto args = Array<Expr>();
for (auto t : ft->arg_types) {
auto y = Var("y", t);
vars.push_back(y);
args.push_back(y);
}
auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))});
auto body = Match(x, clauses);
auto f = Function(vars, body, ft->ret_type, {});
mod->Add(apply_gv, f);
} else {
auto f = Downcast<Function>(mod->Lookup(apply_gv));
auto body = f->body.as<MatchNode>();
CHECK(body) << "internal invariant broken; apply function body should be a match node";
auto clauses = body->clauses;
auto x = f->params[0];
auto args = Array<Expr>();
for (size_t i = 1; i < f->params.size(); i++) {
args.push_back(f->params[i]);
}
clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args)));
mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true);
}
}
Expr EncodeArg(const Expr& arg, const Type& type) {
// we assume arg is either an identifier (var or globalvar) or a function
CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
<< "assume all first-order-parameters are identifiers or functions";
if (arg.as<VarNode>()) {
// variable with functype will be encoded as datatype in surrounding function
return arg;
} else if (arg.as<GlobalVarNode>()) {
return EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type));
} else if (auto fn = arg.as<FunctionNode>()) {
// we handle free vars in anonymous functions by adding arguments to
// the constructor function
auto free_vars = FreeVars(arg);
auto ft = Downcast<FuncType>(type);
auto arg_types = Array<Type>();
auto pattern_vars = Array<Pattern>();
auto call_args = Array<Expr>();
Map<Var, Expr> free_var_bind_map;
for (auto free_var : free_vars) {
// free vars are already encoded, can only exist within
// specialized functions
if (free_var->type_annotation.defined()) {
arg_types.push_back(free_var->type_annotation);
} else {
arg_types.push_back(free_var->checked_type());
}
auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
free_var_bind_map.Set(free_var, new_var);
pattern_vars.push_back(PatternVar(new_var));
call_args.push_back(free_var);
}
auto gtv = GetFuncEncode(ft);
auto c = Constructor(std::to_string(++constructor_counter), arg_types, gtv);
AddConstructor(gtv, c);
auto apply_gv = GetApplyFunction(ft);
auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
pattern_vars);
return Call(c, call_args);
}
throw std::runtime_error("EncodeArg failed to cast arg into identifier node or function node");
}
/*!
* \brief encode a global var with a specialized type with a datatype
*/
Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) {
auto map = gv_datatype_map[gv];
auto type_key = TypeToString(ft);
if (map.count(type_key) == 0) {
auto gtv = GetFuncEncode(ft);
auto c = Constructor(std::to_string(constructor_counter++), {}, gtv);
map[type_key] = c;
AddConstructor(gtv, c);
AddApplyCase(GetApplyFunction(ft), ft, c, gv, {});
}
return Call(map[type_key], {});
}
/*!
* \brief type to string
*/
std::string TypeToString(const Type& t) {
std::ostringstream s;
s << t;
return s.str();
}
/*!
* \brief get ADT handle for encoding type t
*/
GlobalTypeVar GetFuncEncode(const Type& t) {
auto adt_name = "Defunc" + TypeToString(t);
if (func_encoding.count(adt_name) == 0) {
func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle);
}
original_func_type_map[func_encoding[adt_name]] = t;
return func_encoding[adt_name];
}
/*!
* \brief get original function type represented by type t
*/
FuncType GetUnencodedType(const Type& t) {
auto tc = t.as<TypeCallNode>();
CHECK(tc) << "expected type call when getting original type from encoded type";
auto gv = tc->func.as<GlobalTypeVarNode>();
CHECK(gv) << "expected global type var in encoded type";
auto type = original_func_type_map[GetRef<GlobalTypeVar>(gv)];
CHECK(type.defined()) << "reverse mapping from encoded type to original type not found";
return Downcast<FuncType>(type);
}
/*!
* \brief get the apply function for calling datatypes encoding functions of type t
*/
GlobalVar GetApplyFunction(const Type& t) {
auto f_name = "apply" + TypeToString(t);
if (apply_map.count(f_name) == 0) {
apply_map[f_name] = GlobalVar("apply" + TypeToString(t));
}
return apply_map[f_name];
}
/*!
* \brief specialize a function type
*/
FuncType InstFuncType(const FuncTypeNode* fty, const Array<Type> type_args) {
CHECK(fty) << "InstFuncType functype is null";
CHECK_EQ(fty->type_params.size(), type_args.size())
<< "size mismatch between function type params and type args";
auto map = tvm::Map<TypeVar, Type>();
for (size_t i = 0; i < type_args.size(); i++) {
map.Set(fty->type_params[i], type_args[i]);
}
// copy with typevars removed
return Downcast<FuncType>(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map));
}
/*!
* \brief specialize a function expression
*/
Function Specialize(const Function& f, const Array<Type> type_args) {
CHECK_EQ(f->type_params.size(), type_args.size())
<< "cannot specialize function with size mismatch between function type params and type "
"args";
auto map = tvm::Map<TypeVar, Type>();
for (size_t i = 0; i < type_args.size(); i++) {
map.Set(f->type_params[i], type_args[i]);
}
// copy with typevars removed
auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map);
return Downcast<Function>(copy);
}
/*!
* \brief transform a function to be first order by transforming arg_types and
* using the `apply` function for applications
*/
Function FirstifyVars(const Function& f) {
CHECK(f->type_params.size() == 0) << "firstify function has type params";
tvm::Map<Var, Expr> var_bind_map;
Array<Var> params;
for (auto var : f->params) {
if (auto var_type = var->type_annotation.as<FuncTypeNode>()) {
// first order parameter
auto fop_type = GetRef<FuncType>(var_type);
auto adt = GetFuncEncode(fop_type);
auto new_var = Var(var->name_hint(), TypeCall(adt, {}));
mod->LookupTypeDef(adt);
var_bind_map.Set(var, new_var);
params.push_back(new_var);
} else {
CHECK(!HasFuncType(var->type_annotation))
<< "nested function type in parameter not supported yet";
params.push_back(var);
}
}
auto bind = Downcast<Function>(Bind(f, var_bind_map));
return Function(params, this->VisitExpr(bind->body), bind->ret_type, {});
}
};
Expr Defunctionalization(const Function& f, const IRModule& mod) {
// f is the starting point of the program, all types MUST be known
CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization";
for (const auto& p : f->params) {
CHECK(!HasFuncType(p->checked_type())) << "program cannot have func type parameters";
}
CHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function";
return Downcast<Function>(DefuncMutator(mod).VisitExpr(f));
}
TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization").set_body_typed(Defunctionalization);
} // namespace relay
} // namespace tvm