blob: 5c015ca46a9b561b29c5019f992bb8f7f3af7926 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_policy/utils.h
* \brief Common utilities for search policies.
*/
#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
#include <dmlc/common.h>
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/ir/expr.h>
#include <tvm/te/operation.h>
#include <algorithm>
#include <condition_variable>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../utils.h"
namespace tvm {
namespace auto_scheduler {
/*! \brief Return whether the search task is targeting a CPU. */
inline bool IsCPUTask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLCPU;
}
/*! \brief Return whether the search task is targeting a GPU. */
inline bool IsGPUTask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLGPU ||
(task)->target->kind->device_type == kDLOpenCL ||
(task)->target->kind->device_type == kDLVulkan ||
(task)->target->kind->device_type == kDLMetal ||
(task)->target->kind->device_type == kDLROCM ||
(task)->target->kind->device_type == kOpenGL;
}
/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLGPU;
}
/*! \brief Return whether the search task is targeting a OpenCL GPU. */
inline bool IsOpenCLTask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLOpenCL;
}
/*! \brief Argsort. Order: largest to smallest */
template <typename T>
inline std::vector<int> Argsort(const std::vector<T>& scores) {
std::vector<int> index;
index.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
index.push_back(i);
}
auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; };
std::sort(index.begin(), index.end(), cmp);
return index;
}
/*! \brief Convert operation to stage id. */
inline int OperationToStage(const te::Operation& op, const State& state) {
for (size_t i = 0; i < state->stages.size(); ++i) {
if (op == state->stages[i]->op) {
return i;
}
}
LOG(FATAL) << "Cannot find op: " << op;
return -1;
}
/********** Get Parameters **********/
/*! \brief Get an integer from a tvm str Map. */
inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
auto pint = attr_dict[key].as<IntImmNode>();
CHECK(pint != nullptr);
return pint->value;
}
/*! \brief Get a double from a tvm str Map. */
inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
auto pdouble = attr_dict[key].as<FloatImmNode>();
CHECK(pdouble != nullptr);
return pdouble->value;
}
/*! \brief Get a string from a tvm str Map. */
inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
const auto& target = attr_dict[key];
if (auto pstr = target.as<StringImmNode>()) {
return pstr->value;
}
auto pstr = target.as<StringObj>();
CHECK(pstr != nullptr);
return pstr->data;
}
/*! \brief Get a iterator name set from a tvm str Map. */
inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict,
const std::string& key) {
std::set<std::string> ret;
CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
auto names = attr_dict[key].as<ArrayNode>();
CHECK(names != nullptr);
for (const auto& name : *names) {
ret.insert(name.as<StringObj>()->data);
}
return ret;
}
/********** Checks with ComputeDAG **********/
/*! \brief Return whether an op is strictly-inlineable. */
inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) {
if (state->current_compute_dag) {
return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable(
state->stages[stage_id]->op);
} else {
return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op);
}
}
/*! \brief Return whether an op is an output op. */
inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) {
if (state->current_compute_dag) {
return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput(
state->stages[stage_id]->op);
} else {
return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op);
}
}
/*! \brief Return whether an op needs multi level tiling. */
inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) {
if (state->current_compute_dag) {
return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling(
state->stages[stage_id]->op);
} else {
return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op);
}
}
/*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */
inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) {
std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers;
std::set<int> ret;
if (state->current_compute_dag) {
consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers(
state, state->stages[stage_id]->op);
} else {
consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op);
}
for (const auto& op : consumers) {
ret.insert(OperationToStage(op, state));
}
return ret;
}
/*! \brief Check if a stage has single consumer or all of its consumers share a common root, return
* the target consumer root or -1. */
inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) {
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
if (consumers.empty()) {
return -1;
}
if (consumers.size() == 1) {
return *consumers.begin();
} else {
// Check all consumers share a common root
int common_root_id = -1;
bool mismatch = false;
for (const auto& consumer_stage_id : consumers) {
int root_id = -1;
if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) {
root_id = consumer_stage_id;
} else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) {
root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first;
} else {
LOG(FATAL) << "Invalid case";
}
if (common_root_id == -1) {
common_root_id = root_id;
} else {
if (common_root_id != root_id) {
mismatch = true;
break;
}
}
}
return mismatch ? -1 : common_root_id;
}
}
/*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */
inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) {
std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
std::set<int> ret;
if (state->current_compute_dag) {
producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers(
state, state->stages[stage_id]->op);
} else {
producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op);
}
for (const auto& op : producers) {
ret.insert(OperationToStage(op, state));
}
return ret;
}
/*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for
* inlined ops. */
inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) {
std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
std::set<int> ret;
if (state->current_compute_dag) {
producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers(
state->stages[stage_id]->op);
} else {
producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op);
}
for (const auto& op : producers) {
ret.insert(OperationToStage(op, state));
}
return ret;
}
/*! \brief Get the number of common outer iterators. This function propagates the relation for
* chains with multiple ops. */
inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id,
int target_stage_id) {
if (state->current_compute_dag) {
return state->current_compute_dag.as<ComputeDAGNode>()
->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op,
state->stages[target_stage_id]->op);
} else {
return task->compute_dag->access_analyzer.GetNumCommonOuterIterator(
state->stages[stage_id]->op, state->stages[target_stage_id]->op);
}
}
/*! \brief Return whether two ops are elementwise-matched. */
inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id,
int target_stage_id) {
const auto& op = state->stages[stage_id]->op;
const auto& target_op = state->stages[target_stage_id]->op;
if (state->current_compute_dag) {
return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch(
op, target_op);
} else {
return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op);
}
}
/********** Get informations from Stage/Iterator **********/
/*! \brief Return the extent of an iterator. */
inline int64_t GetExtent(const Iterator& it) {
if (it->range.defined()) {
if (auto pint = it->range->extent.as<IntImmNode>()) {
return pint->value;
}
}
return -1;
}
/*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */
inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) {
int64_t cum_space_len = 1, cum_reduce_len = 1;
for (const auto& iter : stage->iters) {
if (iter->iter_kind == IteratorKind::kSpatial) {
cum_space_len *= GetExtent(iter);
} else if (iter->iter_kind == IteratorKind::kReduction) {
cum_reduce_len *= GetExtent(iter);
}
}
return std::make_pair(cum_space_len, cum_reduce_len);
}
/*! \brief Return whether this stage needs rfactor. */
inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) {
const auto& op = state->stages[stage_id]->op;
if (op->IsInstance<te::ComputeOpNode>()) {
// Compute the product of lengths of all space iters and all reduce iters
int cum_space_len, cum_reduce_len;
std::tie(cum_space_len, cum_reduce_len) =
GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
if (NeedsMultilevelTiling(task, state, stage_id)) {
// Do not use rfactor if we have enough parallelism on space iters
if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) {
return false;
} else {
return true;
}
} else if (cum_reduce_len > 1) {
// Always try rfactor for reduction ops
return cum_reduce_len > task->hardware_params->num_cores;
}
}
return false;
}
/*! \brief Return whether the stage has reduce iterators. */
inline bool HasReduceIter(const Stage& stage) {
for (const auto& iter : stage->iters) {
if (iter->iter_kind != IteratorKind::kSpatial) {
return true;
}
}
return false;
}
/*! \brief Return whether the stage has specific annotated iterators. */
inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) {
for (const auto& iter : stage->iters) {
if (iter->annotation == type) {
return true;
}
}
return false;
}
/*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */
inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state,
int stage_id, int* target_stage_id = nullptr) {
// Temporal object to be used if the input pointer is nullptr
int temp_target_stage_id;
if (target_stage_id == nullptr) {
target_stage_id = &temp_target_stage_id;
}
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
if (consumers.size() == 1) {
*target_stage_id = *consumers.begin();
if (ElementwiseMatch(task, state, stage_id, *target_stage_id) &&
(!(HasReduceIter(state->stages[stage_id]) &&
HasReduceIter(state->stages[*target_stage_id]))) &&
(!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) {
return true;
}
}
return false;
}
/*! \brief Return whether the step changes the number of stages */
inline bool IsStageNumberChangingStep(const Step& step) {
return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() ||
step->IsInstance<RfactorStepNode>();
}
/*! \brief Return whether the state does cache_read for stage_id. */
inline bool HasCacheReadStage(const State& s, int stage_id) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
if (stage_id == ps->stage_id) {
return true;
}
}
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (stage_id > s->transform_steps[i]->stage_id) {
stage_id--;
}
}
}
return false;
}
/*! \brief Return whether the state does cache_write for stage_id. */
inline bool HasCacheWriteStage(const State& s, int stage_id) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) {
if (stage_id == ps->stage_id) {
return true;
}
}
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (stage_id > s->transform_steps[i]->stage_id) {
stage_id--;
}
}
}
return false;
}
/*! \brief Return whether the state does rfactor for stage_id. */
inline bool HasRfactorStage(const State& s, int stage_id) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
if (stage_id == ps->stage_id) {
return true;
}
}
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (stage_id > s->transform_steps[i]->stage_id) {
stage_id--;
}
}
}
return false;
}
/*! \brief Return whether the stage does cross thread reduction. */
inline bool HasCrossThreadReduction(const State& state, int stage_id) {
std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
for (const auto& iter : in_stage->iters) {
if (iter->annotation == IteratorAnnotation::kThreadX &&
iter->iter_kind == IteratorKind::kReduction) {
return true;
}
}
return false;
};
// Check the stage itself
if (check_stage(state->stages[stage_id])) {
return true;
}
// Check the attached stages
for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
const auto& res =
state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
if (res != state->attach_map->iter_to_attached_stages.end()) {
for (int attached_stage_id : res->second) {
if (check_stage(state->stages[attached_stage_id])) {
return true;
}
}
}
}
return false;
}
/*! \brief Return whether the stage has been tiled already. */
inline bool IsTiled(const Stage& stage) {
auto op = stage->op.as<te::ComputeOpNode>();
CHECK(op != nullptr);
return stage->iters.size() != op->axis.size() + op->reduce_axis.size();
}
/*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */
inline void ExtractOriginalIterators(const std::string& name, std::set<std::string>* rets) {
size_t last_pos = 0;
for (size_t i = 0; i < name.size(); ++i) {
if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split
if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') {
rets->insert(name.substr(last_pos, i - last_pos));
}
last_pos = i + 1;
}
}
if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' &&
name[last_pos] != '.') {
rets->insert(name.substr(last_pos, name.size() - last_pos));
}
}
/*! \brief Get the last reduce iterator in the outermost reduce tile. */
inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
auto pop = stage->op.as<te::ComputeOpNode>();
CHECK(pop != nullptr);
std::set<std::string> original_names;
const std::set<std::string>& no_split_at_inner_name_set =
stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();
size_t reduce_axis_size = 0;
for (const auto axis : pop->reduce_axis) {
if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
reduce_axis_size++;
}
}
if (reduce_axis_size) {
for (const auto& iter : stage->iters) {
if (iter->iter_kind == IteratorKind::kReduction) {
ExtractOriginalIterators(iter->name, &original_names);
if (original_names.size() == reduce_axis_size) {
return iter;
}
}
}
} else {
// Return the first reduce iterator
for (const auto& iter : stage->iters) {
if (iter->iter_kind == IteratorKind::kReduction) {
return iter;
}
}
}
LOG(FATAL) << "Cannot find the iterator.";
return stage->iters[0];
}
/*! \brief Get the target stage id of a history step in the new state.
* We need this because the stage_id in the history may be stale due to later steps */
inline int GetTargetStageIDInState(const State& s, int step_id) {
int stage_inc = 0;
for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc)
stage_inc++;
}
}
return s->transform_steps[step_id]->stage_id + stage_inc;
}
/*! \brief Get all split steps for one stage. */
inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
if (stage_id == ps->stage_id) {
split_step_ids->push_back(i);
}
}
if (IsStageNumberChangingStep(s->transform_steps[i])) {
if (stage_id > s->transform_steps[i]->stage_id) {
stage_id--;
}
}
}
}
/*! \brief Fuse all reduction iterators. */
inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
Array<Iterator>* space_iters,
Array<Iterator>* reduce_iters) {
space_iters->clear();
reduce_iters->clear();
for (const auto& iter : state->stages[stage_id]->iters) {
if (iter->iter_kind == IteratorKind::kSpatial) {
space_iters->push_back(iter);
} else if (iter->iter_kind == IteratorKind::kReduction) {
reduce_iters->push_back(iter);
}
}
CHECK(!reduce_iters->empty());
State tmp_s = state;
if (reduce_iters->size() > 1) {
*fused_iter = tmp_s.fuse(stage_id, *reduce_iters);
} else {
*fused_iter = (*reduce_iters)[0];
}
return tmp_s;
}
/*! \brief Fuse all outer level space iterators. */
inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
std::vector<Iterator> to_fuse;
for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
const auto& it = state->stages[stage_id]->iters[iter_id];
// Stop at reduce iterator or annotated iterator
if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
break;
}
// Stop at compute_at attach point
if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
break;
}
to_fuse.push_back(it);
}
CHECK(!to_fuse.empty());
State tmp_s = state;
if (to_fuse.size() > 1) {
*fused_iter = tmp_s.fuse(stage_id, to_fuse);
} else {
*fused_iter = to_fuse[0];
}
return tmp_s;
}
/*! \brief Random sample states. */
inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
size_t out_size) {
Array<State> out_states;
for (size_t i = 0; i < out_size; i++) {
out_states.push_back(in_states[(*random_gen)() % in_states.size()]);
}
return out_states;
}
/*! \brief Compute prefix-sum probabiilty based on the given weights */
inline void ComputePrefixSumProb(const std::vector<float>& weights,
std::vector<double>* prefix_sum_probs) {
// Compute selection probabilities.
float sum = 0.0;
prefix_sum_probs->resize(weights.size());
for (size_t i = 0; i < weights.size(); ++i) {
sum += std::max(weights[i], 0.0f);
(*prefix_sum_probs)[i] = sum;
}
for (size_t i = 0; i < weights.size(); ++i) {
(*prefix_sum_probs)[i] /= sum;
}
}
/*! \brief Random choose an index according to a prefix sum probability. */
inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) {
std::uniform_real_distribution<> dis(0.0, 1.0);
double x = dis(*random_gen);
CHECK(!prefix_sum_probs.empty());
return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) -
prefix_sum_probs.begin();
}
/*! \brief Print a title */
inline void PrintTitle(const std::string& title, int verbose) {
StdCout(verbose) << Chars('-', 60) << "\n"
<< Chars('-', 25) << " [ " << title << " ]\n"
<< Chars('-', 60) << std::endl;
}
/*!
* \brief Enumerate all possible factorization schemes for splitting an axes.
* \note This class will memorize the results for reuse.
*/
class SplitFactorizationMemo {
public:
using QueryKey = std::tuple<int, int, int>;
const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths,
int max_innermost_factor);
const std::vector<int>& GetFactors(int n);
private:
void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);
/*!
* \brief A simple implementation of read-write lock.
* The guarded block can be read by multiple threads at the same time, while other operations will
* be blocked if one thread is writing.
* \note Writing threads will wait until all reading threads have finshed. If there're multiple
* writing threads, the process order of them is not guaranteed.
*/
class ReadWriteLock {
public:
/*! \brief The method to get the read lock. One thread can process read if there's on other
* writing threads. */
void GetRead();
/*! \brief The method to get the write lock. One thread can process write if there's on other
* reading or writing threads. */
void GetWrite();
/*! \brief The method to release the read lock. */
void UnlockRead();
/*! \brief The method to release the write lock. */
void UnlockWrite();
private:
uint32_t read_count_ = 0;
bool is_writing_ = false;
std::mutex cv_mutex_;
std::condition_variable cv_;
} lock_;
std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;
int n_lengths_;
Array<Integer> tmp_stack_;
Array<Array<Integer>>* results_;
std::unordered_map<int, std::vector<int>> factor_memory_;
};
/*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
/*! \brief Get the possible compute locations for a stage. */
std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
const State& state, int stage_id);
// Apply multi-level tiling structure according to a string format,
// where "S" stands a space level, "R" stands for a reduction level.
// For example, if the format is "SSRSRS", then we will
// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3
// For example, if apply "SSRSRS" to matrix multiplication,
// we have space iterators i and j, reduce iterator k.
// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3
State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
std::vector<int>* spatial_split_step_ids = nullptr);
// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep
State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
int n_split);
// Prune invalid states and return the results in-place.
void PruneInvalidState(const SearchTask& task, Array<State>* states);
} // namespace auto_scheduler
} // namespace tvm
#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_