blob: 136a92abde31f4ac8b2142dc91168b2a31f900a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file primfunc_utils.cc
* \brief Passes that serve as helper functions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace transform {
transform::Pass AnnotateEntryFunc() {
auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule {
// If only a single function exists, that function must be the entry
if (mod->functions.size() == 1) {
auto [gvar, base_func] = *mod->functions.begin();
if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
if (auto ptr = base_func.as<PrimFuncNode>()) {
mod->Update(gvar, WithAttr(ffi::GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true));
}
}
return mod;
}
// If the module has multiple functions, but only one is exposed
// externally, that function must be the entry.
bool has_external_non_primfuncs = false;
IRModule with_annotations;
for (const auto& [gvar, base_func] : mod->functions) {
bool is_external = base_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value();
if (is_external) {
if (auto ptr = base_func.as<PrimFuncNode>()) {
with_annotations->Add(
gvar, WithAttr(ffi::GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true));
} else {
has_external_non_primfuncs = true;
}
}
}
if (with_annotations->functions.size() == 1 && !has_external_non_primfuncs) {
mod->Update(with_annotations);
return mod;
}
// Default fallback, no annotations may be inferred.
return mod;
};
return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {});
}
transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
} else {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc)
.def("tir.transform.Filter", Filter);
}
} // namespace transform
} // namespace tir
} // namespace tvm