blob: 3d732e7fbd99275f71a22074240229b2943a5a1e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/mutator.h>
#include <tvm/meta_schedule/postproc.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>
namespace tvm {
namespace meta_schedule {
class TaskSchedulerNode;
class MeasureCallback;
/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
public:
/*! \brief The workload to be tuned. */
Optional<IRModule> mod;
/*! \brief The target to be tuned for. */
Optional<Target> target;
/*! \brief The design space generator. */
Optional<SpaceGenerator> space_generator;
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Array<ScheduleRule> sch_rules;
/*! \brief The postprocessors. */
Array<Postproc> postprocs;
/*! \brief The probability of using certain mutator. */
Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_terminated; // TODO(@junrushao1994): move to TaskScheduler
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Optional<Array<BuilderResult>> builder_results;
/*! \brief Packed functions to fetch the runner results asynchronously. */
Optional<Array<RunnerFuture>> runner_futures;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("sch_rules", &sch_rules);
v->Visit("postprocs", &postprocs);
v->Visit("mutator_probs", &mutator_probs);
v->Visit("task_name", &task_name);
// `logging_func` is not visited
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_terminated", &is_terminated);
v->Visit("measure_candidates", &measure_candidates);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
}
/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();
/*! \brief Set the measure candidates from the SearchStrategy */
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
/*!
* \brief Send the measure candidates to builder.
* \param builder The builder to send the candidates to.
*/
void _SendToBuilder(const Builder& builder);
/*!
* \brief Send the built measure candidates to runner.
* \param runner The runner to send the candidates to.
*/
void _SendToRunner(const Runner& runner);
/*!
* \brief Join the running tasks.
* \returns The results from the runner
*/
Array<RunnerResult> _Join();
/*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */
void _ClearMeasureState();
static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};
/*!
* \brief Managed reference to TuneContextNode.
* \sa TuneContextNode
*/
class TuneContext : public runtime::ObjectRef {
public:
/*!
* \brief Constructor.
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param logging_func The tuning task's logging function.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<Array<Postproc>> postprocs, //
Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
PackedFunc logging_func, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
};
} // namespace meta_schedule
} // namespace tvm
#endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_