blob: 4726b09dd2a9fcefdc52fee83464416686c0458b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file file_utils.cc
*/
#include "file_utils.h"
#include <tvm/ffi/extra/json.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/support/io.h>
#include <fstream>
#include <utility>
#include <vector>
#include "../support/bytes_io.h"
namespace tvm {
namespace runtime {
FunctionInfo::FunctionInfo(ffi::String name, ffi::Array<DLDataType> arg_types,
ffi::Array<ffi::String> launch_param_tags,
ffi::Array<ArgExtraTags> arg_extra_tags) {
auto n = ffi::make_object<FunctionInfoObj>();
n->name = std::move(name);
n->arg_types = std::move(arg_types);
n->launch_param_tags = std::move(launch_param_tags);
n->arg_extra_tags = std::move(arg_extra_tags);
data_ = std::move(n);
}
ffi::json::Value FunctionInfoObj::SaveToJSON() const {
namespace json = ::tvm::ffi::json;
json::Object obj;
obj.Set("name", name);
// arg_types: convert DLDataType to string
json::Array sarg_types;
for (const auto& t : arg_types) {
sarg_types.push_back(ffi::String(DLDataTypeToString(t)));
}
obj.Set("arg_types", std::move(sarg_types));
{
json::Array tags;
for (const auto& s : launch_param_tags) tags.push_back(s);
obj.Set("launch_param_tags", std::move(tags));
}
// arg_extra_tags: store as int
json::Array iarg_extra_tags;
for (const auto& t : arg_extra_tags) {
iarg_extra_tags.push_back(static_cast<int64_t>(t));
}
obj.Set("arg_extra_tags", std::move(iarg_extra_tags));
return obj;
}
void FunctionInfoObj::LoadFromJSON(ffi::json::Object src) {
namespace json = ::tvm::ffi::json;
name = src.at("name").cast<ffi::String>();
// arg_types
auto sarg_types_arr = src.at("arg_types").cast<json::Array>();
arg_types = ffi::Array<DLDataType>();
for (size_t i = 0; i < sarg_types_arr.size(); ++i) {
arg_types.push_back(StringToDLDataType(std::string(sarg_types_arr[i].cast<ffi::String>())));
}
// launch_param_tags (optional, also support legacy "thread_axis_tags")
auto lt = src.find("launch_param_tags");
if (lt != src.end()) {
auto arr = (*lt).second.cast<json::Array>();
launch_param_tags = ffi::Array<ffi::String>();
for (const auto& elem : arr) launch_param_tags.push_back(elem.cast<ffi::String>());
} else {
auto tt = src.find("thread_axis_tags");
if (tt != src.end()) {
auto arr = (*tt).second.cast<json::Array>();
launch_param_tags = ffi::Array<ffi::String>();
for (const auto& elem : arr) launch_param_tags.push_back(elem.cast<ffi::String>());
}
}
// arg_extra_tags (optional)
auto et = src.find("arg_extra_tags");
if (et != src.end()) {
auto earr = (*et).second.cast<json::Array>();
arg_extra_tags = ffi::Array<ArgExtraTags>();
for (size_t i = 0; i < earr.size(); ++i) {
arg_extra_tags.push_back(static_cast<ArgExtraTags>(earr[i].cast<int64_t>()));
}
}
}
std::string GetFileFormat(const std::string& file_name, const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
std::string GetCacheDir() {
char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm";
}
if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm";
}
return ".";
}
std::string GetFileBasename(const std::string& file_name) {
size_t last_slash = file_name.find_last_of("/");
if (last_slash == std::string::npos) return file_name;
return file_name.substr(last_slash + 1);
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
return file_name + ".tvm_meta.json";
}
}
void LoadBinaryFromFile(const std::string& file_name, std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = static_cast<size_t>(fs.tellg());
fs.seekg(0, std::ios::beg);
data->resize(size);
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(const std::string& file_name, const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
void SaveMetaDataToFile(const std::string& file_name,
const ffi::Map<ffi::String, FunctionInfo>& fmap) {
namespace json = ::tvm::ffi::json;
json::Object root;
root.Set("tvm_version", ffi::String("0.1.0"));
json::Object func_info;
for (const auto& kv : fmap) {
func_info.Set(kv.first, kv.second->SaveToJSON());
}
root.Set("func_info", std::move(func_info));
std::ofstream fs(file_name.c_str());
TVM_FFI_ICHECK(!fs.fail()) << "Cannot open file " << file_name;
fs << std::string(json::Stringify(root, 2));
fs.close();
}
void LoadMetaDataFromFile(const std::string& file_name, ffi::Map<ffi::String, FunctionInfo>* fmap) {
namespace json = ::tvm::ffi::json;
std::ifstream fs(file_name.c_str());
TVM_FFI_ICHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string content((std::istreambuf_iterator<char>(fs)), std::istreambuf_iterator<char>());
fs.close();
auto root = json::Parse(content).cast<json::Object>();
// tvm_version is ignored
auto func_info_obj = root.at("func_info").cast<json::Object>();
for (const auto& kv : func_info_obj) {
auto info_node = ffi::make_object<FunctionInfoObj>();
info_node->LoadFromJSON(kv.second.cast<json::Object>());
fmap->Set(kv.first.cast<ffi::String>(), FunctionInfo(std::move(info_node)));
}
}
void RemoveFile(const std::string& file_name) {
// FIXME: This doesn't check the return code.
std::remove(file_name.c_str());
}
void CopyFile(const std::string& src_file_name, const std::string& dest_file_name) {
std::ifstream src(src_file_name, std::ios::binary);
TVM_FFI_ICHECK(src) << "Unable to open source file '" << src_file_name << "'";
std::ofstream dest(dest_file_name, std::ios::binary | std::ios::trunc);
TVM_FFI_ICHECK(dest) << "Unable to destination source file '" << src_file_name << "'";
dest << src.rdbuf();
src.close();
dest.close();
TVM_FFI_ICHECK(dest) << "File-copy operation failed."
<< " src='" << src_file_name << "'"
<< " dest='" << dest_file_name << "'";
}
ffi::Map<ffi::String, Tensor> LoadParams(const std::string& param_blob) {
support::BytesInStream strm(param_blob);
return LoadParams(&strm);
}
ffi::Map<ffi::String, Tensor> LoadParams(support::Stream* strm) {
ffi::Map<ffi::String, Tensor> params;
uint64_t header, reserved;
TVM_FFI_ICHECK(strm->Read(&header)) << "Invalid parameters file format";
TVM_FFI_ICHECK(header == kTVMTensorListMagic) << "Invalid parameters file format";
TVM_FFI_ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
std::vector<std::string> names;
TVM_FFI_ICHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
TVM_FFI_ICHECK(size == names.size()) << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
// The data_entry is allocated on device, Tensor.load always load the array into CPU.
Tensor temp;
temp.Load(strm);
params.Set(names[i], temp);
}
return params;
}
void SaveParams(support::Stream* strm, const ffi::Map<ffi::String, Tensor>& params) {
std::vector<std::string> names;
std::vector<const DLTensor*> arrays;
for (auto& p : params) {
names.push_back(p.first);
arrays.push_back(p.second.operator->());
}
uint64_t header = kTVMTensorListMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
strm->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
strm->Write(sz);
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::SaveDLTensor(strm, arrays[i]);
}
}
}
std::string SaveParams(const ffi::Map<ffi::String, Tensor>& params) {
std::string result;
support::BytesOutStream strm(&result);
SaveParams(&strm, params);
return result;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FunctionInfoObj>()
.def_ro("name", &FunctionInfoObj::name)
.def_ro("arg_types", &FunctionInfoObj::arg_types)
.def_ro("launch_param_tags", &FunctionInfoObj::launch_param_tags)
.def_ro("arg_extra_tags", &FunctionInfoObj::arg_extra_tags);
refl::GlobalDef()
.def("runtime.SaveParams",
[](const ffi::Map<ffi::String, Tensor>& params) {
std::string s = ::tvm::runtime::SaveParams(params);
return ffi::Bytes(std::move(s));
})
.def("runtime.SaveParamsToFile",
[](const ffi::Map<ffi::String, Tensor>& params, const ffi::String& path) {
tvm::runtime::SimpleBinaryFileStream strm(path, "wb");
SaveParams(&strm, params);
})
.def("runtime.LoadParams", [](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); })
.def("runtime.LoadParamsFromFile", [](const ffi::String& path) {
tvm::runtime::SimpleBinaryFileStream strm(path, "rb");
return LoadParams(&strm);
});
}
} // namespace runtime
} // namespace tvm