blob: a89fa4b0c77a8e9b72988e162732f0ad4ee3ac9c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_policy/sketch_search_policy.h
* \brief The search policy that searches in a hierarchical search space defined by sketches.
* The policy randomly samples programs from the space defined by sketches
* and use evolutionary search to fine-tune them.
*/
#include "sketch_policy.h"
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <algorithm>
#include <iomanip>
#include <limits>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "sketch_policy_rules.h"
namespace tvm {
namespace auto_scheduler {
/********** Sketch generation rules **********/
static RuleSkipStage rule_skip_stage;
static RuleAlwaysInline rule_always_inline;
static RuleMultiLevelTiling rule_multi_level_tiling;
static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion;
static RuleAddCacheRead rule_add_cache_read_stage;
static RuleAddCacheWrite rule_add_cache_write_stage;
static RuleAddRfactor rule_add_rfactor;
static RuleCrossThreadReduction rule_cross_thread_reduction;
static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor;
static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
/********** Init population rules **********/
static InitFillTileSize init_fill_tile_size;
static InitChangeComputeLocation init_change_compute_location;
static InitParallel init_parallel;
static InitUnroll init_unroll;
static InitVectorization init_vectorization;
static InitThreadBind init_thread_bind;
/********** Sketch policy **********/
TVM_REGISTER_NODE_TYPE(SketchPolicyNode);
SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
Map<String, ObjectRef> params, int seed, int verbose,
Optional<Array<SearchCallback>> init_search_callbacks) {
auto node = make_object<SketchPolicyNode>();
node->search_task = std::move(task);
node->program_cost_model = std::move(program_cost_model);
node->rand_gen = std::mt19937(seed);
node->params = std::move(params);
node->verbose = verbose;
if (init_search_callbacks) {
PrintTitle("Call init-search callbacks", verbose);
// Candidates:
// - auto_scheduler.PreloadMeasuredStates: Load already measured states to
// `measured_states_set_`, `measured_states_vector_` and `measured_states_throughputs_`.
// - auto_scheduler.PreloadCustomSketchRule: Add user custom sketch rules to `sketch_rules`,
// these rules will be processed prior to the default rules.
node->RunCallbacks(init_search_callbacks.value());
}
// NOTE: There are strong dependency among the rules below,
// so the order to push them into the vector should be considered carefully.
if (IsCPUTask(node->search_task)) {
// Sketch Generation Rules
node->sketch_rules.push_back(&rule_always_inline);
node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
node->sketch_rules.push_back(&rule_add_rfactor);
node->sketch_rules.push_back(&rule_add_cache_write_stage);
node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
node->sketch_rules.push_back(&rule_multi_level_tiling);
node->sketch_rules.push_back(&rule_skip_stage);
// Initial Population Generation Rules
node->init_rules.push_back(&init_fill_tile_size);
node->init_rules.push_back(&init_change_compute_location);
node->init_rules.push_back(&init_parallel);
node->init_rules.push_back(&init_unroll);
node->init_rules.push_back(&init_vectorization);
// Mutation Rules for Evolutionary Search
node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04));
node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05));
node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
} else if (IsGPUTask(node->search_task)) {
// Sketch Generation Rules
node->sketch_rules.push_back(&rule_add_cache_read_stage);
node->sketch_rules.push_back(&rule_always_inline);
node->sketch_rules.push_back(&rule_special_compute_location_gpu);
node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
node->sketch_rules.push_back(&rule_cross_thread_reduction);
node->sketch_rules.push_back(&rule_add_cache_write_stage);
node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
node->sketch_rules.push_back(&rule_multi_level_tiling);
node->sketch_rules.push_back(&rule_skip_stage);
// Initial Population Generation Rules
node->init_rules.push_back(&init_fill_tile_size);
node->init_rules.push_back(&init_thread_bind);
node->init_rules.push_back(&init_unroll);
// Mutation Rules for Evolutionary Search
node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
} else {
LOG(FATAL) << "No default sketch rules for target: " << task->target;
}
data_ = std::move(node);
}
State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure_per_iter,
ProgramMeasurer measurer) {
num_measure_per_iter_ = num_measure_per_iter;
if (n_trials <= 1) {
// No measurement is allowed
const Array<State>& best_states = SearchOneRound(0);
CHECK_GT(best_states.size(), 0);
return best_states[0];
} else {
int num_random =
static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter);
early_stopping = early_stopping < 0 ? std::numeric_limits<int>::max() >> 1 : early_stopping;
measurer->Reset();
int ct = 0;
int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
Array<MeasureInput> inputs;
Array<MeasureResult> results;
while (ct < n_trials) {
if (!inputs.empty()) {
// Retrain cost models before the next search round
PrintTitle("Train cost model", verbose);
program_cost_model->Update(inputs, results);
}
// Search one round to get promising states
PrintTitle("Search", verbose);
Array<State> random_states;
Array<State> best_states = SearchOneRound(num_random, &random_states);
// Infer bound. This is necessary for computing the correct ToStr() for redundancy check
best_states = search_task->compute_dag.InferBound(best_states);
random_states = search_task->compute_dag.InferBound(random_states);
// Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
// Also pick some random states to do eps-greedy
inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct);
// Currently it's hard to detect if all of the search space has been traversed
// Stop if no extra valid states found in several retries
if (inputs.empty()) {
if (empty_retry_count-- > 0) {
continue;
} else {
StdCout(verbose) << "It seems all candidates in the search space have been measured."
<< std::endl;
break;
}
} else {
// Reset the retry count
empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
}
// Measure candidate states
PrintTitle("Measure", verbose);
measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs, &results);
ct += inputs.size();
// Check if reach the early stopping condition
if (ct - measurer->best_ct[search_task->workload_key] > early_stopping) {
StdCout(verbose) << "Stop early since no performance improvement in the last "
<< early_stopping << " measure steps.\n";
break;
}
// Update measured states throughputs. These states will join the EvolutionarySearch in later
// search rounds.
for (const auto& res : results) {
measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
}
}
PrintTitle("Done", verbose);
return measurer->best_state[search_task->workload_key];
}
}
Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
// Temporal object to be used if the input pointer is nullptr
Array<State> temp_random_states;
if (random_states == nullptr) {
random_states = &temp_random_states;
} else {
random_states->clear();
}
// Get parameters
int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
int num_use_measured =
std::min(static_cast<int>(measured_states_vector_.size()),
static_cast<int>(
GetDoubleParam(params, SketchParamKey::EvolutionarySearch::use_measured_ratio) *
population));
bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
// 1. Generate sketches
if (sketch_cache_.empty()) {
sketch_cache_ = GenerateSketches();
}
// 2. Sample the init population
Array<State> init_population = SampleInitPopulation(
sketch_cache_, is_cost_model_reasonable ? population - num_use_measured : population);
// 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated
// states, else perform evolutionary search
if (is_cost_model_reasonable) {
// Also insert already measured good states to the initial population
std::vector<int> indices = Argsort(measured_states_throughputs_);
for (int i = 0; i < num_use_measured; i++) {
init_population.push_back(measured_states_vector_[indices[i]]);
}
// Sample some random states for eps-greedy
*random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 3);
return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
} else {
PruneInvalidState(search_task, &init_population);
return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
}
}
Array<State> SketchPolicyNode::GenerateSketches() {
const State& init_state = search_task->compute_dag->init_state;
// Two ping pong buffers to avoid copy
Array<State> states_buf1{init_state}, states_buf2;
Array<State>* pnow = &states_buf1;
Array<State>* pnext = &states_buf2;
// A map that maps state to its current working position (stage_id)
std::unordered_map<State, int, ObjectHash, ObjectEqual> cur_stage_id_map;
cur_stage_id_map[init_state] = static_cast<int>(init_state->stages.size()) - 1;
// Derivation rule based enumeration
Array<State> out_states;
while (!pnow->empty()) {
pnext->clear();
for (const State& state : *pnow) {
int stage_id = cur_stage_id_map[state];
// Reaches to the terminal stage
if (stage_id < 0) {
out_states.push_back(state);
continue;
}
// Try all derivation rules
for (const auto& rule : sketch_rules) {
auto cond = rule->MeetCondition(*this, state, stage_id);
if (cond != SketchGenerationRule::ConditionKind::kSkip) {
for (const auto& pair : rule->Apply(*this, state, stage_id)) {
cur_stage_id_map[pair.first] = pair.second;
pnext->push_back(pair.first);
}
// Skip the rest rules
if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) {
break;
}
}
}
}
std::swap(pnow, pnext);
}
// Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(),
// so later we can sample random value for the split factor.
// Why don't we use Expr() when doing the split for rfactor at the first time?
// Because during ApplySteps, a rfactor with undefined Expr() will crash TVM.
// So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor
// in other stages
for (size_t i = 0; i < out_states.size(); ++i) {
auto state = out_states[i];
auto pstate = state.CopyOnWrite();
for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) {
if (pstate->transform_steps[step_id]->IsInstance<RfactorStepNode>()) {
CHECK_GE(step_id, 1);
int split_step_id = static_cast<int>(step_id - 1);
auto step = pstate->transform_steps[split_step_id].as<SplitStepNode>();
CHECK(step != nullptr);
pstate->transform_steps.Set(
split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt},
step->inner_to_outer));
}
}
out_states.Set(i, std::move(state));
}
StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl;
return out_states;
}
Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches, int out_size) {
int fail_ct = 0;
Array<State> out_states;
std::vector<std::mt19937> rand_gens;
rand_gens.reserve(out_size);
for (int i = 0; i < out_size; i++) {
rand_gens.push_back(std::mt19937(rand_gen()));
}
auto tic_begin = std::chrono::high_resolution_clock::now();
while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
std::vector<State> temp_states(out_size);
support::parallel_for(0, out_size - out_states.size(),
[this, &temp_states, &sketches, &rand_gens](int index) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different
// possibility for they may have different potential on generating state
// with better performance
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});
for (int i = 0; i < out_size; i++) {
if (temp_states[i].defined()) {
out_states.push_back(std::move(temp_states[i]));
} else {
fail_ct++;
}
}
}
double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
.count();
StdCout(verbose) << "Sample Initial Population\t#s: " << out_states.size()
<< "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed
<< std::setprecision(2) << duration << std::endl;
return out_states;
}
Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_population,
int out_size) {
Array<State> best_states;
auto tic_begin = std::chrono::high_resolution_clock::now();
size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);
// Two ping pong buffers to avoid copy.
Array<State> states_buf1{init_population}, states_buf2;
states_buf1.reserve(population);
states_buf2.reserve(population);
Array<State>* pnow = &states_buf1;
Array<State>* pnext = &states_buf2;
// A heap to keep the best states during evolution
using StateHeapItem = std::pair<State, float>;
auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
return left.second > right.second;
};
std::vector<StateHeapItem> heap;
std::unordered_set<std::string> in_heap(measured_states_set_);
heap.reserve(out_size);
// auxiliary global variables
std::vector<float> pop_scores;
std::vector<double> pop_selection_probs;
float max_score = 0.0;
pop_scores.reserve(population);
pop_selection_probs.reserve(population);
std::uniform_real_distribution<> dis(0.0, 1.0);
// mutation rules
int mutation_success_ct, mutation_fail_ct;
mutation_success_ct = mutation_fail_ct = 0;
std::vector<float> rule_weights;
std::vector<double> rule_selection_probs;
for (const auto& rule : mutation_rules) {
rule_weights.push_back(rule->weight);
}
ComputePrefixSumProb(rule_weights, &rule_selection_probs);
// Genetic Algorithm
for (int k = 0; k < num_iters + 1; ++k) {
// Maintain the heap
*pnow = search_task->compute_dag.InferBound(*pnow);
PruneInvalidState(search_task, pnow);
program_cost_model->Predict(search_task, *pnow, &pop_scores);
for (size_t i = 0; i < pnow->size(); ++i) {
const State& state = (*pnow)[i];
std::string state_str = state.ToStr();
if (in_heap.count(state_str) == 0) {
if (static_cast<int>(heap.size()) < out_size) {
heap.emplace_back((*pnow)[i], pop_scores[i]);
std::push_heap(heap.begin(), heap.end(), cmp);
in_heap.insert(state_str);
} else if (pop_scores[i] > heap.front().second) {
std::string old_state_str = heap.front().first.ToStr();
in_heap.erase(old_state_str);
in_heap.insert(state_str);
std::pop_heap(heap.begin(), heap.end(), cmp);
heap.back() = StateHeapItem(state, pop_scores[i]);
std::push_heap(heap.begin(), heap.end(), cmp);
}
if (pop_scores[i] > max_score) {
max_score = pop_scores[i];
}
}
}
// Print statistical information
if (k % 5 == 0 || k == num_iters) {
StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4)
<< "\tMax score: " << max_score << "\tMin score: " << heap.front().second
<< "\t#Pop: " << pnow->size() << "\t#M+: " << mutation_success_ct / (k + 1)
<< "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl;
}
if (k == num_iters) {
break;
}
// Compute selection probability
ComputePrefixSumProb(pop_scores, &pop_selection_probs);
// Do mutation
while (pnext->size() < population) {
State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)];
if (dis(rand_gen) < mutation_prob) {
const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
mutation_success_ct++;
} else {
mutation_fail_ct++;
}
} else {
pnext->push_back(std::move(tmp_s));
}
}
std::swap(pnext, pnow);
pnext->clear();
}
// Copy best states in the heap to out_states
std::sort(heap.begin(), heap.end(), cmp);
for (auto& item : heap) {
best_states.push_back(std::move(item.first));
}
double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
.count();
StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states.size()
<< "\tTime elapsed: " << std::fixed << std::setprecision(2) << duration
<< std::endl;
return best_states;
}
Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>& best_states,
const Array<State>& random_states,
int remaining_n_trials) {
int num_random =
static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter_);
int num_good = num_measure_per_iter_ - num_random;
Array<MeasureInput> inputs;
size_t offset_best = 0, offset_random = 0;
while (static_cast<int>(inputs.size()) < std::min(num_measure_per_iter_, remaining_n_trials)) {
State state;
bool has_best = offset_best < best_states.size();
bool has_random = offset_random < random_states.size();
if (static_cast<int>(inputs.size()) < num_good) {
// prefer best states
if (has_best) {
state = best_states[offset_best++];
} else if (has_random) {
state = random_states[offset_random++];
} else {
break;
}
} else {
// prefer random states
if (has_random) {
state = random_states[offset_random++];
} else if (has_best) {
state = best_states[offset_best++];
} else {
break;
}
}
// Check if it has already been measured
std::string state_str = state.ToStr();
if (!measured_states_set_.count(state_str)) {
measured_states_set_.insert(std::move(state_str));
measured_states_vector_.push_back(state);
inputs.push_back(MeasureInput(search_task, state));
}
}
return inputs;
}
TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
.set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
int seed, int verbose,
Optional<Array<SearchCallback>> init_search_callbacks) {
return SketchPolicy(task, program_cost_model, params, seed, verbose, init_search_callbacks);
});
TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches")
.set_body_typed([](SketchPolicy policy) { return policy->GenerateSketches(); });
TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation")
.set_body_typed([](SketchPolicy policy, int pop_size) {
const Array<State>& sketches = policy->GenerateSketches();
Array<State> init_population = policy->SampleInitPopulation(sketches, pop_size);
return init_population;
});
TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch")
.set_body_typed([](SketchPolicy policy, Array<State> init_population, int out_size) {
Array<State> states = policy->EvolutionarySearch(init_population, out_size);
return states;
});
} // namespace auto_scheduler
} // namespace tvm