blob: 0976e158aaf0b9d8a634d1f15b2047f0e564d86f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"
namespace tvm {
namespace meta_schedule {
/******** Workload ********/
Workload::Workload(IRModule mod) {
ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
n->shash = tvm::StructuralHash()(mod);
n->mod = mod;
data_ = std::move(n);
}
Workload::Workload(IRModule mod, Workload::THashCode shash) {
ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
n->mod = mod;
n->shash = shash;
data_ = std::move(n);
}
ObjectRef WorkloadNode::AsJSON() const {
// Convert `this->mod` to JSON
std::string json_mod = tvm::SaveJSON(this->mod);
// Dump the JSON string to base64
std::string b64_mod = Base64Encode(json_mod);
// Output
return Array<ObjectRef>{SHash2Str(this->shash), String(b64_mod)};
}
Workload Workload::FromJSON(const ObjectRef& json_obj) {
IRModule mod{nullptr};
THashCode shash = 0;
try {
const ArrayNode* json_array = json_obj.as<ArrayNode>();
CHECK(json_array && json_array->size() == 2);
// Load json[0] => shash
String str_shash = Downcast<String>(json_array->at(0));
// Load json[1] => mod
{
String b64_mod = Downcast<String>(json_array->at(1));
std::string json_mod = Base64Decode(b64_mod);
mod = Downcast<IRModule>(LoadJSON(json_mod));
}
// Verify SHash(mod) == shash
shash = tvm::StructuralHash()(mod);
String recalc_shash = SHash2Str(shash);
CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash
<< "; Recalculated: " << recalc_shash;
} catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
<< "\nThe error is: " << e.what();
}
return Workload(mod, shash);
}
/******** TuningRecord ********/
TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs,
Optional<Target> target, Optional<Array<ArgInfo>> args_info) {
ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>();
n->trace = trace;
n->workload = workload;
n->run_secs = run_secs;
n->target = target;
n->args_info = args_info;
this->data_ = n;
}
MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
tir::Schedule sch =
tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail);
trace->ApplyToSchedule(sch, false, nullptr);
return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true));
}
ObjectRef TuningRecordNode::AsJSON() const {
Optional<Array<ObjectRef>> json_args_info{nullptr};
Optional<ObjectRef> json_target{nullptr};
if (args_info.defined()) {
Array<ObjectRef> info;
info.reserve(args_info.value().size());
for (const ArgInfo& arg_info : args_info.value()) {
info.push_back(arg_info->AsJSON());
}
json_args_info = info;
}
if (target.defined()) {
json_target = target.value()->Export();
}
return Array<ObjectRef>{trace->AsJSON(false), //
run_secs, //
json_target, //
json_args_info};
}
TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) {
tir::Trace trace{nullptr};
Optional<Array<FloatImm>> run_secs{nullptr};
Optional<Target> target{nullptr};
Optional<Array<ArgInfo>> args_info{nullptr};
try {
const ArrayNode* json_array = json_obj.as<ArrayNode>();
CHECK(json_array && json_array->size() == 4);
// Load json[1] => run_secs
if (json_array->at(1).defined()) {
run_secs = AsFloatArray(json_array->at(1));
}
// Load json[2] => target
if (json_array->at(2).defined()) {
target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
}
// Load json[3] => args_info
if (json_array->at(3).defined()) {
const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>();
Array<ArgInfo> info;
info.reserve(json_args_info->size());
for (const ObjectRef& json_arg_info : *json_args_info) {
info.push_back(ArgInfo::FromJSON(json_arg_info));
}
args_info = info;
}
// Load json[0] => trace
{
const ObjectRef& json_trace = json_array->at(0);
tir::Schedule sch =
tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
tir::Trace::ApplyJSONToSchedule(json_trace, sch);
trace = sch->trace().value();
}
} catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
<< "\nThe error is: " << e.what();
}
return TuningRecord(trace, workload, run_secs, target, args_info);
}
/******** Database ********/
Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name) {
if (!this->HasWorkload(mod)) {
return NullOpt;
}
Array<TuningRecord> records = this->GetTopK(this->CommitWorkload(mod), 1);
if (records.empty()) {
return NullOpt;
}
ICHECK_EQ(records.size(), 1);
return records[0];
}
Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, workload_name)) {
TuningRecord record = opt_record.value();
tir::Schedule sch =
tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
record->trace->ApplyToSchedule(sch, false);
return sch;
} else {
return NullOpt;
}
}
Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, workload_name)) {
return opt_sch.value()->mod();
} else {
return NullOpt;
}
}
std::vector<Database>* ThreadLocalDatabases() {
static thread_local std::vector<Database> tls;
return &tls;
}
void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); }
void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); }
Optional<Database> Database::Current() {
std::vector<Database>* tls = ThreadLocalDatabases();
if (tls->empty()) {
return NullOpt;
} else {
return tls->back();
}
}
/******** PyDatabase ********/
Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
PyDatabaseNode::FQuerySchedule f_query_schedule,
PyDatabaseNode::FQueryIRModule f_query_ir_module,
PyDatabaseNode::FSize f_size) {
ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
n->f_has_workload = f_has_workload;
n->f_commit_workload = f_commit_workload;
n->f_commit_tuning_record = f_commit_tuning_record;
n->f_get_top_k = f_get_top_k;
n->f_get_all_tuning_records = f_get_all_tuning_records;
n->f_query_tuning_record = f_query_tuning_record;
n->f_query_schedule = f_query_schedule;
n->f_query_ir_module = f_query_ir_module;
n->f_size = f_size;
return Database(n);
}
/******** FFI ********/
TVM_REGISTER_NODE_TYPE(WorkloadNode);
TVM_REGISTER_NODE_TYPE(TuningRecordNode);
TVM_REGISTER_OBJECT_TYPE(DatabaseNode);
TVM_REGISTER_NODE_TYPE(PyDatabaseNode);
TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) {
return Workload(mod);
});
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON")
.set_body_method<Workload>(&WorkloadNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
.set_body_typed([](tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs,
Optional<Target> target, Optional<Array<ArgInfo>> args_info) {
return TuningRecord(trace, workload, run_secs, target, args_info);
});
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate")
.set_body_method<TuningRecord>(&TuningRecordNode::AsMeasureCandidate);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
.set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope")
.set_body_method(&Database::EnterWithScope);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope")
.set_body_method(&Database::ExitWithScope);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
.set_body_method<Database>(&DatabaseNode::HasWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
.set_body_method<Database>(&DatabaseNode::CommitWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
.set_body_method<Database>(&DatabaseNode::CommitTuningRecord);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
.set_body_method<Database>(&DatabaseNode::GetTopK);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords")
.set_body_method<Database>(&DatabaseNode::GetAllTuningRecords);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord")
.set_body_method<Database>(&DatabaseNode::QueryTuningRecord);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule")
.set_body_method<Database>(&DatabaseNode::QuerySchedule);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule")
.set_body_method<Database>(&DatabaseNode::QueryIRModule);
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);
} // namespace meta_schedule
} // namespace tvm