blob: 610fec96617afff2e761f8dc324c7319b4f82e9a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/utils.h
* \brief Common utilities.
*/
#ifndef TVM_AUTO_SCHEDULER_UTILS_H_
#define TVM_AUTO_SCHEDULER_UTILS_H_
#include <dmlc/common.h>
#include <tvm/tir/expr.h>
#include <algorithm>
#include <deque>
#include <exception>
#include <future>
#include <numeric>
#include <random>
#include <string>
#include <thread>
#include <tuple>
#include <utility>
#include <vector>
namespace std {
/*! \brief Hash function for std::pair */
template <typename T1, typename T2>
struct hash<std::pair<T1, T2>> {
std::size_t operator()(const std::pair<T1, T2>& k) const {
return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
}
};
/*! \brief Hash function for std::tuple */
template <typename T1, typename T2, typename T3>
struct hash<std::tuple<T1, T2, T3>> {
std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
return ::dmlc::HashCombine(
::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))),
std::hash<T3>()(std::get<2>(k)));
}
};
} // namespace std
namespace tvm {
namespace auto_scheduler {
/********** Utilities for Array, std::vector, std::string **********/
/*! \brief Get the first appearance index of elements in an Array */
template <typename T>
inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
for (const auto& v : to_locate) {
auto it = std::find(array.begin(), array.end(), v);
if (it != array.end()) {
indices->push_back(it - array.begin());
} else {
LOG(FATAL) << "Cannot find the item";
}
}
}
/*! \brief Get the first appearance index of an element in an Array */
template <typename T>
inline int GetIndex(const Array<T>& array, const T& to_locate) {
for (size_t i = 0; i < array.size(); ++i) {
if (array[i] == to_locate) {
return i;
}
}
LOG(FATAL) << "Cannot find the item";
return -1;
}
/*! \brief Delete the item in a std::vector if it exists. */
template <typename T>
inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) {
auto iter = std::find(array->begin(), array->end(), to_delete);
if (iter != array->end()) {
array->erase(iter);
}
}
/*! \brief Compute the product of all elements in a vector */
inline int64_t ElementProduct(const std::vector<int>& array) {
int64_t ret = 1;
for (auto x : array) {
ret *= x;
}
return ret;
}
/*! \brief Move elements from multiple vectors to one vector */
template <typename T>
std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* in) {
out->insert(out->end(), std::make_move_iterator(in->begin()), std::make_move_iterator(in->end()));
return *out;
}
/*! \brief Move elements from multiple vectors to one vector */
template <typename T, typename... Args>
std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* first, Args... args) {
ConcatenateMove(out, first);
ConcatenateMove(out, args...);
return *out;
}
/*! \brief Get a random permutation of integers [0, n-1] */
template <typename G>
void RandomPermutation(int n, std::vector<int>* out, G* gen) {
out->assign(n, 0);
std::iota(out->begin(), out->end(), 0);
std::shuffle(out->begin(), out->end(), *gen);
}
/*! \brief Replace a sub-string to another sub-string in a string */
inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
auto pos = base->find(from);
while (pos != std::string::npos) {
base->replace(pos, from.size(), to);
pos = base->find(from, pos + to.size());
}
}
/*! \brief Return whether two int arrays are elementwise-equal */
inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
if (arr1.size() != arr2.size()) {
return false;
}
for (size_t i = 0; i < arr1.size(); ++i) {
auto int1 = arr1[i].as<IntImmNode>();
auto int2 = arr2[i].as<IntImmNode>();
CHECK(int1 != nullptr);
CHECK(int2 != nullptr);
if (int1->value != int2->value) {
return false;
}
}
return true;
}
/********** Utilities for TVM Containers / ByteArray **********/
/*! \brief Compute mean of a FloatImm array */
inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
double sum = 0;
if (float_array.empty()) {
return 0.0;
}
for (const auto& x : float_array) {
auto floatimm = x.as<tir::FloatImmNode>();
CHECK(floatimm != nullptr);
sum += floatimm->value;
}
return sum / float_array.size();
}
/*! \brief Return whether a string starts with another substring */
inline bool StrStartsWith(const String& a, const String& b) {
if (b.size() > a.size()) return false;
return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str());
}
/*! \brief Return whether a string ends with another substring */
inline bool StrEndsWith(const String& a, const String& b) {
if (b.size() > a.size()) return false;
return std::equal(a.c_str() + a.size() - b.size(), a.c_str() + a.size(), b.c_str());
}
/********** Other Utilities **********/
/*! \brief Get an int value from an Expr */
inline int64_t GetIntImm(const PrimExpr& expr) {
auto pint = expr.as<IntImmNode>();
CHECK(pint != nullptr);
return pint->value;
}
/*! \brief Compute the product of the lengths of axes */
inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
int64_t ret = 1.0;
for (const auto& x : axes) {
if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) {
ret *= imm->value;
} else {
return -1.0;
}
}
return ret;
}
/*!
* \brief Clean the name of an iterator or an op to make it valid in python code.
* \param str The original name.
* \param prefix The name prefix to differentiate the same name (e.g., the same iterator names).
* \return The cleaned name.
*/
inline std::string CleanName(const std::string& str, const std::string& prefix = "") {
std::string ret = str;
StrReplace(&ret, ".", "_");
StrReplace(&ret, "@", "_");
StrReplace(&ret, "outer", "o");
StrReplace(&ret, "inner", "i");
if (prefix != "") {
return prefix + "_" + ret;
}
return ret;
}
/*! \brief An empty output stream */
class NullStream : public std::ostream {
public:
NullStream() : std::ostream(nullptr) {}
NullStream(const NullStream&) : std::ostream(nullptr) {}
static NullStream& Global();
};
template <class T>
NullStream& operator<<(NullStream& os, const T& value) {
return os;
}
/*! \brief Get std cout with verbose control */
inline std::ostream& StdCout(int verbose, int setting = 1) {
return verbose >= setting ? std::cout : NullStream::Global();
}
/*! \brief Print multiple chars */
inline std::string Chars(const char& str, int times) {
std::stringstream ret;
for (int i = 0; i < times; ++i) {
ret << str;
}
return ret.str();
}
/*!
* \brief Parse shape and axis names from layout string
*/
inline void ParseKernelLayout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
for (char c : std::string(layout)) {
if (c >= 'A' && c <= 'z') {
axis += c;
if (factor != 0) {
shape->push_back(factor);
factor = 0;
}
} else if (c >= '0' && c <= '9') {
factor = factor * 10 + c - '0';
if (!axis.empty()) {
axes->push_back(axis);
axis = "";
}
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
if (!axis.empty()) {
axes->push_back(axis);
}
}
/*! \brief Get the base name before '_' of an axis */
inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
} // namespace auto_scheduler
} // namespace tvm
#endif // TVM_AUTO_SCHEDULER_UTILS_H_