[VTA] Recover rpc server support (#8604)

diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index 0b49b67..52a7a89 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -365,9 +365,14 @@
     custom_addr=None,
     silent=False,
     no_fork=False,
+    server_init_callback=None,
 ):
     if no_fork:
         multiprocessing.set_start_method("spawn")
+
+    if server_init_callback:
+        server_init_callback()
+
     # This is a function that will be sent to the
     # Popen worker to run on a separate process.
     # Create and start the server in a different thread
@@ -420,6 +425,25 @@
 
     no_fork: bool, optional
         Whether forbid fork in multiprocessing.
+
+    server_init_callback: Callable, optional
+        Additional initialization function when starting the server.
+
+    Note
+    ----
+    The RPC server only sees functions in the tvm namespace.
+    To bring additional custom functions to the server env, you can use server_init_callback.
+
+    .. code:: python
+
+        def server_init_callback():
+            import tvm
+            # must import mypackage here
+            import mypackage
+
+            tvm.register_func("function", mypackage.func)
+
+        server = rpc.Server(host, server_init_callback=server_init_callback)
     """
 
     def __init__(
@@ -434,6 +458,7 @@
         custom_addr=None,
         silent=False,
         no_fork=False,
+        server_init_callback=None,
     ):
         try:
             if _ffi_api.ServerLoop is None:
@@ -455,6 +480,7 @@
                 custom_addr,
                 silent,
                 no_fork,
+                server_init_callback,
             ],
         )
         # receive the port
diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py
index b7a9c79..dcf564d 100644
--- a/vta/python/vta/exec/rpc_server.py
+++ b/vta/python/vta/exec/rpc_server.py
@@ -34,7 +34,6 @@
 from ..libinfo import find_libvta
 
 
-@tvm.register_func("tvm.rpc.server.start", override=True)
 def server_start():
     """VTA RPC server extension."""
     # pylint: disable=unused-variable
@@ -148,8 +147,21 @@
     else:
         tracker_addr = None
 
+    # register the initialization callback
+    def server_init_callback():
+        # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self
+        import tvm
+        import vta.exec.rpc_server
+
+        tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True)
+
     server = rpc.Server(
-        args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr
+        args.host,
+        args.port,
+        args.port_end,
+        key=args.key,
+        tracker_addr=tracker_addr,
+        server_init_callback=server_init_callback,
     )
     server.proc.join()