blob: 513f0d730e8c144ae4d5606cfa7015475fd40a56 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file profile_instrumentation.cc
*/
// Insert profile intrinsic at loop and function level. During codegen,
// these instruction can be replaced with a call to a target specific handler
// and can be used to capture profiling information such as processor cycles.
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace lwp {
TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id", Bool);
static int32_t start_id = 0;
struct LoopInfo {
LoopInfo() = default;
LoopInfo(unsigned i, unsigned d, unsigned h = 0) : id(i), depth(d), height(h) {
has_siblings = false;
has_parallel = false;
}
unsigned id;
int32_t depth;
int32_t height;
bool has_siblings;
// Set to 'true' if ForKind::kParallel is set for the current loop or one of its ancestor
bool has_parallel;
};
using LoopInfoMap = std::unordered_map<const ForNode*, LoopInfo>;
// Traverse loops depth first and assign them a unique number.
class LoopAnalyzer : public StmtExprVisitor {
public:
LoopInfoMap Analyze(const Stmt& stmt) {
this->VisitStmt(stmt);
return loops;
}
void VisitStmt_(const ForNode* op) final {
LoopInfo loop_info(start_id, 0);
start_id++;
loop_info.height = TraverseLoop(op->body, 0);
loops[op] = loop_info;
}
unsigned TraverseLoop(const Stmt& stmt, unsigned parent_depth, bool has_parallel = false) {
if (stmt->IsInstance<SeqStmtNode>()) {
std::vector<const ForNode*> siblings;
unsigned height = 0;
bool has_loop = false;
const SeqStmtNode* n = stmt.as<SeqStmtNode>();
for (Stmt s : n->seq) {
if (s->IsInstance<ForNode>()) {
has_loop = true;
const ForNode* f = s.as<ForNode>();
LoopInfo loop_info(start_id, parent_depth + 1);
start_id++;
bool parent_parallel = false;
if (has_parallel) {
loop_info.has_parallel = true;
parent_parallel = true;
} else if (f->kind == ForKind::kParallel) {
// has_parallel for the current loop is being set to 'false' since the
// intrinsic is added outside of the loop. The instrumentation isn't
// allowed for the subsequent nested loops.
loop_info.has_parallel = false;
parent_parallel = true;
}
siblings.push_back(f);
height = std::max(height, TraverseLoop(f->body, parent_depth + 1, parent_parallel));
loop_info.height = height;
loops[f] = loop_info;
}
}
if (siblings.size() > 1) {
for (auto* l : siblings) {
loops[l].has_siblings = true;
}
}
height = has_loop ? height + 1 : height;
return height; // Parent's height : max of all children's height
} else if (stmt->IsInstance<IfThenElseNode>()) {
const IfThenElseNode* n = stmt.as<IfThenElseNode>();
unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel);
if (n->else_case) {
height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel));
}
return height;
} else if (stmt->IsInstance<ForNode>()) {
const ForNode* f = stmt.as<ForNode>();
LoopInfo loop_info(start_id, parent_depth + 1);
start_id++;
bool parent_parallel = false;
if (has_parallel) {
loop_info.has_parallel = true;
parent_parallel = true;
} else if (f->kind == ForKind::kParallel) {
// has_parallel for the current loop is being set to 'false' since the
// intrinsic is added outside of the loop. The instrumentation isn't
// allowed for the subsequent nested loops.
loop_info.has_parallel = false;
parent_parallel = true;
}
unsigned height = TraverseLoop(f->body, parent_depth + 1, parent_parallel);
loop_info.height = height;
loops[f] = loop_info;
return height + 1;
} else if (stmt->IsInstance<LetStmtNode>()) {
const LetStmtNode* n = stmt.as<LetStmtNode>();
return TraverseLoop(n->body, parent_depth, has_parallel);
} else if (stmt->IsInstance<AttrStmtNode>()) {
const AttrStmtNode* n = stmt.as<AttrStmtNode>();
return TraverseLoop(n->body, parent_depth, has_parallel);
} else if (stmt->IsInstance<AllocateNode>()) {
const AllocateNode* n = stmt.as<AllocateNode>();
return TraverseLoop(n->body, parent_depth, has_parallel);
} else {
return 0; // inner-most loop
}
}
private:
LoopInfoMap loops;
};
class InstrumentIntrin : public StmtMutator {
public:
InstrumentIntrin(int32_t max_depth, int32_t min_height, bool instr_siblings)
: max_instr_depth_(max_depth),
min_instr_height_(min_height),
instr_siblings_(instr_siblings) {}
void GetLoopInfo(PrimFuncNode* op) {
LoopAnalyzer analzer;
loops_ = analzer.Analyze(op->body);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
return SeqStmt::Flatten(stmt);
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
if (loops_.count(op) < 1) return stmt;
LoopInfo loop_info = loops_[op];
if (loop_info.has_parallel) {
return stmt;
}
// Exclude inner-most loops from instrumentation. The inner-most loop has
// height '0' and it increases as we move outward in the loop nest.
if (loop_info.height < min_instr_height_) {
return stmt;
}
// Only instrument loops with a sibling
if (instr_siblings_ && !loop_info.has_siblings) {
return stmt;
}
// If instr_siblings_ is set, ignore max depth for instrumentation
if (!instr_siblings_ && loop_info.depth > max_instr_depth_) {
return stmt;
}
PrimExpr id = static_cast<int32_t>(loop_info.id);
PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {id});
PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {id});
const Stmt start_profile = Evaluate(start_call);
const Stmt end_profile = Evaluate(end_call);
Stmt new_stmt = SeqStmt({start_profile, stmt, end_profile});
return new_stmt;
}
private:
LoopInfoMap loops_;
int32_t max_instr_depth_;
int32_t min_instr_height_;
bool instr_siblings_;
};
class CheckParallelLoops : public StmtExprVisitor {
public:
bool HasParallelLoops(const Stmt& stmt) {
this->VisitStmt(stmt);
return has_parallel;
}
private:
void VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kParallel) {
has_parallel = true;
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
bool has_parallel = false;
};
PrimFunc AddProfileBuiltins(PrimFunc func, int32_t max_instr_depth, int32_t min_instr_height,
bool instr_siblings, bool disable_func_instrumentation) {
auto* func_ptr = func.CopyOnWrite();
PrimExpr e = start_id++;
if (!disable_func_instrumentation) {
PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {e});
PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {e});
const Stmt start_profile = Evaluate(start_call);
const Stmt end_profile = Evaluate(end_call);
func_ptr->body = SeqStmt({start_profile, std::move(func_ptr->body), end_profile});
}
InstrumentIntrin p(max_instr_depth, min_instr_height, instr_siblings);
p.GetLoopInfo(func_ptr);
func_ptr->body = p(std::move(func_ptr->body));
return func;
}
} // namespace lwp
namespace transform {
Pass InstrumentProfileIntrinsics() {
auto pass_func = [](IRModule m, PassContext ctx) {
auto* mptr = m.CopyOnWrite();
// All loops with depth <= max_instr_depth are instrumented. By default,
// only outer-most loops are instrumented which has a depth of 0.
// In addition, loops with siblings are also instrumented provided
// their loop depth is >= min_instr_height. This is done to avoid
// instrumenting inner-most loops.
auto max_instr_depth = ctx->GetConfig<Integer>("tir.lwp_max_depth", Integer(0)).value();
auto min_instr_height = ctx->GetConfig<Integer>("tir.lwp_min_height", Integer(1)).value();
bool instr_siblings = ctx->GetConfig<Bool>("tir.instr_siblings", Bool(true)).value();
bool disable_func_instrumentation =
ctx->GetConfig<Bool>("tir.lwp_disable_func_prof", Bool(false)).value();
bool reset_start_id = ctx->GetConfig<Bool>("tir.reset_start_id", Bool(false)).value();
if (reset_start_id) lwp::start_id = 0;
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
for (const auto& kv : mptr->functions) {
if (auto func = kv.second.as<PrimFunc>()) {
auto updated_func = lwp::AddProfileBuiltins(func.value(), max_instr_depth.IntValue(),
min_instr_height.IntValue(), instr_siblings,
disable_func_instrumentation);
updates.push_back({kv.first, updated_func});
}
}
for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
}
return m;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.InstrumentProfileIntrinsics", InstrumentProfileIntrinsics);
}
} // namespace transform
} // namespace tir
} // namespace tvm