blob: f093f5425d94223c45dc50badf5b179e08e2beeb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include "pass_util.h"
namespace tvm {
namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private MixedModeMutator {
public:
Expr Realize(Expr expr) { return Mutate(expr); }
private:
Expr DispatchVisitExpr(const Expr& expr) final {
Expr res;
if (const auto* temp = expr.as<TempExprNode>()) {
res = temp->Realize();
} else {
res = MixedModeMutator::DispatchVisitExpr(expr);
}
return res;
}
};
class ForwardRewriter : private MixedModeMutator {
public:
ForwardRewriter(const OpAttrMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
ForwardRewriter(const FForwardRewrite* rewrite_func,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
// Transform expression.
Expr Rewrite(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr);
}
return realizer_.Realize(this->VisitExpr(expr));
}
private:
// The rewrite rule.
const OpAttrMap<FForwardRewrite>* rewrite_map_{nullptr};
const FForwardRewrite* rewrite_func_{nullptr};
// The context.const
std::function<ObjectRef(const Call&)> fcontext_{nullptr};
// The multiple reference trigger
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
// Internal ref counter
std::unordered_map<const Object*, size_t> ref_counter_;
// internal realizer
TempRealizer realizer_;
// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr, const Expr& post) {
if (fmulti_ref_trigger_ != nullptr) {
Expr ret = post;
auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end());
if (it->second > 1) {
ret = fmulti_ref_trigger_(ret);
}
return ret;
} else {
return post;
}
}
// Automatic fold TupleGetItem.
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index];
} else {
if (tuple.same_as(op->tuple)) {
return GetRef<Expr>(op);
} else {
return TupleGetItem(tuple, op->index);
}
}
}
Expr Rewrite_(const TupleNode* op, const Expr& post) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
const auto* post_node = post.as<TupleNode>();
for (size_t i = 0; i < op->fields.size(); ++i) {
auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(op->fields[i]);
}
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields);
}
}
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite;
if (rewrite_func_) {
frewrite = *rewrite_func_;
} else {
CHECK(rewrite_map_);
frewrite = rewrite_map_->get(call_node->op, nullptr);
}
const auto* post_node = post.as<CallNode>();
auto new_op = post_node->op;
bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args;
for (size_t i = 0; i < call_node->args.size(); ++i) {
Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
if (frewrite == nullptr) {
new_arg = realizer_.Realize(new_arg);
}
unchanged &= new_arg.same_as(call_node->args[i]);
call_args.push_back(new_arg);
}
// try to rewrite.
if (frewrite != nullptr) {
Expr res = frewrite(ref_call, call_args,
fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
if (res.defined()) return res;
// abort, use old rule
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = realizer_.Realize(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
}
if (unchanged) return ref_call;
return Call(new_op, call_args, call_node->attrs, call_node->type_args);
}
};
Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_name,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
auto rewrite_map = Op::GetAttrMap<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
} // namespace relay
} // namespace tvm