blob: d73bd911a921233920f9d150dc332e2214e5572a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_policy/search_policy.cc
* \brief The base class of search policies.
*/
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/runtime/registry.h>
#include "utils.h"
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode);
TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode);
TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode);
void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) {
RecordReader reader = RecordReader(log_file);
const auto& res = reader->ReadLines(-1);
size_t log_size = res.first.size();
CHECK_EQ(log_size, res.second.size());
if (log_size) {
Array<State> measured_states;
std::vector<float> measured_throughputs;
for (size_t i = 0; i < log_size; i++) {
const auto& inp = res.first[i];
if (inp->task->workload_key == search_task->workload_key &&
inp->task->target->kind->name.compare(search_task->target->kind->name) == 0) {
State state = search_task->compute_dag->init_state;
auto pstate = state.CopyOnWrite();
pstate->transform_steps = inp->state->transform_steps;
for (const auto& step : pstate->transform_steps) {
StepApplyToState(step, &state, search_task->compute_dag);
}
measured_states.push_back(std::move(state));
measured_throughputs.push_back(
res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
}
}
// We can assume the recorded states will all be valid after infer bound
measured_states = search_task->compute_dag.InferBound(measured_states);
for (size_t i = 0; i < measured_states.size(); i++) {
auto& state = measured_states[i];
const auto& state_str = state.ToStr();
if (!measured_states_set_.count(state_str)) {
measured_states_set_.insert(state_str);
if (measured_throughputs[i] != 0.0) {
measured_states_vector_.emplace_back(std::move(state));
measured_states_throughputs_.emplace_back(measured_throughputs[i]);
}
}
}
StdCout(verbose) << "SearchPolicy: Loaded " << measured_states_set_.size()
<< " measurement records from " << log_file << " for "
<< search_task->workload_key << std::endl;
} else {
StdCout(verbose) << "SearchPolicy: No measurement records found in " << log_file << " for "
<< search_task->workload_key << std::endl;
}
}
void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
for (const auto& callback : callbacks) {
callback->Callback(this);
}
}
PreloadMeasuredStates::PreloadMeasuredStates(String filename) {
auto node = make_object<PreloadMeasuredStatesNode>();
node->filename = std::move(filename);
data_ = std::move(node);
}
void PreloadMeasuredStatesNode::Callback(SearchPolicyNode* policy) {
policy->PreloadMeasuredStates(filename);
}
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
.set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) {
if (callbacks) {
policy->RunCallbacks(callbacks.value());
}
});
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetTask")
.set_body_typed([](SearchPolicy policy, SearchTask task) { policy->search_task = task; });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
.set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
TVM_REGISTER_GLOBAL("auto_scheduler.PreloadMeasuredStates").set_body_typed([](String filename) {
return PreloadMeasuredStates(filename);
});
} // namespace auto_scheduler
} // namespace tvm