blob: 9e72eeb3f0c2eb7f299807cea663ff301ee7e2ad [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_policy/utils.cc
* \brief Common utilities
*/
#include "utils.h"
#include <algorithm>
namespace tvm {
namespace auto_scheduler {
Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
const auto& stage = s->stages[stage_id];
const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
CHECK(pop != nullptr);
const std::set<std::string>& no_split_at_inner_name_set =
stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();
size_t reduce_count = 0;
for (const auto axis : pop->reduce_axis) {
if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
reduce_count++;
}
}
Array<Integer> spatial_split_step_ids;
for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (stage_id > s->transform_steps[i]->stage_id) {
stage_id--;
}
} else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
if (stage_id == ps->stage_id) {
// Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
if (reduce_count) {
reduce_count--;
} else {
spatial_split_step_ids.push_back(i);
}
}
}
}
return spatial_split_step_ids;
}
std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
const State& state, int stage_id) {
int target_stage_id = GetSingleConsumerId(task, state, stage_id);
if (target_stage_id < 0) {
return {};
}
const Stage& target_stage = state->stages[target_stage_id];
std::vector<std::pair<int, int>> candidates;
bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
bool target_is_tiled = IsTiled(target_stage);
bool visited_reduce = false;
// Enumerate compute_at location at target_stage
// TODO(merrymercy): More analysis here to make smarter choices
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
const Iterator& target_iter = target_stage->iters[i];
if (target_iter->iter_kind == IteratorKind::kReduction) {
visited_reduce = true;
if (!target_is_tiled) { // Do not go into reduce iter
break;
}
} else if (target_iter->iter_kind == IteratorKind::kSpatial) {
if (visited_reduce) { // Do not go into inner tile
break;
}
}
if (target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}
if (GetExtent(target_iter) == 1) {
// Skip iterators with length of 1
continue;
}
if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
StrEndsWith(target_iter->name, ".0")) {
// Skip the first level iterators if target stage compute_at another stage
// In this case, the lengths of first level iterators are always one
continue;
}
candidates.emplace_back(target_stage_id, i);
if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
break;
}
}
// if the target_stage is already compute_at another stage X, try also compute_at X
// We call stage X as `target_target_stage`
if (target_compute_at_other) {
int target_target_stage_id;
target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
const Stage& target_target_stage = state->stages[target_target_stage_id];
for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
const Iterator& target_target_iter = target_target_stage->iters[i];
if (target_target_iter->iter_kind == IteratorKind::kReduction ||
state->attach_map->iter_to_attached_stages.count(
std::make_pair(target_target_stage_id, i))) {
break;
}
if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}
if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
continue;
}
candidates.emplace_back(target_target_stage_id, i);
}
}
return candidates;
}
State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
std::vector<int>* spatial_split_step_ids) {
// Temporal object to be used if the input pointer is nullptr
std::vector<int> temp_split_step_ids;
if (spatial_split_step_ids == nullptr) {
spatial_split_step_ids = &temp_split_step_ids;
}
std::vector<std::vector<Iterator>> space_levels;
std::vector<std::vector<Iterator>> reduce_levels;
std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
Array<Iterator> split_res;
for (const auto c : format) {
if (tolower(c) == 's') {
space_levels.emplace_back();
} else if (tolower(c) == 'r') {
reduce_levels.emplace_back();
} else {
LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
}
size_t n_space = space_levels.size();
size_t n_reduce = reduce_levels.size();
spatial_split_step_ids->clear();
State tmp_s = state;
const Stage& stage = state->stages[stage_id];
const std::set<std::string>& no_split_at_inner_name_set =
stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();
for (const auto& iter : state->stages[stage_id]->iters) {
if (!no_split_at_inner_name_set.count(iter->name)) {
if (iter->iter_kind == IteratorKind::kSpatial) {
CHECK_GE(n_space, 1);
if (n_space == 1) {
space_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
for (size_t i = 0; i < n_space; i++) {
space_levels[i].push_back(split_res[i]);
}
spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
}
} else if (iter->iter_kind == IteratorKind::kReduction) {
CHECK_GE(n_reduce, 1);
if (n_reduce == 1) {
reduce_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
for (size_t i = 0; i < n_reduce; i++) {
reduce_levels[i].push_back(split_res[i]);
}
}
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
} else {
if (iter->iter_kind == IteratorKind::kSpatial) {
space_inner.push_back(iter);
} else if (iter->iter_kind == IteratorKind::kReduction) {
reduce_inner.push_back(iter);
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
}
}
if (!space_outer.empty()) {
CHECK(!space_levels.empty());
space_levels.front().insert(space_levels.front().begin(),
std::make_move_iterator(space_outer.begin()),
std::make_move_iterator(space_outer.end()));
}
if (!space_inner.empty()) {
CHECK(!space_levels.empty());
space_levels.back().insert(space_levels.back().begin(),
std::make_move_iterator(space_inner.begin()),
std::make_move_iterator(space_inner.end()));
}
if (!reduce_outer.empty()) {
CHECK(!reduce_levels.empty());
reduce_levels.front().insert(reduce_levels.front().begin(),
std::make_move_iterator(reduce_outer.begin()),
std::make_move_iterator(reduce_outer.end()));
}
if (!reduce_inner.empty()) {
CHECK(!reduce_levels.empty());
reduce_levels.back().insert(reduce_levels.back().begin(),
std::make_move_iterator(reduce_inner.begin()),
std::make_move_iterator(reduce_inner.end()));
}
Array<Iterator> order;
int space_ct = 0, reduce_ct = 0;
for (const auto c : format) {
if (tolower(c) == 's') {
order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
std::make_move_iterator(space_levels[space_ct].end()));
space_ct++;
} else if (tolower(c) == 'r') {
order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
std::make_move_iterator(reduce_levels[reduce_ct].end()));
reduce_ct++;
} else {
LOG(FATAL) << "Invalid multi level tiling format: " << format;
}
}
tmp_s.reorder(stage_id, order);
return tmp_s;
}
State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
int n_split) {
if (n_split < 1 || n_split > 3) {
LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3";
}
// Apply up to three-level tiling structure: space_L0, space_L1, space_L2
std::vector<Iterator> space_0, space_1, space_2, space_3, tmp_order;
Array<Iterator> split_res;
auto pop = state->stages[stage_id]->op.as<te::ComputeOpNode>();
CHECK(pop != nullptr);
const Stage& stage = state->stages[stage_id];
const std::set<std::string>& no_split_at_inner_name_set =
stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();
int no_split_at_inner_name_in_stage_cnt = 0;
for (const auto& iter : state->stages[stage_id]->iters) {
no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name);
}
CHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt,
split_step_ids.size());
State tmp_s = state;
int ct = 0;
for (const auto& iter : state->stages[stage_id]->iters) {
if (iter->iter_kind == IteratorKind::kSpatial) {
// For spatial iterator, split it into multi iterators
if (!no_split_at_inner_name_set.count(iter->name)) {
IteratorAnnotation ann_type = iter->annotation;
split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split);
// Restore annotation. Move unroll and vectorize to inner, move parallel
// to outer
switch (ann_type) {
case IteratorAnnotation::kUnroll:
split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split]));
break;
case IteratorAnnotation::kVectorize:
split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split]));
break;
case IteratorAnnotation::kParallel:
split_res.Set(0, tmp_s.parallel(stage_id, split_res[0]));
break;
default:
break;
}
space_0.push_back(split_res[0]);
space_1.push_back(split_res[1]);
if (n_split >= 2) {
space_2.push_back(split_res[2]);
if (n_split == 3) {
space_3.push_back(split_res[3]);
}
}
ct++;
} else {
if (no_split_at_inner_name_set.count(iter->name)) {
if (n_split == 1) {
space_1.push_back(iter);
} else if (n_split == 2) {
space_2.push_back(iter);
} else {
CHECK_EQ(n_split, 3);
space_3.push_back(iter);
}
}
}
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
}
if (n_split == 3) {
ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3);
} else if (n_split == 2) {
ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2);
} else {
ConcatenateMove(&tmp_order, &space_0, &space_1);
}
tmp_s.reorder(stage_id, tmp_order);
return tmp_s;
}
// Return whether a state has nested parallel, which is invalid on CPUs
bool HasNestedParallel(const State& state) {
std::function<void(int stage_id, size_t*)> count_parallel_ct;
count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) {
const Stage& stage = state->stages[stage_id];
if (stage->compute_at == ComputeAtKind::kInlined) {
return;
}
for (size_t i = 0; i < stage->iters.size(); ++i) {
if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) {
(*parallel_ct)++;
}
IterKey iter_key(stage_id, i);
auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
if (pair != state->attach_map->iter_to_attached_stages.end()) {
for (const auto& attach_stage_id : pair->second) {
count_parallel_ct(attach_stage_id, parallel_ct);
}
}
}
};
for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) {
size_t parallel_ct = 0;
if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) {
count_parallel_ct(stage_id, &parallel_ct);
if (parallel_ct >= 2) {
return true;
}
}
}
return false;
}
void PruneInvalidState(const SearchTask& task, Array<State>* states) {
size_t pt = 0;
for (size_t i = 0; i < states->size(); ++i) {
if (!(*states)[i].defined()) {
continue;
}
if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) {
continue;
}
if (i != pt) {
states->Set(pt, (*states)[i]);
}
pt++;
}
if (pt == 0) {
LOG(FATAL) << "Internal error: All states are invalid.";
} else {
states->resize(pt);
}
}
/********** SplitFactorizationMemo **********/
void SplitFactorizationMemo::ReadWriteLock::GetRead() {
std::unique_lock<std::mutex> lock(cv_mutex_);
// Wake up and get the mutex lock if there's no writing thread
cv_.wait(lock, [this]() { return !this->is_writing_; });
read_count_++;
}
void SplitFactorizationMemo::ReadWriteLock::GetWrite() {
std::unique_lock<std::mutex> lock(cv_mutex_);
// Wake up and get the mutex lock if there's no reading or writing threads
cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; });
is_writing_ = true;
}
void SplitFactorizationMemo::ReadWriteLock::UnlockRead() {
std::lock_guard<std::mutex> lock(cv_mutex_);
read_count_--;
// Notify the other blocked threads if this is the last reading thread
if (read_count_ == 0) {
cv_.notify_one();
}
}
void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() {
std::lock_guard<std::mutex> lock(cv_mutex_);
is_writing_ = false;
// Notify the other blocked threads
cv_.notify_one();
}
const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
int extent, int n_lengths, int max_innermost_factor) {
QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
const auto& const_memory = memory_;
lock_.GetRead();
const auto& it = const_memory.find(key);
const auto& memory_end = const_memory.end();
lock_.UnlockRead();
if (it != memory_end) {
return it->second;
}
lock_.GetWrite();
tmp_stack_ = Array<Integer>(n_lengths, Integer());
results_ = &memory_[key];
n_lengths_ = n_lengths;
DfsEnumerate(0, extent, max_innermost_factor);
lock_.UnlockWrite();
return *results_;
}
void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_length, int max_innermost_factor) {
if (now == n_lengths_) {
if (tmp_stack_.back().as<IntImmNode>()->value <= max_innermost_factor) {
results_->push_back(tmp_stack_);
}
} else {
for (const auto& f : GetFactors(remaining_length)) {
tmp_stack_.Set(now, Integer(f));
DfsEnumerate(now + 1, remaining_length / f, max_innermost_factor);
}
}
}
const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {
auto it = factor_memory_.find(n);
if (it != factor_memory_.end()) {
return it->second;
}
std::vector<int>& res = factor_memory_[n];
int step = n % 2 == 0 ? 1 : 2;
for (size_t i = 1; i < static_cast<size_t>(std::sqrt(n)) + 1; i += step) {
if (n % i == 0) {
res.push_back(i);
if (n / i != i) {
res.push_back(n / i);
}
}
}
std::sort(res.begin(), res.end());
return res;
}
/********** Utils interface API for ffi **********/
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
.set_body_typed([](const Stage& stage) { return IsTiled(stage); });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
.set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
.set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
.set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
.set_body_typed([](const State& s, int stage_id) {
return HasCrossThreadReduction(s, stage_id);
});
} // namespace auto_scheduler
} // namespace tvm