blob: 8ab4160306ff8ea99869a3c135bebf7c81f93041 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
// LINT_C_FILE
/*!
* \file aot_executor.c
* \brief implement AoT executor in C
*/
#include <inttypes.h>
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/crt/aot_executor.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/module.h>
#include <tvm/runtime/crt/packed_func.h>
#include <tvm/runtime/crt/page_allocator.h>
static void DumpMetadata(const TVMMetadata* md) {
LOG_DEBUG("%s:\n", __FUNCTION__);
LOG_DEBUG("\tmod_name=%s\n", md->mod_name);
LOG_DEBUG("\tversion=%" PRId64 "\n", md->version);
LOG_DEBUG("\tnum_inputs=%" PRId64 "\n", md->num_inputs);
LOG_DEBUG("\tnum_outputs=%" PRId64 "\n", md->num_outputs);
LOG_DEBUG("\tnum_workspace_pools=%" PRId64 "\n", md->num_workspace_pools);
LOG_DEBUG("\tnum_constant_pools=%" PRId64 "\n", md->num_constant_pools);
int i;
for (i = 0; i < md->num_inputs; ++i) {
LOG_DEBUG("\tinput[%d]: %s\n", i, md->inputs[i].name);
}
for (i = 0; i < md->num_outputs; ++i) {
LOG_DEBUG("\toutput[%d]: %s\n", i, md->outputs[i].name);
}
for (i = 0; i < md->num_workspace_pools; ++i) {
LOG_DEBUG("\tworkspace_pools[%d]: %s\n", i, md->workspace_pools[i].name);
}
for (i = 0; i < md->num_constant_pools; ++i) {
LOG_DEBUG("\tconstant_pools[%d]: %s\n", i, md->constant_pools[i].name_hint);
}
}
int TVMAotExecutor_GetNumInputs(TVMAotExecutor* executor) { return executor->metadata->num_inputs; }
int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor) {
return executor->metadata->num_outputs;
}
int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
int i;
int rv = -1;
const TVMMetadata* md = executor->metadata;
for (i = 0; i < md->num_inputs; ++i) {
if (!strcmp(md->inputs[i].name, name)) {
rv = i;
break;
}
}
CHECK_GE(rv, 0, "cannot find '%s' among input.", name);
return rv;
}
int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, const char** name) {
const TVMMetadata* md = executor->metadata;
*name = md->inputs[index].name;
return 0;
}
int TVMAotExecutor_Run(TVMAotExecutor* executor) {
const char* tvm_main_suffix = "_run";
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];
{
const size_t max_strlen = TVM_CRT_MAX_STRLEN_FUNCTION_NAME;
size_t len = strnlen(executor->metadata->mod_name, max_strlen);
len += strnlen(tvm_main_suffix, max_strlen);
CHECK_LT(len, max_strlen, "tvm_main name too long %zu\n", len);
}
// create main function name string, e.g. "tvmgen_default___tvm_main__"
snprintf(tvm_main_name, sizeof(tvm_main_name), "%s%s", executor->metadata->mod_name,
tvm_main_suffix);
TVMPackedFunc tvm_main;
TVMArgs temp_args;
CHECK_LE(executor->num_args, TVM_CRT_MAX_ARGS, "too many args %" PRId64 "\n", executor->num_args);
int i;
for (i = 0; i < executor->num_args; ++i) {
temp_args.values[i].v_handle = &executor->args[i].dl_tensor;
temp_args.tcodes[i] = kTVMDLTensorHandle;
}
temp_args.values_count = executor->num_args;
int status =
TVMPackedFunc_InitModuleFunc(&tvm_main, executor->module_handle, tvm_main_name, &temp_args);
if (status != 0) {
return status;
}
CHECK_EQ(tvm_main.Call(&tvm_main), 0, "call to %s failed", tvm_main_name);
return 0;
}
int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle,
const DLDevice device, const char* module_name) {
executor->module_handle = module_handle;
executor->device = device;
// get a pointer to the PackedFunc get_c_metadata() which gives us access to the top-level
// metadata structure
TVMPackedFunc get_c_metadata;
TVMArgs temp_args;
temp_args.values_count = 0;
const char* tvmgen_prefix = "tvmgen_";
const char* get_c_metdata_suffix = "_get_c_metadata";
char get_c_metdata_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];
{
size_t max_strlen = TVM_CRT_MAX_STRLEN_FUNCTION_NAME;
size_t len = strnlen(tvmgen_prefix, max_strlen);
len += strnlen(module_name, max_strlen);
len += strnlen(get_c_metdata_suffix, max_strlen);
CHECK_LT(len, max_strlen, "get_c_metadata name too long %zu\n", len);
}
// create get_c_metadata() function name string, e.g. "tvmgen_default_get_c_metadata()"
snprintf(get_c_metdata_name, sizeof(get_c_metdata_name), "%s%s%s", tvmgen_prefix, module_name,
get_c_metdata_suffix);
int status = TVMPackedFunc_InitModuleFunc(&get_c_metadata, executor->module_handle,
get_c_metdata_name, &temp_args);
if (status != 0) {
return status;
}
CHECK_EQ(get_c_metadata.Call(&get_c_metadata), 0, "get_c_metadata");
// save the returned pointer to the top-level metadata
executor->metadata = (TVMMetadata*)get_c_metadata.ret_value.values[0].v_handle;
const TVMMetadata* md = executor->metadata;
DumpMetadata(md);
executor->num_args = md->num_inputs + md->num_outputs + md->num_workspace_pools;
tvm_crt_error_t err = TVMPlatformMemoryAllocate(executor->num_args * sizeof(*executor->args),
executor->device, (void**)(&executor->args));
if (err != kTvmErrorNoError) {
return -1;
}
int i;
int arg_idx = 0;
for (i = 0; i < md->num_inputs; ++i) {
LOG_DEBUG("input allocate[%d]: %s\n", i, md->inputs[i].name);
TVMNDArray* array = &executor->args[arg_idx++];
status = TVMNDArray_Empty(md->inputs[i].num_shape, md->inputs[i].shape, md->inputs[i].dtype,
executor->device, array);
if (status != 0) {
return status;
}
TVMNDArray_IncrementReference(array);
}
for (i = 0; i < md->num_outputs; ++i) {
LOG_DEBUG("output allocate[%d]: %s\n", i, md->outputs[i].name);
TVMNDArray* array = &executor->args[arg_idx++];
status = TVMNDArray_Empty(md->outputs[i].num_shape, md->outputs[i].shape, md->outputs[i].dtype,
executor->device, array);
if (status != 0) {
return status;
}
TVMNDArray_IncrementReference(array);
}
return status;
}
int TVMAotExecutor_Create(TVMModuleHandle module_handle, const DLDevice device,
TVMAotExecutor** executor, const char* module_name) {
tvm_crt_error_t err = TVMPlatformMemoryAllocate(sizeof(**executor), device, (void**)executor);
if (err != kTvmErrorNoError) {
return -1;
}
memset(*executor, 0, sizeof(**executor));
return TVMAotExecutor_Init(*executor, module_handle, device, module_name);
}
int TVMAotExecutor_Release(TVMAotExecutor* executor, const DLDevice device) {
int status;
if (executor->num_args > 0) {
// free TVMNDArray data memory for each argument
int i;
for (i = 0; i < executor->num_args; ++i) {
status = TVMNDArray_Release(&executor->args[i]);
if (status != 0) {
return status;
}
}
// free TVMNDArray argument list
status = TVMPlatformMemoryFree(executor->args, executor->device);
if (status != 0) {
return status;
}
}
status = TVMPlatformMemoryFree(executor, device);
if (status != 0) {
return status;
}
return 0;
}