blob: c8370d6daca2a62bcbf105a4f02a4f740835dc5e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_policy/sketch_policy_rules.cc
* \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy.
*/
#include "sketch_policy_rules.h"
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "sketch_policy.h"
namespace tvm {
namespace auto_scheduler {
static std::vector<int> auto_unroll_configs_cpu = {0, 16, 64, 512};
static std::vector<int> auto_unroll_configs_gpu = {0, 16, 64, 512, 1024};
/********** Sketch Generation Rule **********/
/********** RuleSkipStage **********/
SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
// This rule should be the last rule, always return true to decrease the stage index count
return ConditionKind::kApply;
}
std::vector<std::pair<State, int>> RuleSkipStage::Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const {
return {std::make_pair(state, stage_id - 1)};
}
/********** RuleAlwaysInline **********/
SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
const Stage& stage = state->stages[stage_id];
// Check the inline limitation of TE first
if (stage->op_type == StageKind::kPlaceholder ||
IsOutputOp(policy.search_task, state, stage_id) || HasReduceIter(stage)) {
return ConditionKind::kSkip;
}
// Always do compute inline if it's strictly inlineable or is in GPU policy
return IsStrictlyInlineable(policy.search_task, state, stage_id) || IsGPUTask(policy.search_task)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleAlwaysInline::Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const {
State tmp_s = state;
tmp_s.compute_inline(stage_id);
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
/********** RuleMultiLevelTiling **********/
SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
return NeedsMultilevelTiling(policy.search_task, state, stage_id)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
const std::string& multi_level_tiling_structure =
IsGPUTask(policy.search_task)
? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
: GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure);
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
/********** RuleMultiLevelTilingWithFusion **********/
SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
// Always do fusion for stage with cache_write or is in GPU policy
return HasCacheWriteStage(state, stage_id) || IsGPUTask(policy.search_task)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kApply;
}
return ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
int target_stage_id;
CHECK(HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id));
const std::string& multi_level_tiling_structure =
IsGPUTask(policy.search_task)
? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
: GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
std::vector<int> spatial_split_step_ids;
State base_state =
DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids);
std::vector<std::pair<State, int>> ret;
std::vector<int> follow_tiling_levels =
IsGPUTask(policy.search_task) ? std::vector<int>{3} : std::vector<int>{1, 2};
for (int level : follow_tiling_levels) {
if (tolower(multi_level_tiling_structure[level - 1]) != 's') {
continue;
}
State tmp_s = base_state;
tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level);
const Iterator& target_iter =
tmp_s->stages[target_stage_id]->iters[level * spatial_split_step_ids.size() - 1];
tmp_s.compute_at(stage_id, target_stage_id, target_iter);
ret.emplace_back(std::move(tmp_s), stage_id - 1);
}
return ret;
}
/********** RuleAddCacheRead **********/
SketchGenerationRule::ConditionKind RuleAddCacheRead::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
const SearchTask& task = policy.search_task;
// Don't cache_read a stage if it has multiple consumers
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
if (consumers.size() != 1) {
return ConditionKind::kSkip;
}
// Don't cache_read a stage if its consumer does not need multi-level tiling
int target_stage_id = *consumers.begin();
if (!NeedsMultilevelTiling(task, state, target_stage_id)) {
return ConditionKind::kSkip;
}
// Don't cache_read a stage if its consumer does cross-thread reduction
if (HasCrossThreadReduction(state, target_stage_id)) {
return ConditionKind::kSkip;
}
// Only direct producers can be cache read
const std::set<int>& producers = GetDirectProducers(task, state, target_stage_id);
if (producers.find(stage_id) == producers.end()) {
return ConditionKind::kSkip;
}
return ConditionKind::kApplyAndSkipRest;
}
std::vector<std::pair<State, int>> RuleAddCacheRead::Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const {
const SearchTask& task = policy.search_task;
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
CHECK_EQ(consumers.size(), 1);
int target_stage_id = *consumers.begin();
State tmp_s = state;
// Cache read add shared memory
int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id}, task->compute_dag);
target_stage_id++;
const auto& share_read_pos =
GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
return {std::make_pair(tmp_s, stage_id)};
}
/********** RuleAddCacheWrite **********/
SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
// Add cache write if a stage needs multi-level tiling, but does not have a element-wise
// matched consumer
if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
!HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
// An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion
return IsGPUTask(policy.search_task) ? ConditionKind::kApplyAndSkipRest : ConditionKind::kApply;
}
return ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleAddCacheWrite::Apply(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
State tmp_s = state;
tmp_s.cache_write(stage_id, "local", policy.search_task->compute_dag);
return {std::make_pair(std::move(tmp_s), stage_id)};
}
/********** RuleAddRfactor **********/
SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id))
? ConditionKind::kApply
: ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleAddRfactor::Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const {
// Fuse all reduction iters
Array<Iterator> space_iters, reduce_iters;
Iterator fused_reduce_iter;
State base_state =
FuseAllReductionIterators(state, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
// TODO(merrymercy): We can do more analysis here to generate less and more efficient sketches.
// In some cases, we only need rfactor for more parallel
// In some cases, we only need rfactor for vectorization.
// Now we will generate two versions and let the search figure out the bette one.
// Split reduction iters
const auto& split_res = base_state.split(stage_id, fused_reduce_iter, {Integer(1)});
int factor_axis_id = static_cast<int>(space_iters.size());
std::vector<std::pair<State, int>> ret;
for (const auto& split_iter : split_res) {
State tmp_s = base_state;
int rstage_id =
tmp_s.rfactor(stage_id, split_iter, factor_axis_id, policy.search_task->compute_dag);
// reorder the space iterator to innermost for vectorization
if (split_iter == split_res[1]) {
Array<Iterator> new_order;
for (size_t i = 0; i < tmp_s->stages[rstage_id]->iters.size(); ++i) {
if (i != space_iters.size()) {
new_order.push_back(tmp_s->stages[rstage_id]->iters[i]);
}
}
new_order.push_back(tmp_s->stages[rstage_id]->iters[space_iters.size()]);
tmp_s.reorder(rstage_id, new_order);
}
ret.emplace_back(std::move(tmp_s), rstage_id - 1);
}
return ret;
}
/********** RuleSimplifyComputeWithConstTensor **********/
SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
std::set<std::string> const_tensor_indices = GetIterNameSetParam(
state->stages[stage_id]->op->attrs, SearchPolicyKey::simplify_const_tensor_indices);
State tmp_s = state;
Array<Array<Iterator>> tiled_outer_iters;
Array<Iterator> unrolled_inner_iters;
// Currently set to 2
size_t tile_level = 2;
for (const auto& iter : state->stages[stage_id]->iters) {
if (const_tensor_indices.count(iter->name)) {
// unroll indices of const tensors
unrolled_inner_iters.push_back(tmp_s.unroll(stage_id, iter));
} else {
// tile other space indices
CHECK(iter->iter_kind == IteratorKind::kSpatial);
tiled_outer_iters.push_back(
tmp_s.split(stage_id, iter, Array<Optional<Integer>>(tile_level - 1, NullOpt)));
}
}
// reorder them
Array<Iterator> new_order;
for (size_t i = 0; i < tile_level; ++i) {
for (size_t j = 0; j < tiled_outer_iters.size(); ++j) {
new_order.push_back(tiled_outer_iters[j][i]);
}
}
new_order.insert(new_order.end(), unrolled_inner_iters.begin(), unrolled_inner_iters.end());
tmp_s.reorder(stage_id, new_order);
return {std::make_pair(tmp_s, stage_id - 1)};
}
/********** RuleCrossThreadReduction **********/
SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
CHECK(IsGPUTask(policy.search_task));
// If it is an intermidiate state created by RuleAddCacheWrite,
// we just skip it.
if (HasCacheWriteStage(state, stage_id)) {
return ConditionKind::kSkip;
}
const auto& op = state->stages[stage_id]->op;
if (op->IsInstance<te::ComputeOpNode>()) {
// Compute the product of lengths of all space iters and all reduce iters
int cum_space_len, cum_reduce_len;
std::tie(cum_space_len, cum_reduce_len) =
GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
// Do rfactor if we do not have enough parallelism on space iters
return cum_space_len < cum_reduce_len ? ConditionKind::kApply : ConditionKind::kSkip;
} else if (cum_reduce_len > 1) {
// Try rfactor for other reduction operators
return cum_reduce_len > policy.search_task->hardware_params->warp_size ? ConditionKind::kApply
: ConditionKind::kSkip;
}
}
return ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleCrossThreadReduction::Apply(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
const SearchTask& task = policy.search_task;
State tmp_s = state;
// fuse all reduction iters
Array<Iterator> space_iters, reduce_iters;
Iterator fused_reduce_iter;
tmp_s =
FuseAllReductionIterators(tmp_s, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
// Check the opportunity for kernel fusion
bool fusible = false;
int target_stage_id = GetSingleConsumerId(policy.search_task, tmp_s, stage_id);
int num_common_outer = -1;
if (target_stage_id >= 0) {
num_common_outer =
GetNumCommonOuterIterator(policy.search_task, tmp_s, stage_id, target_stage_id);
if (num_common_outer > 0 &&
!NeedsMultilevelTiling(policy.search_task, state, target_stage_id)) {
fusible = true;
}
}
if (fusible) {
const Stage& target_stage = state->stages[target_stage_id];
std::vector<int> split_step_ids;
GetSplitStepIds(tmp_s, target_stage_id, &split_step_ids);
if (split_step_ids.size() == 0) {
// If the target stage does not have split step,
// it must be a simple stage without reduce iters.
// We then should do a split for it.
CHECK(!HasReduceIter(target_stage));
const auto& split_res = tmp_s.split(target_stage_id, target_stage->iters.back(),
{Integer(task->hardware_params->warp_size)});
tmp_s.bind(target_stage_id, split_res[1], IteratorAnnotation::kThreadX);
split_step_ids.push_back(tmp_s->transform_steps.size() - 2);
}
CHECK_EQ(split_step_ids.size(), 1);
const Iterator& target_iter = tmp_s->stages[target_stage_id]->iters[num_common_outer - 1];
const auto& split_res = tmp_s.follow_split(stage_id, fused_reduce_iter, split_step_ids[0], 1);
tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
tmp_s.compute_at(stage_id, target_stage_id, target_iter);
} else {
const auto& split_res =
tmp_s.split(stage_id, fused_reduce_iter, {Integer(task->hardware_params->warp_size)});
tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
}
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
/********** RuleSpecialComputeLocationGPU **********/
SketchGenerationRule::ConditionKind RuleSpecialComputeLocationGPU::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
if (GetProducers(policy.search_task, state, stage_id).empty()) {
return ConditionKind::kSkip;
}
const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
if (consumers.size() == 1 && state->stages[*consumers.begin()]->op->attrs.count(
SearchPolicyKey::simplify_const_tensor_indices)) {
return ConditionKind::kApplyAndSkipRest;
}
return ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
State tmp_s = state;
const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
CHECK_EQ(consumers.size(), 1);
// Get the last outer space iterator that is not unrolled.
const Stage& target_stage = state->stages[*consumers.begin()];
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
if (target_stage->iters[i]->annotation == IteratorAnnotation::kUnroll) {
CHECK_GT(i, 0);
tmp_s.compute_at(stage_id, *consumers.begin(), target_stage->iters[i - 1]);
break;
}
}
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
/********** Init Population **********/
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
if (auto ps = (*state)->transform_steps[step_id].as<SplitStepNode>()) {
bool all_defined = true;
for (const auto& len : ps->lengths) {
if (!len) {
all_defined = false;
break;
}
}
if (all_defined) {
continue;
}
CHECK(ps->extent);
int extent = GetIntImm(ps->extent.value());
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
extent, ps->lengths.size(),
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
pstate->transform_steps.Set(
step_id,
SplitStep(ps->stage_id, ps->iter_id, ps->extent,
Array<Optional<Integer>>(candidate_lengths.begin(), candidate_lengths.end()),
ps->inner_to_outer));
}
}
pstate->concrete = true;
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(
SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return ResultKind::kValid;
}
for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stages and placeholders
if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) {
continue;
}
// Skip the tiled stages
if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
continue;
}
std::vector<std::pair<int, int>> candidates =
GetComputeLocationCandidates(policy->search_task, *state, stage_id);
int choice = (*rand_gen)() % (candidates.size() + 2);
if (choice == 0) {
if (!HasReduceIter(stage)) {
const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter;
if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) {
state->compute_inline(stage_id);
}
}
} else if (choice == 1) {
state->compute_root(stage_id);
} else {
choice = choice - 2;
const Stage& stage = (*state)->stages[candidates[choice].first];
state->compute_at(stage_id, candidates[choice].first,
stage->iters[candidates[choice].second]);
}
}
try {
*state = policy->search_task->compute_dag.InferBound(*state);
} catch (std::exception& e) {
return ResultKind::kInvalid;
}
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
annotate_parallel;
annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state,
int stage_id, int iter_offset) {
const Stage& stage = (*state)->stages[stage_id];
Array<Iterator> to_fuse;
int64_t parallel_degree = 1;
// Try to fuse and parallel the outermost n iterators
// Stop if we meet reduce iterator or we have enough parallel degree
size_t iter_id = iter_offset;
for (; iter_id < stage->iters.size(); ++iter_id) {
const Iterator& it = stage->iters[iter_id];
if (it->iter_kind == IteratorKind::kReduction ||
it->annotation != IteratorAnnotation::kNone) {
break;
}
to_fuse.push_back(it);
parallel_degree *= GetExtent(it);
if (parallel_degree > policy.search_task->hardware_params->num_cores * 16) {
break;
}
if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) {
break;
}
}
if (parallel_degree == 1) {
auto res =
(*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
if (res != (*state)->attach_map->iter_to_attached_stages.end()) {
for (int attached_stage_id : res->second) {
annotate_parallel(policy, state, attached_stage_id, 0);
}
annotate_parallel(policy, state, stage_id, iter_id + 1);
}
}
if (!to_fuse.empty()) {
if (to_fuse.size() == 1) {
state->parallel(stage_id, to_fuse[0]);
} else {
Iterator fused_iter = state->fuse(stage_id, to_fuse);
state->parallel(stage_id, fused_iter);
}
}
};
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
if (stage->compute_at != ComputeAtKind::kRoot || stage->op_type == StageKind::kPlaceholder) {
continue;
}
annotate_parallel(*policy, state, stage_id, 0);
}
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::vector<int>& auto_unroll_configs =
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
continue;
}
// Handle always_unroll_inner attr
if (stage->op->attrs.count(SearchPolicyKey::always_unroll_inner)) {
const auto& to_unroll_name_set =
GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::always_unroll_inner);
// Unroll the space iterators and reduce iterators listed in the attrs in the innermost
// tile
std::set<std::string> visited_names;
for (int n = static_cast<int>(stage->iters.size()) - 1; n >= 0; n--) {
const Iterator& it = stage->iters[n];
// If we meet two iterators that come from a same original iterator,
// then we are out of the innermost tile
size_t size_before = visited_names.size();
ExtractOriginalIterators(it->name, &visited_names);
if (size_before == visited_names.size()) {
break;
}
std::set<std::string> name;
ExtractOriginalIterators(it->name, &name);
if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) {
if (it->annotation == IteratorAnnotation::kNone) {
state->unroll(stage_id, it);
}
}
}
}
if (HasReduceIter(stage)) {
// Use auto unroll for multi level tiled stage
int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
state->pragma(stage_id, (*state)->stages[stage_id]->iters[0],
std::string("auto_unroll_max_step") + "$" + std::to_string(value));
}
}
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
State* state,
std::mt19937* rand_gen) const {
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
continue;
}
// Try to fuse and vectorize the space iterators in the inner most tile
int64_t cum_length_prod = 1;
int num_fusible = 0;
while (num_fusible < static_cast<int>(stage->iters.size())) {
int iter_id = static_cast<int>(stage->iters.size()) - 1 - num_fusible;
// Stop if this iterator has been a compute at attach point
if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) {
break;
}
const Iterator& it = stage->iters[iter_id];
// Stop if we meet a reduce iterator or annotated iterator
if (it->iter_kind == IteratorKind::kReduction ||
it->annotation != IteratorAnnotation::kNone) {
break;
}
// Stop if the memory access is not continuous (vectorizable)
// Note: The check is too hard, so we use heuristic here
if (IsTiled(stage) && num_fusible != 0) {
// If the stage is tiled, then the memory access must not be continuous
// for the innermost two iterators
break;
}
cum_length_prod *= GetExtent(it);
if (cum_length_prod > GetIntParam(policy->params, SketchParamKey::max_vectorize_size)) {
break;
}
num_fusible++;
}
if (num_fusible > 1) {
// Select a random range to fuse
num_fusible = 1 + (*rand_gen)() % (num_fusible - 1);
}
if (num_fusible == 1) {
state->vectorize(stage_id, stage->iters.back());
} else if (num_fusible > 1) {
Array<Iterator> to_fuse(stage->iters.end() + (-num_fusible), stage->iters.end());
state->vectorize(stage_id, state->fuse(stage_id, to_fuse));
}
}
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::set<int> multi_level_tiling_root_set;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
const Stage& stage = (*state)->stages[stage_id];
if (stage->compute_at != ComputeAtKind::kIter) {
// This stage is not multi-level tiled,
// so it must be produced by RuleCrossThreadReduction.
CHECK(HasCrossThreadReduction(*state, stage_id));
} else {
const auto res = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
CHECK(res != (*state)->attach_map->stage_to_attach_iter.end());
multi_level_tiling_root_set.insert(res->second.first);
}
}
}
*state = policy->search_task->compute_dag.InferBound(*state);
for (int stage_id = (*state)->stages.size() - 1; stage_id >= 0; --stage_id) {
const Stage& stage = (*state)->stages[stage_id];
if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
continue;
}
// Deal with the cross-thread reduction generated by RuleCrossThreadReduction
if (HasCrossThreadReduction(*state, stage_id)) {
if (stage->compute_at != ComputeAtKind::kRoot) {
continue;
}
Iterator fused_it;
*state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it));
state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX);
continue;
}
// Skip if this stage has already been annotaed with threadIdx.x
if (HasAnnotatedIter(stage, IteratorAnnotation::kThreadX)) {
continue;
}
if (stage->compute_at == ComputeAtKind::kRoot) {
// This stage has not been tiled, but in GPU schedule, we must tile the root stage
// to do thread binding
if (!multi_level_tiling_root_set.count(stage_id)) {
Iterator fused_it;
*state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
if (GetExtent(fused_it) <= policy->search_task->hardware_params->warp_size) {
state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
} else {
// Set threadIdx.x = default_warp_size by default.
// The later EvolutionarySearch will try more possibility
const auto& split_its = state->split(
stage_id, fused_it, {Integer(policy->search_task->hardware_params->warp_size)});
state->bind(stage_id, split_its[0], IteratorAnnotation::kBlockX);
state->bind(stage_id, split_its[1], IteratorAnnotation::kThreadX);
}
continue;
}
// Otherwise, this is a tiled root stage, we assume it should be tiled with 3 space level
// in the outer iterators.
// The remaining part deals with the thread binding for multi-level tiled stages
auto pop = stage->op.as<te::ComputeOpNode>();
std::vector<Iterator> to_fuse;
int total_space_extent = 1;
for (const auto& i : pop->root_iter_vars()) {
CHECK(i->dom.defined());
const auto& pint = i->dom->extent.as<IntImmNode>();
CHECK(pint);
total_space_extent *= pint->value;
}
// Check if the total space extent is too small for multi-level thread binding
if (total_space_extent <= policy->search_task->hardware_params->warp_size) {
Iterator fused_it;
*state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
continue;
}
// Fuse the outermost space tile as blockIdx
for (size_t i = 0; i < pop->axis.size(); i++) {
const auto& it = (*state)->stages[stage_id]->iters[i];
// There may be some iterators that are marked with no split, stop if reaches next
// tiling level
if (!StrEndsWith(it->name, ".0")) {
break;
}
to_fuse.push_back(it);
}
const auto& blockidx_it = state->fuse(stage_id, to_fuse);
state->bind(stage_id, blockidx_it, IteratorAnnotation::kBlockX);
// Fuse the second outermost space tile as vthread
to_fuse.clear();
for (size_t i = 1; i < pop->axis.size() + 1; i++) {
const auto& it = (*state)->stages[stage_id]->iters[i];
// There may be some iterators that are marked with no split, stop if reaches next
// tiling level
if (!StrEndsWith(it->name, ".1")) {
break;
}
to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
}
const auto& vthread_it = state->fuse(stage_id, to_fuse);
if (GetExtent(vthread_it) > policy->search_task->hardware_params->max_vthread_extent) {
return ResultKind::kInvalid;
}
state->bind(stage_id, vthread_it, IteratorAnnotation::kVThread);
// Fuse the third outermost space tile as threadIdx
to_fuse.clear();
for (size_t i = 2; i < pop->axis.size() + 2; i++) {
const auto& it = (*state)->stages[stage_id]->iters[i];
// There may be some iterators that are marked with no split, stop if reaches next
// tiling level
if (!StrEndsWith(it->name, ".2")) {
break;
}
to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
}
const auto& threadidx_it = state->fuse(stage_id, to_fuse);
if (GetExtent(threadidx_it) < policy->search_task->hardware_params->warp_size) {
return ResultKind::kInvalid;
}
state->bind(stage_id, threadidx_it, IteratorAnnotation::kThreadX);
} else if (stage->compute_at == ComputeAtKind::kIter &&
StrEndsWith(stage->op->name, ".shared")) {
// Do cooperative fetching for the cache read stage.
// Get spatial_split_step_ids from the root stage
const auto& it = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
CHECK(it != (*state)->attach_map->stage_to_attach_iter.end());
Array<Integer> spatial_split_step_ids = GetSpatialSplitStepIds(*state, it->second.first);
// Fuse all iterators to do cooperative fetching
Iterator fused = state->fuse(stage_id, (*state)->stages[stage_id]->iters);
// Split out an extra iterator for vectorization
// The later EvolutionarySearch will try more possibility
const auto& iters0 = state->split(stage_id, fused, {Integer(1)});
state->vectorize(stage_id, iters0[1]);
// Follow split to keep a same thread extent with the root stage
const auto& iters1 =
state->follow_fused_split(stage_id, iters0[0], spatial_split_step_ids, 1, true);
state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX);
}
}
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
// Extract all SplitStep
std::vector<size_t> split_step_ids;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) {
if (!ps->extent.defined() || !ps->extent.value()->IsInstance<IntImmNode>()) {
continue;
}
auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1);
if (GetIntImm(innermost_factor) <= max_innermost_split_factor) {
split_step_ids.push_back(i);
}
}
}
if (split_step_ids.empty()) {
// No tile size could be mutated.
return ResultKind::kInvalid;
}
// Select a SplitStep with extent larger than one to mutate.
int retry_ct = 0;
int64_t extent = 1;
int step_id;
const SplitStepNode* ps;
do {
step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()];
ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
CHECK(ps != nullptr);
extent = GetIntImm(ps->extent.value());
retry_ct += 1;
} while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent == 1 || extent == 0));
if (extent <= 1) {
// Cannot find a step with extent larger than one.
return ResultKind::kInvalid;
}
// Fetch the current tile sizes.
std::vector<int> lengths(ps->lengths.size() + 1, 1);
for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) {
lengths[i + 1] = GetIntImm(ps->lengths[i].value());
}
lengths[0] = extent / ElementProduct(lengths);
// Random permute the tile size order.
std::vector<int> random_perm;
RandomPermutation(lengths.size(), &random_perm, rand_gen);
// Try to divide a factor from one tile size and multiple it to another.
for (size_t i = 0; i < random_perm.size(); ++i) {
size_t src_idx = random_perm[i];
int length = lengths[src_idx];
if (length <= 1) {
continue;
}
// Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]
size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
const std::vector<int>& factors = policy->split_memo.GetFactors(length);
CHECK_GE(factors.size(), 1);
int divide_factor;
if (dst_idx == lengths.size() - 1) {
// Maintain the restriction of hardware_params.max_innermost_split_factor.
int max_factor_index = static_cast<int>(factors.size()) - 1;
for (; max_factor_index >= 1; max_factor_index--) {
if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) {
break;
}
}
if (max_factor_index == 0) {
// Failed on this dst_idx, try next one.
continue;
}
divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)];
} else {
divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)];
}
// Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx].
Array<Integer> new_lengths;
for (size_t j = 1; j < lengths.size(); ++j) {
if (j == src_idx) {
new_lengths.push_back(Integer(lengths[j] / divide_factor));
} else if (j == dst_idx) {
new_lengths.push_back(Integer(lengths[j] * divide_factor));
} else {
new_lengths.push_back(Integer(lengths[j]));
}
}
CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor);
StateNode* pstate = state->CopyOnWrite();
pstate->transform_steps.Set(
step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()),
ps->inner_to_outer));
return ResultKind::kValid;
}
return ResultKind::kInvalid;
}
PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// Extract all auto_unroll_max_step pragma steps.
std::vector<int> pragma_steps;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
pragma_steps.push_back(i);
}
}
}
if (pragma_steps.empty()) {
return ResultKind::kInvalid;
}
std::vector<int>& auto_unroll_configs =
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
// Randomly pick up an auto unroll pragma step
auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
CHECK(ps);
// Mutate its value to a random candidates
auto val = std::to_string(auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]);
StateNode* pstate = state->CopyOnWrite();
pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id,
std::string("auto_unroll_max_step") + "$" + val));
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state,
std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return ResultKind::kInvalid;
}
// Extract all compute_at steps.
std::vector<int> compute_at_steps;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) {
int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id;
if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) {
continue;
}
if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) {
continue;
}
compute_at_steps.push_back(s);
}
}
if (compute_at_steps.empty()) {
return ResultKind::kInvalid;
}
// Randomly pick one step
size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
CHECK(ps != nullptr);
// Randomly pick a new computation location
std::vector<std::pair<int, int>> candidates =
GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc);
if (candidates.empty()) {
return ResultKind::kInvalid;
}
int choice = (*rand_gen)() % (candidates.size());
int new_compute_at_stage_id = candidates[choice].first;
int new_compute_at_iter_id = candidates[choice].second;
// Replay a new state.
State tmp_s = policy->search_task->compute_dag->init_state;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
if (s == step_id) {
tmp_s.CopyOnWrite()->transform_steps.push_back(
ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id));
} else {
tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]);
}
try {
StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
} catch (dmlc::Error& e) {
return ResultKind::kInvalid;
}
}
*state = tmp_s;
return ResultKind::kValid;
}
PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// This mutation rule only focuses on a case that parallel was added to
// the outermost loop and the loop is generated by fusing other loops.
// In short, we mutate the fusion step before the parallel step.
// Extract all parallel steps.
std::vector<int> parallel_steps;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>();
if (!ps || ps->annotation != IteratorAnnotation::kParallel) {
continue;
}
// Skip non-outermost loop or the parallel step without fusion beforehand.
if (ps->iter_id != 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) {
continue;
}
auto fuse_step = (*state)->transform_steps[s - 1].as<FuseStepNode>();
if (fuse_step->fused_ids[0] != 0) {
continue;
}
parallel_steps.push_back(s);
}
if (parallel_steps.empty()) {
return ResultKind::kInvalid;
}
// Randomly pick one parallel step.
size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
// Replay a new state until the picked fuse step.
State tmp_s = policy->search_task->compute_dag->init_state;
for (size_t s = 0; s < step_id - 1; ++s) {
const auto& step = (*state)->transform_steps[s];
tmp_s.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
}
// Compute all possible fusion granularities
auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>();
int stage_id = fuse_step->stage_id;
const Stage& stage = tmp_s->stages[stage_id];
size_t max_fusable_iter_id;
for (max_fusable_iter_id = 0; max_fusable_iter_id < stage->iters.size(); ++max_fusable_iter_id) {
const Iterator& it = stage->iters[max_fusable_iter_id];
if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
break;
}
if (tmp_s->attach_map->iter_to_attached_stages.count(
std::make_pair(stage_id, max_fusable_iter_id))) {
break;
}
}
// Randomly pick one granularity
int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1;
Array<Integer> fused_ids;
for (int i = 0; i < fuse_to_iter_id; ++i) {
fused_ids.push_back(i);
}
int iter_offset = fuse_step->fused_ids.back()->value - fused_ids.back()->value;
if (iter_offset == 0) {
return ResultKind::kInvalid;
}
// Replay the mutated fused and annotation step.
auto new_fuse_step = FuseStep(stage_id, fused_ids);
tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step);
StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag);
tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]);
StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag);
// Replay the rest steps.
for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) {
auto step = (*state)->transform_steps[s];
if (step->stage_id == stage_id) {
// Since we changed the loop structure, iter ID in later steps to the same stage
// has to be adjusted.
if (auto ps = step.as<AnnotationStepNode>()) {
if (ps->iter_id == 0) {
step = AnnotationStep(ps->stage_id, 0, ps->annotation);
} else {
CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation);
}
} else if (auto ps = step.as<PragmaStepNode>()) {
if (ps->iter_id == 0) {
step = PragmaStep(ps->stage_id, 0, ps->pragma_type);
} else {
CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
step = PragmaStep(ps->stage_id, ps->iter_id + iter_offset, ps->pragma_type);
}
} else {
return ResultKind::kInvalid;
}
}
if (IsStageNumberChangingStep(step)) {
// For these steps, we have to update stage_id because these steps will make stage_id
// out-dated. But here we just simply give up this mutation for simplicity.
// This is not an issue because this will never happend in normal cases where all these steps
// are before parallel steps.
return ResultKind::kInvalid;
}
tmp_s.CopyOnWrite()->transform_steps.push_back(step);
try {
StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
} catch (dmlc::Error& e) {
return ResultKind::kInvalid;
}
}
*state = tmp_s;
return ResultKind::kValid;
}
} // namespace auto_scheduler
} // namespace tvm