blob: ce69053311d1b7c0adc598b4e54ab4afbd8478f0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inline_private_functions.cc
* \brief Inline private functions to their callsite
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace transform {
namespace {
template <typename T>
using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
template <typename T, typename U>
using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
PMap<GlobalVar, PSet<GlobalVar>> CollectCallMap(const IRModule& mod) {
struct Visitor : StmtExprVisitor {
GlobalVar current;
PMap<GlobalVar, PSet<GlobalVar>> caller_lookup;
void VisitExpr_(const CallNode* op) {
if (auto gvar = op->op.as<GlobalVar>()) {
caller_lookup[gvar.value()].insert(current);
}
StmtExprVisitor::VisitExpr_(op);
}
} visitor;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFuncNode>()) {
visitor.current = gvar;
visitor(prim_func->body);
}
}
return visitor.caller_lookup;
}
PSet<GlobalVar> CollectRecursiveFunctions(const IRModule& mod) {
// Collect all direct callers
auto call_map = CollectCallMap(mod);
// Propagate to find all indirect callers
while (true) {
bool made_change = false;
for (const auto& [callee, callers] : call_map) {
for (const auto& caller : callers) {
if (auto it = call_map.find(caller); it != call_map.end()) {
PSet<GlobalVar>& indirect_callers = it->second;
auto res = indirect_callers.insert(callee);
made_change = made_change || res.second;
}
}
}
if (!made_change) {
break;
}
}
// Filter all GlobalVars that can be called by themselves, either
// directly or indirectly.
PSet<GlobalVar> recursive_funcs;
for (const auto& [caller, callees] : call_map) {
if (callees.count(caller)) {
recursive_funcs.insert(caller);
}
}
return recursive_funcs;
}
bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func,
const PSet<GlobalVar>& recursive_functions) {
// Only inline private functions. Externally-exposed functions
// must be preserved so to avoid breaking callsites outside of
// the IRModule.
bool is_exposed = prim_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value();
if (is_exposed) return false;
// We do not currently implement any analysis for termination of
// a function. If a recursive function requires runtime checks
// in order to terminate, we would keep inlining until the
// recursive visits segfault.
bool is_recursive = recursive_functions.count(gvar);
if (is_recursive) return false;
// We do not currently support inlining of functions that accept
// buffer arguments.
bool has_buffer_arguments = prim_func->buffer_map.size();
if (has_buffer_arguments) return false;
// We do not currently support inlining of schedulable TIR
// functions. To support this use case, repeated names in
// `tir::Block` nodes resulting from multiple calls to the same
// inlined function will need to be de-duplicated.
bool has_block_node = prim_func->body.as<BlockRealizeNode>();
if (has_block_node) return false;
return true;
}
ffi::Map<GlobalVar, PrimFunc> CollectInlinablePrimFuncs(const IRModule& mod) {
auto recursive_functions = CollectRecursiveFunctions(mod);
ffi::Map<GlobalVar, PrimFunc> output;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (IsInlinablePrimFunc(gvar, prim_func, recursive_functions)) {
output.Set(gvar, prim_func);
}
}
}
return output;
}
class PrimFuncInliner : StmtExprMutator {
public:
explicit PrimFuncInliner(ffi::Map<GlobalVar, PrimFunc> inlinable_funcs)
: inlinable_funcs_(inlinable_funcs) {
for (const auto& [gvar, callee] : inlinable_funcs_) {
removable_funcs_.insert(gvar);
}
}
PrimFunc VisitFunc(PrimFunc func) {
current_target_ = func->GetAttr<Target>(tvm::attr::kTarget);
auto new_body = VisitStmt(func->body);
current_target_ = std::nullopt;
if (!new_body.same_as(func->body)) {
func.CopyOnWrite()->body = new_body;
}
return func;
}
PSet<GlobalVar> GetRemovableFunctions() const { return removable_funcs_; }
private:
Stmt VisitStmt_(const EvaluateNode* eval) override {
if (auto inlined = GetInlinedFunction(eval)) {
return inlined.value();
} else {
return StmtExprMutator::VisitStmt_(eval);
}
}
ffi::Optional<Stmt> GetInlinedFunction(const EvaluateNode* eval) {
auto call = eval->value.as<CallNode>();
if (!call) return std::nullopt;
auto gvar = call->op.as<GlobalVar>();
if (!gvar) return std::nullopt;
auto opt_callee = inlinable_funcs_.Get(gvar.value());
if (!opt_callee) return std::nullopt;
auto callee = opt_callee.value();
bool is_same_target = [&]() -> bool {
auto callee_target = callee->GetAttr<Target>(tvm::attr::kTarget);
if (current_target_ && callee_target) {
return callee_target.value()->str() == current_target_.value()->str();
} else {
return true;
}
}();
if (!is_same_target) return std::nullopt;
Stmt inlined = InlineArguments(gvar.value(), callee, call->args);
return VisitStmt(inlined);
}
PrimExpr VisitExpr_(const CallNode* call) override {
// Because the current implementation inlines a subroutine inserts
// the `tir::Stmt` body at the point of use, replacement must
// occur in a context where a `tir::Stmt` can be returned. Support
// of subroutines that are called within an expression
// (e.g. Replacing func in `Buf[0] = func(1) + func(2)`) would
// require hoisting preprocessing done in the subroutine to the
// parent `tir::Stmt`.
//
// See `TestInlineCallOccurringInExpression` in
// `test_tir_inline_private_functions.py` for a test of this
// behavior, currently marked with `pytest.mark.xfail`.
//
// Any callee that hasn't been inlined at this point must be kept
// in the output IRModule.
if (auto gvar = call->op.as<GlobalVar>()) {
removable_funcs_.erase(gvar.value());
}
return StmtExprMutator::VisitExpr_(call);
}
Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee,
const ffi::Array<PrimExpr>& args) const {
CHECK_EQ(callee->params.size(), args.size())
<< "Callee " << gvar << " accepts " << callee->params.size() << " parameters ("
<< callee->params << "), but is called with " << args.size() << " arguments (" << args
<< ")";
ICHECK(callee->buffer_map.empty())
<< "Inlining of PrimFuncs with buffer arguments is not yet supported, "
<< "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map;
ffi::Map<Var, ffi::Variant<tir::Buffer, tvm::PrimExpr>> param_map;
for (size_t i = 0; i < callee->params.size(); i++) {
param_map.Set(callee->params[i], args[i]);
}
callee = Specialize(callee, param_map);
return callee->body;
}
// Map from GlobalVar to PrimFuncs which may be inlined.
ffi::Map<GlobalVar, PrimFunc> inlinable_funcs_;
/* \brief Set of callees that may be removed
*
* Some constructs may not be inlined (e.g. if the call site occurs
* outside of an Evaluate node). For these cases, the output
* IRModule must still contain the callee.
*/
PSet<GlobalVar> removable_funcs_;
ffi::Optional<Target> current_target_ = std::nullopt;
};
} // namespace
Pass InlinePrivateFunctions() {
auto pass_func = [](IRModule mod, PassContext ctx) {
auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod);
if (inlinable_prim_funcs.empty()) {
// Early bail-out if the module has no inlinable PrimFuncs.
return mod;
}
PrimFuncInliner mutator(std::move(inlinable_prim_funcs));
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto updated = mutator.VisitFunc(opt.value());
if (!updated.same_as(base_func)) {
updates->Add(gvar, updated);
}
}
}
if (updates->functions.size()) {
auto write_ptr = mod.CopyOnWrite();
write_ptr->Update(updates);
for (const auto& gvar : mutator.GetRemovableFunctions()) {
write_ptr->Remove(gvar);
}
mod = ConvertSSA()(mod);
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.InlinePrivateFunctions", InlinePrivateFunctions);
}
} // namespace transform
} // namespace tir
} // namespace tvm