blob: a3173e990ccc4581bc24aae01b291e5d7bdb1a05 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/apply_pass_to_function.cc
* \brief Utility transformation that applies an inner pass to a subset of an IRModule
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/relax/expr.h>
#include <tvm/tir/function.h>
#include <unordered_set>
#include "../runtime/regex.h"
namespace tvm {
namespace transform {
namespace {
BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any attr_value) {
if (auto tir = func.as<tir::PrimFunc>()) {
return WithAttr(tir.value(), attr_key, attr_value);
} else if (auto relax = func.as<relax::Function>()) {
return WithAttr(relax.value(), attr_key, attr_value);
} else {
return func;
}
}
BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
if (auto tir = func.as<tir::PrimFunc>()) {
return WithoutAttr(tir.value(), attr_key);
} else if (auto relax = func.as<relax::Function>()) {
return WithoutAttr(relax.value(), attr_key);
} else {
return func;
}
}
} // namespace
Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
bool error_if_no_function_matches_regex) {
auto pass_name =
static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex)
.str();
auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex](
IRModule mod, PassContext) -> IRModule {
bool at_least_one_function_matched_regex = false;
std::unordered_set<ffi::String> keep_original_version;
std::unordered_set<ffi::String> internal_functions;
IRModule subset;
for (auto [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if (tvm::runtime::regex_match(name, func_name_regex)) {
at_least_one_function_matched_regex = true;
if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value()) {
// Function may be mutated, but is an internal function. Mark
// it as externally-exposed, so that any call-tracing internal
// transforms do not remove this function, in case it its
// callers are not being mutated.
internal_functions.insert(gvar->name_hint);
func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}
} else {
// Function may not be mutated. Replace it with a
// `relax::ExternFunc` to prevent references to it from
// dangling.
keep_original_version.insert(gvar->name_hint);
func = relax::ExternFunc("dummy_" + name);
func->struct_info_ = gvar->struct_info_;
}
subset->Add(gvar, func);
}
if (error_if_no_function_matches_regex) {
TVM_FFI_ICHECK(at_least_one_function_matched_regex)
<< "No function matched regex '" << func_name_regex << "', out of functions " << [&]() {
ffi::Array<ffi::String> function_names;
for (const auto& [gvar, func] : mod->functions) {
function_names.push_back(gvar->name_hint);
}
return function_names;
}();
}
IRModule new_subset = pass(subset);
if (new_subset.same_as(subset)) {
return mod;
}
auto write_ptr = mod.CopyOnWrite();
for (auto [gvar, func] : new_subset->functions) {
if (!keep_original_version.count(gvar->name_hint)) {
if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
it != write_ptr->global_var_map_.end()) {
write_ptr->Remove((*it).second);
}
if (internal_functions.count(gvar->name_hint)) {
func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
}
write_ptr->Add(gvar, func);
}
}
return mod;
};
return CreateModulePass(pass_func, 0, pass_name, {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("transform.ApplyPassToFunction", ApplyPassToFunction);
}
} // namespace transform
} // namespace tvm