blob: dd6b70573a3b8b52d6dbcbe203227969987fcbc7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/auto_schedule.cc
* \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
* schedule search requirements from upper level (Python API), and returns a high performance
* schedule after search process.
*/
#include <tvm/auto_scheduler/auto_schedule.h>
#include <tvm/runtime/registry.h>
#include "utils.h"
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_NODE_TYPE(TuningOptionsNode);
TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round,
int verbose, ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> measure_callbacks) {
auto node = make_object<TuningOptionsNode>();
node->num_measure_trials = num_measure_trials;
node->early_stopping = early_stopping;
node->num_measures_per_round = num_measures_per_round;
node->verbose = verbose;
node->builder = std::move(builder);
node->runner = std::move(runner);
node->measure_callbacks = std::move(measure_callbacks);
data_ = std::move(node);
}
std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy,
TuningOptions tuning_options) {
// Create a ProgramMeasurer to handle the schedule build and performance measure
ProgramMeasurer measurer =
ProgramMeasurer(tuning_options->builder, tuning_options->runner,
tuning_options->measure_callbacks, tuning_options->verbose);
// Search for the best schedule
State state =
search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping,
tuning_options->num_measures_per_round, measurer);
if (state.defined()) {
return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
} else {
StdCout(tuning_options->verbose)
<< "No valid state found in this search round. Check if it has traversed all of the "
<< "search space." << std::endl;
// Return the default schedule
return {te::Schedule(search_policy->search_task->compute_dag->ops),
search_policy->search_task->compute_dag->tensors};
}
}
TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")
.set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round,
int verbose, ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> measure_callbacks) {
return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose,
builder, runner, measure_callbacks);
});
TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule")
.set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options);
return Array<ObjectRef>{sch, return_tensors};
});
} // namespace auto_scheduler
} // namespace tvm