blob: 99c01b17e78e7e9f15638efcf19a2ff99f379a19 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/measure_record.cc
* \brief Json serialization format for dumping and loading tuning records.
*/
#include <dmlc/json.h>
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "utils.h"
// Json serialization handler for MeasureInput, MeasureResult
// (and recursively for SearchTask, State, Step, ...)
namespace dmlc {
namespace json {
template <>
struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::Array<::tvm::auto_scheduler::Stage>& data) {
writer->BeginArray(false);
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::Array<::tvm::auto_scheduler::Stage>* data) {
bool s;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(!s);
}
};
template <>
struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::Array<::tvm::auto_scheduler::Step>& data) {
writer->BeginArray(false);
for (const auto& step : data) {
writer->WriteArraySeperator();
writer->BeginArray(false);
step->WriteToRecord(writer);
writer->EndArray();
}
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::Array<::tvm::auto_scheduler::Step>* data) {
bool s;
reader->BeginArray();
data->clear();
while (reader->NextArrayItem()) {
reader->BeginArray();
data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
s = reader->NextArrayItem();
CHECK(!s);
}
}
};
template <>
struct Handler<::tvm::auto_scheduler::StateNode> {
inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_scheduler::StateNode& data) {
writer->BeginArray(false);
writer->WriteArrayItem(data.stages);
writer->WriteArrayItem(data.transform_steps);
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) {
bool s;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->stages);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->transform_steps);
s = reader->NextArrayItem();
CHECK(!s);
}
};
template <>
struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::auto_scheduler::SearchTaskNode& data) {
writer->BeginArray(false);
writer->WriteArrayItem(std::string(data.workload_key));
writer->WriteArrayItem(data.target->str());
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
bool s;
std::string str_value;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&str_value);
data->workload_key = std::move(str_value);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&str_value);
data->target = ::tvm::Target(str_value);
s = reader->NextArrayItem();
CHECK(!s);
}
};
template <>
struct Handler<::tvm::auto_scheduler::MeasureInputNode> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::auto_scheduler::MeasureInputNode& data) {
writer->BeginArray(false);
writer->WriteArrayItem(*data.task.operator->());
writer->WriteArrayItem(*data.state.operator->());
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) {
auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>();
auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>();
state_node->concrete = true;
bool s;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
reader->Read(task_node.get());
s = reader->NextArrayItem();
CHECK(s);
reader->Read(state_node.get());
s = reader->NextArrayItem();
CHECK(!s);
data->task = ::tvm::auto_scheduler::SearchTask(task_node);
data->state = ::tvm::auto_scheduler::State(state_node);
}
};
template <>
struct Handler<::tvm::auto_scheduler::MeasureResultNode> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::auto_scheduler::MeasureResultNode& data) {
writer->BeginArray(false);
writer->WriteArraySeperator();
writer->BeginArray(false);
for (const auto& x : data.costs) {
auto pf = x.as<::tvm::tir::FloatImmNode>();
CHECK(pf != nullptr) << "Cost can only contain float values";
writer->WriteArrayItem(pf->value);
}
writer->EndArray();
writer->WriteArrayItem(data.error_no);
writer->WriteArrayItem(data.all_cost);
writer->WriteArrayItem(static_cast<int>((data.timestamp)));
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::auto_scheduler::MeasureResultNode* data) {
std::vector<double> double_list;
bool s;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&double_list);
data->costs.clear();
for (const auto& i : double_list) {
data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
}
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->error_no);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->all_cost);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->timestamp);
s = reader->NextArrayItem();
CHECK(!s);
}
};
} // namespace json
} // namespace dmlc
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.2"; // NOLINT(*)
RecordToFile::RecordToFile(String filename) {
auto node = make_object<RecordToFileNode>();
node->filename = std::move(filename);
data_ = std::move(node);
}
void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
const Array<MeasureResult>& results) {
dmlc::JSONWriter writer(os);
for (size_t i = 0; i < inputs.size(); ++i) {
writer.BeginObject(false);
writer.WriteObjectKeyValue("i", *inputs[i].operator->());
writer.WriteObjectKeyValue("r", *results[i].operator->());
writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION);
writer.EndObject();
*os << "\n";
}
}
void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
std::string* log_version) {
std::istringstream ss(str);
dmlc::JSONReader reader(&ss);
std::string key;
reader.BeginObject();
while (reader.NextObjectItem(&key)) {
if (key == "i") {
reader.Read(inp);
} else if (key == "r") {
reader.Read(res);
} else if (key == "v") {
reader.Read(log_version);
} else {
LOG(FATAL) << "Invalid key in json log: " << key;
}
}
}
void RecordToFileNode::Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
const Array<MeasureResult>& results) {
std::ofstream ofs(filename, std::ofstream::app);
WriteMeasureRecords(&ofs, inputs, results);
}
RecordReader::RecordReader(String filename) {
auto node = make_object<RecordReaderNode>();
node->filename = filename;
node->infile.open(filename, std::ifstream::in);
data_ = std::move(node);
}
RecordReaderNode::~RecordReaderNode() { infile.close(); }
bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) {
std::string log_version;
while (std::getline(infile, cur_line_)) {
if (cur_line_[0] == '#' || cur_line_[0] == ' ') {
// skip comment lines begin with '#' or ' '
continue;
}
ReadMeasureRecord(cur_line_, inp, res, &log_version);
return true;
}
return false;
}
std::pair<Array<MeasureInput>, Array<MeasureResult>> RecordReaderNode::ReadLines(int max_size,
int skip_size) {
auto inp = make_object<MeasureInputNode>();
auto res = make_object<MeasureResultNode>();
Array<MeasureInput> inputs;
Array<MeasureResult> results;
while (ReadNext(inp.get(), res.get())) {
if (skip_size > 0) {
skip_size--;
continue;
}
inputs.push_back(inp->copy());
results.push_back(res->copy());
if (max_size > 0 && static_cast<int>(inputs.size()) >= max_size) {
break;
}
}
return std::make_pair(inputs, results);
}
TVM_REGISTER_GLOBAL("auto_scheduler.RecordToFile").set_body_typed([](const String& filename) {
return RecordToFile(filename);
});
TVM_REGISTER_GLOBAL("auto_scheduler.RecordReader").set_body_typed([](const String& filename) {
return RecordReader(filename);
});
TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadLines")
.set_body_typed([](RecordReader reader, int size, int skip_size) {
const auto& res = reader->ReadLines(size, skip_size);
return Array<ObjectRef>{res.first, res.second};
});
TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](RecordReader reader) {
auto inp = make_object<MeasureInputNode>();
auto res = make_object<MeasureResultNode>();
if (reader->ReadNext(inp.get(), res.get())) {
return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
} else {
return Array<ObjectRef>();
}
});
TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords")
.set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
std::ofstream ofs(filename, std::ofstream::app);
WriteMeasureRecords(&ofs, in, res);
});
} // namespace auto_scheduler
} // namespace tvm