blob: ec45ba00f1b3d9e87777e26e498693d075ec986b [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <utility>
#include "../../support/ordered_set.h"
#include "utils.h"
namespace tvm {
namespace relax {
namespace {
class FunctionInliner : public ExprMutator {
public:
explicit FunctionInliner(
const ffi::Map<ffi::Variant<ffi::String, GlobalVar>, Function>& replacements)
: replacements_(replacements) {}
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const FunctionNode* op) override {
auto node = ExprMutator::VisitExpr_(op);
if (node.get() != op) {
node = CanonicalizeBindings(node);
node = RemoveAllUnused(node);
}
return node;
}
Expr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
if (auto opt = node->op.as<GlobalVar>()) {
auto gvar = opt.value();
if (auto opt = GetFunction(gvar)) {
auto callee = opt.value();
TVM_FFI_ICHECK_EQ(callee->params.size(), node->args.size())
<< "Attempted to inline call to " << gvar << ", which accepts " << callee->params.size()
<< " parameters. "
<< "However, it was called with " << node->args.size() << " arguments in expression "
<< node;
Expr inlined = InlinedCall(callee, node->args);
TVM_FFI_ICHECK(!inline_stack_.count(gvar))
<< "Relax function inlining does not support recursive functions. "
<< "However, recursive function " << gvar << " was requested to be inlined.";
inline_stack_.insert(gvar);
inlined = VisitExpr(std::move(inlined));
inline_stack_.erase(gvar);
return inlined;
}
}
return node;
}
private:
ffi::Optional<Function> GetFunction(const GlobalVar& gvar) const {
if (auto opt = replacements_.Get(gvar)) {
return opt;
} else if (auto opt = replacements_.Get(gvar->name_hint)) {
return opt;
} else {
return std::nullopt;
}
}
Expr InlinedCall(Function func, const ffi::Array<Expr>& args) const {
// Ensures that the inlined instance does not have duplicate usage
// with other inlined copies, or with the original callee.
func = CopyWithNewVars(std::move(func));
ffi::Array<Binding> param_bindings;
ffi::Map<Var, Expr> param_map;
for (size_t i = 0; i < args.size(); i++) {
// Option 1: Use tvm::relax::Bind to substitute arguments into
// the body. If the arguments contain DataflowVar instances,
// but the subroutine does not use DataflowBlock, this would
// result in invalid AST.
//
// Option 2: Define a VarBinding `param[i] = args[i]` for each
// parameter, then rely on CanonicalizeBindings to replace with
// DataflowVar where possible. This would solve the invalid use
// of DataflowVar, but wouldn't handle symbolic variables. If
// the subroutine has symbolic variables defined by its
// arguments, the VarBinding would leave them undefined.
//
// Option 3: Define a MatchCast `param[i] = args[i]` for each
// parameter, followed by CanonicalizeBindings. This is the
// first option that would result in well-formed AST, but it
// wouldn't be optimal. Symbolic variables would have two
// copies, one from the initial definition, and one
// from the MatchCast inlined portion.
//
// Option 4: Define a VarBinding `param[i] = args[i]`, with
// CanonicalizeBindings to handle conversion of Var to
// DataflowVar, and tvm::relax::Bind to handle substitution of
// symbolic variables. This would result in a well-formed Relax
// function, with no duplicate definitions of symbolic
// variables.
//
// This implementation uses Option 4.
Var param_var(func->params[i]->name_hint(), args[i]->struct_info_.as<StructInfo>());
param_bindings.push_back(VarBinding(param_var, args[i]));
param_map.Set(func->params[i], param_var);
}
DataflowBlock binding_block(param_bindings);
Expr body = Bind(func, param_map).as<FunctionNode>()->body;
return SeqExpr({binding_block}, body);
}
const ffi::Map<ffi::Variant<ffi::String, GlobalVar>, Function>& replacements_;
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> inline_stack_;
};
} // namespace
/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
Function FunctionInlineFunctions(
Function func, const ffi::Map<ffi::Variant<ffi::String, GlobalVar>, Function>& replacements) {
for (const auto& [key, func] : replacements) {
if (auto ptr = key.as<GlobalVarNode>()) {
TVM_FFI_CHECK(!replacements.count(ptr->name_hint), ValueError)
<< "Map of functions to inline must be unambiguous. "
<< "However, the map provided contains both the GlobalVar " << key << " and the string \'"
<< ptr->name_hint << "'";
}
}
FunctionInliner mutator(replacements);
return Downcast<Function>(mutator(std::move(func)));
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.FunctionInlineFunctions", FunctionInlineFunctions);
}
namespace transform {
Pass InlinePrivateFunctions() {
auto pass_func = [=](IRModule mod, PassContext pc) {
ffi::Map<ffi::Variant<ffi::String, GlobalVar>, Function> replacements;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<relax::Function>()) {
auto func = opt.value();
bool is_private = !func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value();
if (is_private) {
replacements.Set(gvar, func);
}
}
}
if (replacements.empty()) {
// Early bail-out if there are no private functions.
return mod;
}
for (const auto& recursive_set : DetectRecursion(mod)) {
for (const auto& recursive_func : recursive_set) {
replacements.erase(recursive_func);
}
}
if (replacements.empty()) {
// Early bail-out if all private functions are recursive.
return mod;
}
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (!replacements.count(gvar)) {
if (auto opt = base_func.as<relax::Function>()) {
auto func = FunctionInlineFunctions(opt.value(), replacements);
if (!base_func.same_as(func)) {
updates->Add(gvar, func);
}
}
}
}
auto write_ptr = mod.CopyOnWrite();
for (const auto& [key, func] : replacements) {
write_ptr->Remove(Downcast<GlobalVar>(key));
}
write_ptr->Update(updates);
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "InlinePrivateFunctions", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.InlinePrivateFunctions", InlinePrivateFunctions);
}
} // namespace transform
} // namespace relax
} // namespace tvm