blob: 9f6f4da7eaf3e980eb5305c6f893f52198521783 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tir {
const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) {
GlobalVar result = NullValue<GlobalVar>();
// Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
int num_prim_func = 0;
const tir::PrimFuncNode* main_func = nullptr;
const tir::PrimFuncNode* last_func = nullptr;
for (const auto& kv : mod->functions) {
GlobalVar gv = kv.first;
BaseFunc base_func = kv.second;
if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
last_func = func;
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
if (result_g_var != nullptr) {
*result_g_var = gv;
}
return func;
}
if (gv->name_hint == "main") {
main_func = func;
result = gv;
}
++num_prim_func;
}
}
// Priority 2: PrimFunc whose name is `main`
if (main_func != nullptr) {
if (result_g_var != nullptr) {
*result_g_var = result;
}
return main_func;
}
// Priority 3: The only PrimFunc in the IRModule
if (num_prim_func == 1) {
if (result_g_var != nullptr) {
*result_g_var = result;
}
return last_func;
}
return nullptr;
}
Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) {
struct GetRootSeqStmt : public StmtVisitor {
void VisitStmt_(const SeqStmtNode* seq) override { result = seq; }
const SeqStmtNode* result;
};
struct BlockFinder : public StmtVisitor {
explicit BlockFinder(const BlockNode* tgt) : target(tgt) {}
void VisitStmt_(const BlockNode* block) override {
if (block == target) {
found = true;
}
}
const BlockNode* target;
bool found = false;
};
GetRootSeqStmt seq_finder;
seq_finder(func_body);
ICHECK(seq_finder.result);
for (auto stmt : seq_finder.result->seq) {
if (stmt->IsInstance<ForNode>()) {
BlockFinder finder(block);
finder(stmt);
if (finder.found) {
return stmt;
}
}
}
LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef<Block>(block);
TVM_FFI_UNREACHABLE();
}
const BlockNode* FindAnchorBlock(const IRModule& mod) {
struct ReductionBlockCollector : public StmtVisitor {
void VisitStmt_(const BlockNode* block) override {
if (block->init) {
blocks.push_back(block);
}
StmtVisitor::VisitStmt(block->body);
}
std::vector<const BlockNode*> blocks;
};
if (auto prim_func = FindEntryFunc(mod, nullptr)) {
ReductionBlockCollector collector;
collector(prim_func->body);
const auto& candidates = collector.blocks;
if (candidates.empty()) {
return nullptr;
} else if (candidates.size() == 1) {
return candidates[0];
}
double best_flops = -1;
int best_idx = 0;
for (size_t i = 0; i < candidates.size(); ++i) {
auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
auto flops = EstimateTIRFlops(loop);
if (flops > best_flops) {
best_flops = flops;
best_idx = i;
}
}
return candidates[best_idx];
}
return nullptr;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.analysis.find_anchor_block", [](const IRModule& mod) {
auto ret = FindAnchorBlock(mod);
if (ret) {
return ffi::Optional<Block>(ffi::GetRef<Block>(ret));
}
return ffi::Optional<Block>(std::nullopt);
});
}
} // namespace tir
} // namespace tvm