blob: af1ef7225bcbb90c6c1320eff95d722bc60fef69 [file] [log] [blame]
#include <memory>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
extern unsigned char build_graph_json[];
extern unsigned int build_graph_json_len;
extern unsigned char build_params_bin[];
extern unsigned int build_params_bin_len;
#define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C"
TVM_BUNDLE_FUNCTION void *tvm_runtime_create() {
const std::string json_data(&build_graph_json[0],
&build_graph_json[0] + build_graph_json_len);
tvm::runtime::Module mod_syslib =
(*tvm::runtime::Registry::Get("module._GetSystemLib"))();
int device_type = kDLCPU;
int device_id = 0;
tvm::runtime::Module mod =
(*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(
json_data, mod_syslib, device_type, device_id);
TVMByteArray params;
params.data = reinterpret_cast<const char *>(&build_params_bin[0]);
params.size = build_params_bin_len;
mod.GetFunction("load_params")(params);
return new tvm::runtime::Module(mod);
}
TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void *handle) {
delete reinterpret_cast<tvm::runtime::Module *>(handle);
}
TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void *handle, const char *name,
void *tensor) {
reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("set_input")(
name, reinterpret_cast<DLTensor *>(tensor));
}
TVM_BUNDLE_FUNCTION void tvm_runtime_run(void *handle) {
reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("run")();
}
TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index,
void *tensor) {
reinterpret_cast<tvm::runtime::Module *>(handle)->GetFunction("get_output")(
index, reinterpret_cast<DLTensor *>(tensor));
}