blob: d8ef66d3db105327370b0cffb8370dd7e2390e11 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_UTILS_H_
#define TVM_META_SCHEDULE_UTILS_H_
#include <dmlc/memory_io.h>
#include <tvm/arith/analyzer.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/meta_schedule/feature_extractor.h>
#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/profiler.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/meta_schedule/task_scheduler.h>
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/node.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/support/parallel_for.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../support/array.h"
#include "../support/base64.h"
#include "../support/nd_int_set.h"
#include "../support/table_printer.h"
#include "../support/utils.h"
#include "../tir/schedule/primitive.h"
#include "../tir/schedule/utils.h"
#define TVM_PY_LOG(logging_level, logger) \
::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logger, \
PyLogMessage::Level::logging_level) \
.stream()
#define TVM_PY_LOG_CLEAR_SCREEN(logging_func) clear_logging(__FILE__, __LINE__, logging_func)
namespace tvm {
namespace meta_schedule {
/*!
* \brief Class to accumulate an log message on the python side. Do not use directly, instead use
* TVM_PY_LOG(DEBUG), TVM_PY_LOG(INFO), TVM_PY_LOG(WARNING), TVM_PY_ERROR(ERROR).
* \sa TVM_PY_LOG
* \sa TVM_PY_LOG_CLEAR_SCREEN
*/
class PyLogMessage {
public:
enum class Level : int32_t {
CLEAR = -10,
DEBUG = 10,
INFO = 20,
WARNING = 30,
ERROR = 40,
// FATAL not included
};
explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level)
: filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {}
TVM_NO_INLINE ~PyLogMessage() {
ICHECK(logging_level_ != Level::CLEAR)
<< "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN.";
if (this->logger_ != nullptr) {
logger_(static_cast<int>(logging_level_), std::string(filename_), lineno_, stream_.str());
} else {
if (logging_level_ == Level::INFO) {
runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_INFO).stream()
<< stream_.str();
} else if (logging_level_ == Level::WARNING) {
runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_WARNING).stream()
<< stream_.str();
} else if (logging_level_ == Level::ERROR) {
runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_ERROR).stream()
<< stream_.str();
} else if (logging_level_ == Level::DEBUG) {
runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_DEBUG).stream()
<< stream_.str();
} else {
runtime::detail::LogFatal(filename_, lineno_).stream() << stream_.str();
}
}
}
std::ostringstream& stream() { return stream_; }
private:
const char* filename_;
int lineno_;
std::ostringstream stream_;
PackedFunc logger_;
Level logging_level_;
};
/*!
* \brief Whether the tuning is running on ipython kernel.
* \return A boolean indicating whether ipython kernel is used.
*/
inline bool using_ipython() {
bool flag = false;
const auto* f_using_ipython = runtime::Registry::Get("meta_schedule.using_ipython");
if (f_using_ipython) {
flag = (*f_using_ipython)();
}
return flag;
}
/*!
* \brief Print out the performance table interactively in jupyter notebook.
* \param str The serialized performance table.
*/
inline void print_interactive_table(const String& data) {
const auto* f_print_interactive_table =
runtime::Registry::Get("meta_schedule.print_interactive_table");
ICHECK(f_print_interactive_table->defined())
<< "Cannot find print_interactive_table function in registry.";
(*f_print_interactive_table)(data);
}
/*!
* \brief A helper function to clear logging output for ipython kernel and console.
* \param file The file name.
* \param lineno The line number.
* \param logging_func The logging function.
*/
inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) {
if (const char* env_p = std::getenv("TVM_META_SCHEDULE_CLEAR_SCREEN")) {
if (std::string(env_p) == "1") {
if (logging_func.defined() && using_ipython()) {
logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, "");
} else {
// this would clear all logging output in the console
runtime::detail::LogMessage(file, lineno, TVM_LOG_LEVEL_INFO).stream()
<< "\033c\033[3J\033[2J\033[0m\033[H";
}
}
}
}
/*! \brief The type of the random state */
using TRandState = support::LinearCongruentialEngine::TRandState;
/*!
* \brief Get the base64 encoded result of a string.
* \param str The string to encode.
* \return The base64 encoded string.
*/
inline std::string Base64Encode(std::string str) {
std::string result;
dmlc::MemoryStringStream m_stream(&result);
support::Base64OutStream b64stream(&m_stream);
static_cast<dmlc::Stream*>(&b64stream)->Write(str);
b64stream.Finish();
return result;
}
/*!
* \brief Get the base64 decoded result of a string.
* \param str The string to decode.
* \return The base64 decoded string.
*/
inline std::string Base64Decode(std::string str) {
std::string result;
dmlc::MemoryStringStream m_stream(&str);
support::Base64InStream b64stream(&m_stream);
b64stream.InitPosition();
static_cast<dmlc::Stream*>(&b64stream)->Read(&result);
return result;
}
/*!
* \brief Parses a json string into a json object.
* \param json_str The json string.
* \return The json object
*/
ObjectRef JSONLoads(std::string json_str);
/*!
* \brief Dumps a json object into a json string.
* \param json_obj The json object.
* \return The json string
*/
std::string JSONDumps(ObjectRef json_obj);
/*!
* \brief Converts a structural hash code to string
* \param hash_code The hash code
* \return The string representation of the hash code
*/
inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); }
/*!
* \brief Converts an TVM object to the hex string representation of its structural hash.
* \param obj The TVM object.
* \return The hex string representation of the hash code.
*/
inline String SHash2Hex(const ObjectRef& obj) {
std::ostringstream os;
size_t hash_code = 0;
if (obj.defined()) {
hash_code = StructuralHash()(obj);
}
os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code;
return os.str();
}
/*!
* \brief Fork a random state into another, i.e. PRNG splitting.
* The given random state is also mutated.
* \param rand_state The random state to be forked
* \return The forked random state
*/
inline support::LinearCongruentialEngine::TRandState ForkSeed(
support::LinearCongruentialEngine::TRandState* rand_state) {
return support::LinearCongruentialEngine(rand_state).ForkSeed();
}
/*!
* \brief Fork a random state into another ones, i.e. PRNG splitting.
* The given random state is also mutated.
* \param rand_state The random state to be forked
* \param n The number of forks
* \return The forked random states
*/
inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
support::LinearCongruentialEngine::TRandState* rand_state, int n) {
std::vector<support::LinearCongruentialEngine::TRandState> results;
results.reserve(n);
for (int i = 0; i < n; ++i) {
results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed());
}
return results;
}
/*!
* \brief Get deep copy of an IRModule.
* \param mod The IRModule to make a deep copy.
* \return The deep copy of the IRModule.
*/
inline IRModule DeepCopyIRModule(IRModule mod) {
return Downcast<IRModule>(LoadJSON(SaveJSON(mod)));
}
/*!
* \brief Concatenate strings
* \param strs The strings to concatenate
* \param delim The delimiter
* \return The concatenated string
*/
inline std::string Concat(const Array<String>& strs, const std::string& delim) {
if (strs.empty()) {
return "";
}
std::ostringstream os;
os << strs[0];
for (int i = 1, n = strs.size(); i < n; ++i) {
os << delim << strs[i];
}
return os.str();
}
/*!
* \brief Get the BlockRV from a block StmtSRef
* \param sch The schedule
* \param block_sref The block StmtSRef
* \param global_var_name The global variable name
* \return The BlockRV
*/
inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
const String& global_var_name) {
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
return sch->GetBlock(block->name_hint, global_var_name);
}
/*!
* \brief A helper data structure that replays a trace and collects failure counts
* for each postprocessor
*/
struct ThreadedTraceApply {
/*! \brief Constructor */
explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
: n_(postprocs.size()), items_(new Item[n_]) {
for (int i = 0; i < n_; ++i) {
items_[i].postproc = postprocs[i];
items_[i].fail_counter = 0;
}
}
/*! \brief Destructor */
~ThreadedTraceApply() { delete[] items_; }
/*!
* \brief Apply the trace and postprocessors to an IRModule
* \param mod The IRModule to be applied
* \param trace The trace to apply to the IRModule
* \param rand_state The random seed
* \return The schedule created, or NullOpt if any postprocessor fails
*/
Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace,
TRandState* rand_state) {
tir::Schedule sch =
tir::Schedule::Traced(mod,
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
if (!item.postproc->Apply(sch)) {
item.fail_counter++;
return NullOpt;
}
}
return sch;
}
/*! \brief Returns a string summarizing the failures on each postprocessor */
std::string SummarizeFailures() const {
std::ostringstream os;
for (int i = 0; i < n_; ++i) {
const Item& item = items_[i];
os << "Postproc #" << i << " [" << item.postproc //
<< "]: " << item.fail_counter.load() << " failure(s)";
if (i != n_ - 1) {
os << "\n";
}
}
return os.str();
}
private:
/*! \brief A helper data structure that stores the fail count for each postprocessor. */
struct Item {
/*! \brief The postprocessor. */
Postproc postproc{nullptr};
/*! \brief The thread-safe postprocessor failure counter. */
std::atomic<int> fail_counter{0};
};
/*! \brief The number of total postprocessors. */
int n_;
/*! \brief The pointer to the list of postprocessor items. */
Item* items_;
};
/*!
* \brief Get the number of cores in CPU
* \param target The target
* \return The number of cores.
*/
inline int GetTargetNumCores(const Target& target) {
int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1).IntValue();
if (num_cores == -1) {
static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count");
ICHECK(f_cpu_count)
<< "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\"";
num_cores = (*f_cpu_count)(false);
LOG(FATAL)
<< "Target does not have attribute \"num-cores\", physical core number must be "
"defined! For example, on the local machine, the target must be \"llvm -num-cores "
<< num_cores << "\"";
}
return num_cores;
}
/*!
* \brief Get the median of the running time from RunnerResult in millisecond
* \param results The results from RunnerResult
* \return The median of the running time in millisecond
*/
inline double GetRunMsMedian(const RunnerResult& runner_result) {
Array<FloatImm> run_secs = runner_result->run_secs.value();
ICHECK(!run_secs.empty());
std::vector<double> v;
v.reserve(run_secs.size());
std::transform(run_secs.begin(), run_secs.end(), std::back_inserter(v),
[](const FloatImm& f) -> double { return f->value; });
std::sort(v.begin(), v.end());
int n = v.size();
if (n % 2 == 0) {
return (v[n / 2 - 1] + v[n / 2]) * 0.5 * 1000.0;
} else {
return v[n / 2] * 1000.0;
}
}
/*!
* \brief Convert the given object to an array of floating point numbers
* \param obj The object to be converted
* \return The array of floating point numbers
*/
inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) {
const ArrayNode* arr = obj.as<ArrayNode>();
ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
Array<FloatImm> results;
results.reserve(arr->size());
for (const ObjectRef& elem : *arr) {
if (const auto* int_imm = elem.as<IntImmNode>()) {
results.push_back(FloatImm(DataType::Float(32), int_imm->value));
} else if (const auto* float_imm = elem.as<FloatImmNode>()) {
results.push_back(FloatImm(DataType::Float(32), float_imm->value));
} else {
LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey();
}
}
return results;
}
/*!
* \brief Convert the given object to an array of integers
* \param obj The object to be converted
* \return The array of integers
*/
inline Array<Integer> AsIntArray(const ObjectRef& obj) {
const ArrayNode* arr = obj.as<ArrayNode>();
ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
Array<Integer> results;
results.reserve(arr->size());
for (const ObjectRef& elem : *arr) {
if (const auto* int_imm = elem.as<IntImmNode>()) {
results.push_back(Integer(int_imm->value));
} else {
LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey();
}
}
return results;
}
/*! \brief The struct defining comparison function of sorting by mean run seconds. */
struct SortTuningRecordByMeanRunSecs {
static const constexpr double kMaxMeanTime = 1e10;
static double Mean(const Array<FloatImm>& a) {
if (a.empty()) {
return kMaxMeanTime;
}
double sum = 0.0;
for (const FloatImm& i : a) {
sum += i->value;
}
return sum / a.size();
}
bool operator()(const TuningRecord& a, const TuningRecord& b) const {
double a_time = Mean(a->run_secs.value_or({}));
double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};
/*!
* \brief The helper function to clone schedule rules, postprocessors, and mutators.
* \param src The source space generator.
* \param dst The destination space generator.
*/
inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) {
if (src->sch_rules.defined()) {
Array<ScheduleRule> original = src->sch_rules.value();
Array<ScheduleRule> sch_rules;
sch_rules.reserve(original.size());
for (const ScheduleRule& sch_rule : original) {
sch_rules.push_back(sch_rule->Clone());
}
dst->sch_rules = std::move(sch_rules);
}
if (src->postprocs.defined()) {
Array<Postproc> original = src->postprocs.value();
Array<Postproc> postprocs;
postprocs.reserve(original.size());
for (const Postproc& postproc : original) {
postprocs.push_back(postproc->Clone());
}
dst->postprocs = std::move(postprocs);
}
if (src->mutator_probs.defined()) {
Map<Mutator, FloatImm> original = src->mutator_probs.value();
Map<Mutator, FloatImm> mutator_probs;
for (const auto& kv : original) {
mutator_probs.Set(kv.first->Clone(), kv.second);
}
dst->mutator_probs = std::move(mutator_probs);
}
}
/*! \brief Returns true if the given target is one of the supported gpu targets. */
inline bool IsGPUTarget(const std::string& target_name) {
static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm", "vulkan", "metal"};
return gpu_targets.count(target_name);
}
/*!
* \brief Create an AutoInline schedule rule for the given target.
* \param target_name The name of the target ("llvm", "cuda", etc.)
* \return The AutoInline schedule rule for the given target.
*/
inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) {
Array<ScheduleRule> rules{nullptr};
if (target_name == "llvm") {
rules = ScheduleRule::DefaultLLVM();
} else if (target_name == "hexagon") {
rules = ScheduleRule::DefaultHexagon();
} else if (target_name == "c") {
rules = ScheduleRule::DefaultMicro();
} else if (IsGPUTarget(target_name)) {
rules = ScheduleRule::DefaultCUDA();
} else {
LOG(FATAL) << "ValueError: Unsupported target: " << target_name;
}
for (const ScheduleRule& rule : rules) {
if (rule->GetTypeKey() == "meta_schedule.AutoInline") {
return rule;
}
}
LOG(FATAL) << "ValueError: AutoInline rule is not found in the default rules for target: "
<< target_name;
throw;
}
/*!
* \brief Summarize the run time of the given FloatImm array.
* \param arr The array of FloatImm.
* \return The summary of the values in the given array.
*/
inline double Sum(const Array<FloatImm>& arr) {
double sum = 0;
for (const FloatImm& f : arr) {
sum += f->value;
}
return sum;
}
} // namespace meta_schedule
} // namespace tvm
#endif // TVM_META_SCHEDULE_UTILS_H_