blob: 69c4bb1a29e50de34a9eb7ca4a40e367e13e6aa0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Defines functions to work with TVMModule FuncRegistry."""
import json
def graph_json_to_c_func_registry(graph_path, func_registry_path):
"""Convert a graph json file to a CRT-compatible FuncRegistry.
Parameters
----------
graph_path : str
Path to the graph JSON file.
func_registry_path : str
Path to a .c file which will be written containing the function registry.
"""
with open(graph_path) as json_f:
graph = json.load(json_f)
funcs = []
for n in graph["nodes"]:
if n["op"] != "tvm_op":
continue
funcs.append(n["attrs"]["func_name"])
encoded_funcs = f"\\{len(funcs):03o}" + "\\0".join(funcs)
lines = [
"#include <tvm/runtime/c_runtime_api.h>",
"#include <tvm/runtime/crt/module.h>",
"#include <stdio.h>",
"",
]
for f in funcs:
lines.append(
f"extern int {f}(TVMValue* args, int* type_codes, int num_args, "
"TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle);"
)
lines.append("static TVMBackendPackedCFunc funcs[] = {")
for f in funcs:
lines.append(f" (TVMBackendPackedCFunc) &{f},")
lines += [
"};",
"static const TVMFuncRegistry system_lib_registry = {",
f' "{encoded_funcs}\\0",',
" funcs,",
"};",
"static const TVMModule system_lib = {",
" &system_lib_registry,",
"};",
"",
"const TVMModule* TVMSystemLibEntryPoint(void) {",
" return &system_lib;",
"}",
"", # blank line to end the file
]
with open(func_registry_path, "w") as wrapper_f:
wrapper_f.write("\n".join(lines))