[VTA] Recover rpc server support (#8604)
diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index 0b49b67..52a7a89 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -365,9 +365,14 @@
custom_addr=None,
silent=False,
no_fork=False,
+ server_init_callback=None,
):
if no_fork:
multiprocessing.set_start_method("spawn")
+
+ if server_init_callback:
+ server_init_callback()
+
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
@@ -420,6 +425,25 @@
no_fork: bool, optional
Whether forbid fork in multiprocessing.
+
+ server_init_callback: Callable, optional
+ Additional initialization function when starting the server.
+
+ Note
+ ----
+ The RPC server only sees functions in the tvm namespace.
+ To bring additional custom functions to the server env, you can use server_init_callback.
+
+ .. code:: python
+
+ def server_init_callback():
+ import tvm
+ # must import mypackage here
+ import mypackage
+
+ tvm.register_func("function", mypackage.func)
+
+ server = rpc.Server(host, server_init_callback=server_init_callback)
"""
def __init__(
@@ -434,6 +458,7 @@
custom_addr=None,
silent=False,
no_fork=False,
+ server_init_callback=None,
):
try:
if _ffi_api.ServerLoop is None:
@@ -455,6 +480,7 @@
custom_addr,
silent,
no_fork,
+ server_init_callback,
],
)
# receive the port
diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py
index b7a9c79..dcf564d 100644
--- a/vta/python/vta/exec/rpc_server.py
+++ b/vta/python/vta/exec/rpc_server.py
@@ -34,7 +34,6 @@
from ..libinfo import find_libvta
-@tvm.register_func("tvm.rpc.server.start", override=True)
def server_start():
"""VTA RPC server extension."""
# pylint: disable=unused-variable
@@ -148,8 +147,21 @@
else:
tracker_addr = None
+ # register the initialization callback
+ def server_init_callback():
+ # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self
+ import tvm
+ import vta.exec.rpc_server
+
+ tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True)
+
server = rpc.Server(
- args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr
+ args.host,
+ args.port,
+ args.port_end,
+ key=args.key,
+ tracker_addr=tracker_addr,
+ server_init_callback=server_init_callback,
)
server.proc.join()