blob: bf81050803178fc06809409cbb96d97a592812bc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file gradient.cc
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
#include "let_list.h"
#include "pass_util.h"
#include "pattern_util.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
/*! What is automatic differentiation(AD) and why is it important?
* By AD, we roughly mean, given a term which denotes some mathematical function,
* derive a term which denotes the derivative of that mathematical function.
* Such a method can be compile-time, which is a macro on completely known function.
* Formally speaking, such requirement mean that the input function is a closed expression -
* that is, it only refer to local variable that is it's parameter, or defined inside it.
* Every top level definition satisfy this criteria.
* AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) ->
* (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough
* for a lot of use case.
*
* In deep learning, the most common way to train a deep neural network is by gradient descent or
* some of it's variant. Such optimization method require us to input the gradient of neural
* network, which can be obtained easily using AD. In fact, back propagation is essentially
* reverse-mode automatic differentiation, a kind of AD!
*/
/*! In relay, automatic differentiation(AD) is a macro,
* that transform closed expr(expr without free variable/free type variable) of type
* (x0, x1, x2, ...) -> Float[] to
* (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)),
* When x0, x1, x2... are Float of different shape.
* the return value is a pair, with left hand side as the original value, and right hand side as
* gradient of the input. WithGradientType will take the type of input, and produce the type of
* output. There are multiple implementation of AD in relay, with different characteristic. However,
* they all transform the input expr according to WithGradientType.
*/
Type WithGradientType(const Type&);
/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);
Type WithGradientType(const Type& t) {
// TODO(@M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {});
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return GetRef<Function>(n);
} else {
return e;
}
} else {
return e;
}
}
/*! \brief A fragment of the program being built by the automatic differentation
* pass.
*/
struct ADValueNode {
virtual ~ADValueNode() {}
template <typename T>
T& get() {
auto ret = dynamic_cast<T*>(this);
CHECK(ret) << "cannot downcast";
return *ret;
}
};
template <typename F>
Expr MultiFactory(const Type& t, F factory) {
if (auto* tt = t.as<TensorTypeNode>()) {
return factory(tt->shape, tt->dtype);
} else if (auto* tt = t.as<TupleTypeNode>()) {
std::vector<Expr> res;
for (size_t i = 0; i < tt->fields.size(); i++) {
res.push_back(MultiFactory(tt->fields[i], factory));
}
return Tuple(res);
} else {
LOG(FATAL) << "unsupported type to create tensors of: " << tt;
throw;
}
}
template <typename F, typename F2>
Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) {
if (t.as<TensorTypeNode>()) {
return factory_like(e);
} else if (auto* tt = t.as<TupleTypeNode>()) {
return MultiFactory(t, factory);
} else {
LOG(FATAL) << "unsupported type to tensors of: " << tt;
throw;
}
}
using ADValue = std::shared_ptr<ADValueNode>;
/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& forward)
: forward(ll->Push(forward)),
reverse(
ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) {
this->forward->checked_type_ = forward->checked_type();
}
};
/*! \brief A staged representation of the program, we reflect
* Relay functions into a function over fragments of AD. We
* can compute away this function to obtain a reverse mode program.
*/
struct ADFunction : ADValueNode {
std::function<ADValue(const Type&, const std::vector<ADValue>&, const Attrs&,
const tvm::Array<Type>&)>
func;
explicit ADFunction(const std::function<ADValue(const Type&, const std::vector<ADValue>&,
const Attrs&, const tvm::Array<Type>&)>& func)
: func(func) {}
};
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
using TBase = ExprFunctor<ADValue(const Expr&)>;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
LetList* ll;
FirstOrderReverseAD(LetList* ll) : ll(ll) {}
ADValue VisitExpr(const Expr& n) final {
if (env.count(n)) {
return env.at(n);
}
auto ret = TBase::VisitExpr(n);
env[n] = ret;
return ret;
}
ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
return std::make_shared<ADFunction>(
[this, op_ref](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = Call(op_ref, call_args, attrs, type_args);
orig->checked_type_ = orig_type;
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
});
return ret;
});
}
ADValue VisitExpr_(const TupleGetItemNode* op) final {
Expr e = GetRef<Expr>(op);
ADValue tup = VisitExpr(op->tuple);
auto tt = op->tuple->checked_type().as<TupleTypeNode>();
size_t size = tt->fields.size();
size_t idx = op->index;
auto ret = std::make_shared<ADTensor>(ll, e);
backprop_actions.push_back([tup, idx, size, ret](LetList* ll) {
auto rev = tup->get<ADTensor>().reverse;
// special-case Tuple, to avoid long chains of GetItem/Tuple,
// but we might have functions using tuples, so we don't know
// that the reverse node is always a tuple
std::vector<Expr> grfields;
if (auto tup_node = rev.as<TupleNode>()) {
for (size_t i = 0; i < size; ++i) {
grfields.push_back(i != idx ? tup_node->fields[i]
: Add(tup_node->fields[i], ret->reverse));
}
} else {
for (size_t i = 0; i < size; ++i) {
grfields.push_back(i != idx ? TupleGetItem(rev, i)
: Add(TupleGetItem(rev, i), ret->reverse));
}
}
tup->get<ADTensor>().reverse = ll->Push(Tuple(grfields));
});
return ret;
}
ADValue VisitExpr_(const TupleNode* op) final {
Expr e = GetRef<Expr>(op);
std::vector<ADValue> fields;
for (const auto& f : op->fields) {
fields.push_back(VisitExpr(f));
}
auto ret = std::make_shared<ADTensor>(ll, e);
backprop_actions.push_back([fields, ret](LetList* ll) {
for (size_t i = 0; i < fields.size(); ++i) {
fields[i]->get<ADTensor>().reverse =
ll->Push(Add(fields[i]->get<ADTensor>().reverse, TupleGetItem(ret->reverse, i)));
}
});
return ret;
}
ADValue VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return std::make_shared<ADTensor>(ll, e);
}
ADValue VisitExpr_(const CallNode* op) final {
ADValue f = VisitExpr(op->op);
std::vector<ADValue> args;
for (const auto& arg : op->args) {
args.push_back(VisitExpr(arg));
}
return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args);
}
ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
// todo: assert no closure
return std::make_shared<ADFunction>(
[this, f](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs,
const tvm::Array<Type>& type_args) {
CHECK_EQ(f->params.size(), args.size());
for (size_t i = 0; i < f->params.size(); ++i) {
env[f->params[i]] = args[i];
}
return VisitExpr(f->body);
});
}
// Var will always be in env, handled in VisitExpr (without _), so we don't need
// to implement its VisitExpr_.
};
Type GradRetType(const Function& f) {
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
if (!f->ret_type.defined()) {
return Type();
}
std::vector<Type> vt;
for (const auto& p : f->params) {
if (!p->type_annotation.defined()) {
return Type();
}
vt.push_back(p->type_annotation);
}
return TupleType({f->ret_type, TupleType(vt)});
}
Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {
// Currently we first remove any global functions for the first
// order case.
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "FOWithGradient expects its argument to be a function: " << f;
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
args.push_back(std::make_shared<ADTensor>(ll, p));
}
auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike);
for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return Tuple(grad_res);
});
return Pair(res.forward, grad);
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);
static Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RelayRefType(t)});
}
Type VisitType_(const FuncTypeNode* ftn) final {
std::vector<Type> arg_types;
for (const auto& t : ftn->arg_types) {
arg_types.push_back(VisitType(t));
}
arg_types.push_back(bpt);
return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints);
}
};
Type ReverseType(const Type& t) { return ReverseADType()(t); }
/*! \brief Lift a function that transform Tensor to a function that also transform more type
* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
const std::function<Type(const Type&)>& tf, const Type& forward_type, const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (forward_type.as<TensorTypeNode>()) {
auto ret = ll->Push(f(e));
ret->checked_type_ = tf(forward_type);
return std::move(ret);
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
for (size_t i = 0; i < tt->fields.size(); ++i) {
auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll);
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = ll->Push(Tuple(fields));
ret->checked_type_ = TupleType(types);
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
}
}
/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
* by stitching the references in the AD values.
*/
void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) {
CHECK(IsAtomic(from)) << from;
CHECK(IsAtomic(to)) << to;
if (forward_type.as<TensorTypeNode>()) {
auto from_ref = TupleGetItem(from, 1);
auto to_ref = TupleGetItem(to, 1);
ll->Push(RefWrite(to_ref, RefRead(from_ref)));
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)),
ll);
}
} else {
LOG(FATAL) << "Unsupported input/output type: " << forward_type;
throw;
}
}
// TODO(@M.K.): why take Expr?
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
return LiftTensor(rev, rev_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
auto val = [&](const Expr& e) { return GetField(e, 0); };
auto val_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(val, val_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
auto grad_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(grad, grad_type, forward_type, e, ll);
}
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
}
} else {
LOG(FATAL) << "unsupported arg type of operator: " << t;
throw;
}
}
Expr BPEmpty() {
Expr unitF = Function({}, Tuple(tvm::Array<Expr>({})), TupleType::Empty(), {});
return RefCreate(unitF);
}
struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;
using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
Optional<IRModule> mod;
// TODO(@M.K.) refactor AD to always use mod.
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
std::shared_ptr<ADGlobalVarMap> ad_gvars;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp,
const std::shared_ptr<ADVarMap>& ad_vars,
const std::shared_ptr<ADGlobalVarMap>& ad_gvars)
: mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {}
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}
Expr Remap(const Expr& e) {
struct Remapper : ExprMutator {
std::shared_ptr<ADVarMap> ad_vars;
LetList* ll;
Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (ad_vars->count(var_ref) == 0) {
return std::move(var_ref);
} else {
return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
}
}
};
return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
}
Expr VisitCheckpoint(const CallNode* call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(Remap(x));
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
auto dup_ad =
ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x)));
TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(Call(RefRead(dup_bp), {}));
return Call(bpv, {});
}),
TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
return ret;
});
}
Expr VisitExpr_(const CallNode* call) final {
if (const OpNode* op_node = call->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
if (op_ref->name == "annotation.checkpoint") {
return VisitCheckpoint(call);
}
CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : call->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
}
Expr orig = Call(call->op, orig_args, call->attrs, call->type_args);
orig->checked_type_ = call->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp_body = LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
} else if (call->op.as<ConstructorNode>()) {
return ExprMutator::VisitExpr_(call);
} else {
std::vector<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}
args.push_back(bp);
return Call(VisitExpr(call->op), args);
}
}
Expr VisitExpr_(const ConstantNode* op) final {
return LetList::With([&](LetList* ll) {
Expr e = ll->Push(GetRef<Expr>(op));
return Pair(e, RefCreate(ZerosLike(e)));
});
}
Expr VisitExpr_(const IfNode* op) final {
return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch),
VisitExpr(op->false_branch));
}
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (ad_vars->count(var_ref) == 0) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}
return ad_vars->at(var_ref);
}
Expr VisitExpr_(const GlobalVarNode* op) final {
// todo: concatenating string to add attribute seems like a brittle hack.
// maybe get module indexed by a rose tree of string?
CHECK(mod.defined());
auto orig_gv = GetRef<GlobalVar>(op);
if (ad_gvars->count(orig_gv) == 0) {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
std::vector<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}
Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var : op->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body),
VisitType(op->ret_type), op->type_params, op->attrs);
}
Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
};
bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::unordered_set<std::string> op_names;
void VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
op_names.insert(op_ref->name);
}
ExprVisitor::VisitExpr_(op);
}
};
MGVisitor mg;
mg.VisitExpr(e);
if (mg.op_names.size() > 0) {
LOG(WARNING) << "found operators with missing gradients:";
for (const auto& op : mg.op_names) {
LOG(WARNING) << " " << op;
}
return true;
}
return false;
}
Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
CheckFeature(re, FeatureSet::All() - fGraph);
if (mod.defined()) {
CheckFeature(mod.value(), FeatureSet::All() - fGraph);
}
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
for (const auto& p : f->params) {
CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
}
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty(), bpt);
Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(),
std::make_shared<ReverseAD::ADGlobalVarMap>())(e);
std::vector<Expr> normal_args, args;
for (const auto& p : f->params) {
auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
normal_args.push_back(x);
args.push_back(x);
}
args.push_back(bp);
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWrite(GetField(e, 1), OnesLike(GetField(e, 0))));
} else if (auto tt = t.as<TupleTypeNode>()) {
CHECK_GT(tt->fields.size(), 0);
init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
} else {
LOG(FATAL) << "unhandled type " << t;
throw;
}
};
init_grad(c, f->body->checked_type());
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
for (const auto& a : normal_args) {
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
get_final_result = [&](const Expr& e, const Type& t) -> Expr {
if (t.as<TensorTypeNode>()) {
return GetField(e, 0);
} else if (auto tt = t.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
}
return Tuple(fields);
} else {
LOG(FATAL) << "unhandled type " << t;
throw;
}
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
CheckFeature(ret, FeatureSet::All() - fGraph);
return std::move(ret);
}
TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
} // namespace relay
} // namespace tvm