blob: 56facea1567f2993ca23d2f98ab6f15dda820915 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file source_module.cc
* \brief Source code module, only for viewing
*/
#include <tvm/runtime/packed_func.h>
#include "codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;
// Simulator function
class SourceModuleNode : public runtime::ModuleNode {
public:
SourceModuleNode(std::string code,
std::string fmt)
: code_(code), fmt_(fmt) {}
const char* type_key() const {
return "source";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
return code_;
}
protected:
std::string code_;
std::string fmt_;
};
runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<SourceModuleNode> n =
std::make_shared<SourceModuleNode>(code, fmt);
return runtime::Module(n);
}
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
public:
CSourceModuleNode(std::string code,
std::string fmt)
: code_(code), fmt_(fmt) {}
const char* type_key() const {
return "c";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "C Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
return code_;
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "cc") {
CHECK_NE(code_.length(), 0);
SaveBinaryToFile(file_name, code_);
} else {
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
}
}
protected:
std::string code_;
std::string fmt_;
};
runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<CSourceModuleNode> n =
std::make_shared<CSourceModuleNode>(code, fmt);
return runtime::Module(n);
}
// supports limited save without cross compile
class DeviceSourceModuleNode final : public runtime::ModuleNode {
public:
DeviceSourceModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key,
std::function<std::string(const std::string&)> fget_source)
: data_(data),
fmt_(fmt),
fmap_(fmap),
type_key_(type_key),
fget_source_(fget_source) {}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
if (fget_source_ != nullptr) {
return fget_source_(format);
} else {
return data_;
}
}
const char* type_key() const {
return type_key_.c_str();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
private:
std::string data_;
std::string fmt_;
std::unordered_map<std::string, FunctionInfo> fmap_;
std::string type_key_;
std::function<std::string(const std::string&)> fget_source_;
};
runtime::Module DeviceSourceModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key,
std::function<std::string(const std::string&)> fget_source) {
std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm