blob: 6701308fbfb7858a720d778dc4220fb8fe87f614 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/instrument.cc
* \brief Infrastructure for instrumentation.
*/
#include <dmlc/thread_local.h>
#include <tvm/ir/instrument.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/registry.h>
#include <stack>
namespace tvm {
namespace instrument {
/*!
* \brief Base PassInstrument implementation
* \sa BasePassInstrument
*/
class BasePassInstrumentNode : public PassInstrumentNode {
public:
/*! \brief Callback to run when entering PassContext. */
runtime::TypedPackedFunc<void()> enter_pass_ctx_callback;
/*! \brief Callback to run when exiting PassContext. */
runtime::TypedPackedFunc<void()> exit_pass_ctx_callback;
/*! \brief Callback determines whether to run a pass or not. */
runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback;
/*! \brief Callback to run before a pass. */
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_before_pass_callback;
/*! \brief Callback to run after a pass. */
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_after_pass_callback;
/*! \brief Instrument when entering PassContext. */
void EnterPassContext() const final;
/*! \brief Instrument when exiting PassContext. */
void ExitPassContext() const final;
/*!
* \brief Determine whether to run the pass or not.
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*
* \return true to run the pass; false to skip the pass.
*/
bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final;
/*!
* \brief Instrument before pass run.
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*/
void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final;
/*!
* \brief Instrument after pass run.
*
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*/
void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final;
static constexpr const char* _type_key = "instrument.PassInstrument";
TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode);
};
/*!
* \brief Managed reference class for BasePassInstrumentNode
* \sa BasePassInstrumentNode
*/
class BasePassInstrument : public PassInstrument {
public:
/*!
* \brief Constructor
*
* \param name Name for this instrumentation.
*
*
* \param enter_pass_ctx_callback Callback to call when entering pass context.
* \param exit_pass_ctx_callback Callback to call when exiting pass context.
*
* \param should_run_callback Callback to determine whether pass should run. (return true: enable;
* return false: disable)
*
* \param run_before_pass_callback Callback to call before a pass run.
* \param run_after_pass_callback Callback to call after a pass run.
*/
TVM_DLL BasePassInstrument(
String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback,
runtime::TypedPackedFunc<void()> exit_pass_ctx_callback,
runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)>
should_run_callback,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_before_pass_callback,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_after_pass_callback);
TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode);
};
BasePassInstrument::BasePassInstrument(
String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback,
runtime::TypedPackedFunc<void()> exit_pass_ctx_callback,
runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_before_pass_callback,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_after_pass_callback) {
auto pi = make_object<BasePassInstrumentNode>();
pi->name = std::move(name);
pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback);
pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback);
pi->should_run_callback = std::move(should_run_callback);
pi->run_before_pass_callback = std::move(run_before_pass_callback);
pi->run_after_pass_callback = std::move(run_after_pass_callback);
data_ = std::move(pi);
}
void BasePassInstrumentNode::EnterPassContext() const {
if (enter_pass_ctx_callback != nullptr) {
enter_pass_ctx_callback();
}
}
void BasePassInstrumentNode::ExitPassContext() const {
if (exit_pass_ctx_callback != nullptr) {
exit_pass_ctx_callback();
}
}
bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module,
const transform::PassInfo& pass_info) const {
if (should_run_callback == nullptr) {
return true;
}
return should_run_callback(ir_module, pass_info);
}
void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module,
const transform::PassInfo& pass_info) const {
if (run_before_pass_callback != nullptr) {
run_before_pass_callback(ir_module, pass_info);
}
}
void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module,
const transform::PassInfo& pass_info) const {
if (run_after_pass_callback != nullptr) {
run_after_pass_callback(ir_module, pass_info);
}
}
TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode);
TVM_REGISTER_GLOBAL("instrument.PassInstrument")
.set_body_typed(
[](String name, runtime::TypedPackedFunc<void()> enter_pass_ctx,
runtime::TypedPackedFunc<void()> exit_pass_ctx,
runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_before_pass,
runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
run_after_pass) {
return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run,
run_before_pass, run_after_pass);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BasePassInstrumentNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const BasePassInstrumentNode*>(ref.get());
p->stream << node->name;
});
/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */
struct PassProfile {
// TODO(@altanh): expose PassProfile through TVM Object API
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double, std::micro>;
using Time = std::chrono::time_point<Clock>;
/*! \brief The name of the pass being profiled. */
String name;
/*! \brief The time when the pass was entered. */
Time start;
/*! \brief The time when the pass completed. */
Time end;
/*! \brief The total duration of the pass, i.e. end - start. */
Duration duration;
/*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */
std::vector<PassProfile> children;
explicit PassProfile(String name)
: name(name), start(Clock::now()), end(Clock::now()), children() {}
/*! \brief Gets the PassProfile of the currently executing pass. */
static PassProfile* Current();
/*! \brief Pushes a new PassProfile with the given pass name. */
static void EnterPass(String name);
/*! \brief Pops the current PassProfile. */
static void ExitPass();
};
struct PassProfileThreadLocalEntry {
/*! \brief The placeholder top-level PassProfile. */
PassProfile root;
/*! \brief The stack of PassProfiles for nested passes currently running. */
std::stack<PassProfile*> profile_stack;
PassProfileThreadLocalEntry() : root("root") {}
};
/*! \brief Thread local store to hold the pass profiling data. */
typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore;
void PassProfile::EnterPass(String name) {
PassProfile* cur = PassProfile::Current();
cur->children.emplace_back(name);
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
}
void PassProfile::ExitPass() {
PassProfile* cur = PassProfile::Current();
ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
cur->end = PassProfile::Clock::now();
cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start);
PassProfileThreadLocalStore::Get()->profile_stack.pop();
}
PassProfile* PassProfile::Current() {
PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
if (!entry->profile_stack.empty()) {
return entry->profile_stack.top();
} else {
return &entry->root;
}
}
String RenderPassProfiles() {
PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!";
if (entry->root.children.empty()) {
LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?";
return String();
}
// (depth, parent_duration, pass)
std::stack<std::tuple<size_t, PassProfile::Duration, PassProfile*>> profiles;
// push top level passes
PassProfile::Duration top_dur(0);
for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) {
top_dur += it->duration;
}
for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) {
profiles.push(std::make_tuple(0, top_dur, &*it));
}
std::ostringstream os;
os << std::fixed;
while (profiles.size() > 0) {
auto [depth, parent_duration, profile] = profiles.top();
profiles.pop();
// indent depth
for (size_t i = 0; i < depth; ++i) {
os << "\t";
}
// calculate time spent in pass itself (excluding sub-passes), and push children
PassProfile::Duration self_duration = profile->duration;
for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) {
self_duration -= it->duration;
profiles.push(std::make_tuple(depth + 1, profile->duration, &*it));
}
double parent_pct = profile->duration.count() / parent_duration.count() * 100.0;
double total_pct = profile->duration.count() / top_dur.count() * 100.0;
os << profile->name << ": ";
os << std::setprecision(0);
os << profile->duration.count() << "us [" << self_duration.count() << "us] ";
os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n";
}
return os.str();
}
TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles);
TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() {
auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) {
PassProfile::EnterPass(pass_info->name);
return true;
};
auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) {
PassProfile::ExitPass();
};
auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); };
return BasePassInstrument("PassTimingInstrument",
/* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr,
run_before_pass, run_after_pass);
});
} // namespace instrument
} // namespace tvm