blob: a75a4cd8ae86874c53e5420e39ef155639412151 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/measure_candidate.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/schedule/schedule.h>
namespace tvm {
namespace meta_schedule {
// Forward declaration
class TuneContext;
/*!
* \brief The search strategy for measure candidates generation.
* \note The relationship between SearchStrategy and other classes are as follows:
┌──────────────────────────────────────────────────────────────┐
┌──┴───────────────────────────────────────────────────────────┐ │
┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
│ ┌─────────────────────┐ │ │ │
│ │ │ Generate │ │ │
│ │ Space Generator ├──────────────┐ │ │ │
│ │ │ │ │ │ │
│ └─────────────────────┘ ▼ │ │ │
│ Design Space │ │ │
│ ┌─────────────────────┐ │ │ │ │
│ Generate │ │ Pretuning │ │ │ │
│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
│ │ │ │ │ ├──┘
│ │ └─────────────────────┘ ├──┘
└────┼─────────────────────────────────────────────────────────┘
┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
│ │ ┌───────────┐ │
│ │ Send to │ │ Send to │
│ ▼ ┌─────────────►│ Builder ├──────────┐ │
│ Measure Candidate │ Builder │ │ Runner │ │
│ │ │ └───────────┘ │ │
│ │ ┌────────────┴────────┐ │ │
│ │ │ │ ┌───────────┐ │ │
│ └────►│ Task Scheduler │ │ │ │ │
│ │ │ │ Runner │◄─────────┘ │
│ └─────────────────────┘ │ │ │
│ ▲ └─────┬─────┘ │
│ │ │ │
│ └─── Runner Future ◄────┘ │
└─────────────────────────────────────────────────────────────────────┘
*/
class SearchStrategyNode : public runtime::Object {
public:
/*! \brief Virtual destructor */
virtual ~SearchStrategyNode() = default;
/*!
* \brief Initialize the search strategy with tuning context.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
/*!
* \brief Pre-tuning for the search strategy.
* \param design_spaces The design spaces used during tuning process.
* \param database The database used during tuning process.
* \param cost_model The cost model used during tuning process.
* \note Pre-tuning is supposed to be called before the tuning process and after the
* initialization. Because the search strategy is stateful, we can always call pretuning
* and reset the search strategy.
*/
virtual void PreTuning(const Array<tir::Schedule>& design_spaces,
const Optional<Database>& database,
const Optional<CostModel>& cost_model) = 0;
/*!
* \brief Post-tuning for the search strategy.
* \note Post-tuning is supposed to be called after the tuning process and before we reset the
* search strategy with another pre-tuning. Post-tuning can be empty.
*/
virtual void PostTuning() = 0;
/*!
* \brief Generate measure candidates from design spaces for measurement.
* \return The measure candidates generated, nullptr if finished.
*/
virtual Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() = 0;
/*!
* \brief Update the search strategy with measurement results.
* \param measure_candidates The candidates to be measured.
* \param results The measurement results from the runner.
*/
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) = 0;
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
};
/*! \brief The python side customizable class for measure candidate generation */
class PySearchStrategyNode : public SearchStrategyNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief The function type of `PreTuning` method.
* \param design_spaces The design spaces for pre-tuning.
*/
using FPreTuning = runtime::TypedPackedFunc<void(
const Array<tir::Schedule>&, const Optional<Database>&, const Optional<CostModel>&)>;
/*! \brief The function type of `PostTuning` method. */
using FPostTuning = runtime::TypedPackedFunc<void()>;
/*!
* \brief The function type of `GenerateMeasureCandidates` method.
* \return The measure candidates generated, nullptr if finished.
*/
using FGenerateMeasureCandidates = runtime::TypedPackedFunc<Optional<Array<MeasureCandidate>>()>;
/*!
* \brief The function type of `NotifyRunnerResults` method.
* \param results The measurement results from the runner.
*/
using FNotifyRunnerResults =
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
/*! \brief The packed function to the `InitializeWithTuneContext` method. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `PreTuning` method. */
FPreTuning f_pre_tuning;
/*! \brief The packed function to the `PostTuning` method. */
FPostTuning f_post_tuning;
/*! \brief The packed function to the `GenerateMeasureCandidates` method. */
FGenerateMeasureCandidates f_generate_measure_candidates;
/*! \brief The packed function to the `NotifyRunnerResults` method. */
FNotifyRunnerResults f_notify_runner_results;
void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_pre_tuning` is not visited
// `f_post_tuning` is not visited
// `f_generate_measure_candidates` is not visited
// `f_notify_runner_results` is not visited
}
void InitializeWithTuneContext(const TuneContext& context) final;
void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
const Optional<CostModel>& cost_model) final;
void PostTuning() final;
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results);
static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
};
/*!
* \brief Managed reference to SearchStrategyNode.
* \sa SearchStrategyNode
*/
class SearchStrategy : public runtime::ObjectRef {
public:
/*!
* \brief Create a search strategy with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_pre_tuning The packed function of `PreTuning`.
* \param f_post_tuning The packed function of `PostTuning`.
* \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`.
* \param f_notify_runner_results The packed function of `NotifyRunnerResults`.
* \return The search strategy created.
*/
TVM_DLL static SearchStrategy PySearchStrategy(
PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PySearchStrategyNode::FPreTuning f_pre_tuning, //
PySearchStrategyNode::FPostTuning f_post_tuning, //
PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, //
PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results);
/*!
* \brief Constructor of replay trace search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for trace replaying.
* \param max_fail_count The max number of failures during trace replaying.
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task,
int max_fail_count);
/*!
* \brief Constructor of replay func search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for func replaying.
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);
/*!
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for evolutionary search.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
* \param genetic_num_iters The iterations to run the genetic algorithm.
* \param genetic_mutate_prob The probability of mutation.
* \param genetic_max_fail_count The maximum number to try evolving the given trace.
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int max_trials_per_task, //
int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
double eps_greedy);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};
} // namespace meta_schedule
} // namespace tvm
#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_