blob: b25ed4675641b9632900ee3d08c17a8b31fb348f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC server implementation.
Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
- [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
import os
import ctypes
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time
import sys
import signal
import platform
import tvm._ffi
from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load_module as _load_module
from tvm.contrib import util
from . import _ffi_api
from . import base
from .base import TrackerCode
logger = logging.getLogger("RPCServer")
def _server_env(load_library, work_path=None):
"""Server environment function return temp dir"""
if work_path:
temp = work_path
else:
temp = util.tempdir()
# pylint: disable=unused-variable
@tvm._ffi.register_func("tvm.rpc.server.workpath", override=True)
def get_workpath(path):
return temp.relpath(path)
@tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
m = _load_module(path)
logger.info("load_module %s", path)
return m
@tvm._ffi.register_func("tvm.rpc.server.download_linked_module", override=True)
def download_linked_module(file_name):
"""Load module from remote side."""
# c++ compiler/linker
cc = os.environ.get("CXX", "g++")
# pylint: disable=import-outside-toplevel
path = temp.relpath(file_name)
if path.endswith(".o"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc
_cc.create_shared(path + ".so", path, cc=cc)
path += ".so"
elif path.endswith(".tar"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc, tar as _tar
tar_temp = util.tempdir(custom_path=path.replace(".tar", ""))
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files, cc=cc)
path += ".so"
elif path.endswith(".dylib") or path.endswith(".so"):
pass
else:
raise RuntimeError("Do not know how to link %s" % file_name)
logger.info("Send linked module %s to client", path)
return bytearray(open(path, "rb").read())
libs = []
load_library = load_library.split(":") if load_library else []
for file_name in load_library:
file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logger.info("Load additional library %s", file_name)
temp.libs = libs
return temp
def _serve_loop(sock, addr, load_library, work_path=None):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env(load_library, work_path)
_ffi_api.ServerLoop(sockfd)
if not work_path:
temp.remove()
logger.info("Finish serving %s", addr)
def _parse_server_opt(opts):
# parse client options
ret = {}
for kv in opts:
if kv.startswith("-timeout="):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
Parameters
----------
listen_sock: Socket
The socket used by listening process.
tracker_conn : connnection to tracker
Tracker connection
ping_period : float, optional
ping tracker every k seconds if no connection is accepted.
"""
old_keyset = set()
# Report resource to tracker
if tracker_conn:
matchkey = base.random_key(rpc_key + ":")
base.sendjson(tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else:
matchkey = rpc_key
unmatch_period_count = 0
unmatch_timeout = 4
# Wait until we get a valid connection
while True:
if tracker_conn:
trigger = select.select([listen_sock], [], [], ping_period)
if not listen_sock in trigger[0]:
base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
pending_keys = base.recvjson(tracker_conn)
old_keyset.add(matchkey)
# if match key not in pending key set
# it means the key is acquired by a client but not used.
if matchkey not in pending_keys:
unmatch_period_count += 1
else:
unmatch_period_count = 0
# regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logger.info("no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(
tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]
)
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
unmatch_period_count = 0
continue
conn, addr = listen_sock.accept()
magic = struct.unpack("<i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen))
arr = key.split()
expect_header = "client:" + matchkey
server_key = "server:" + rpc_key
if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logger.warning("mismatch key from %s", addr)
continue
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:])
# Server logic
tracker_conn = None
while True:
try:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue
cinfo = {"key": "server:" + rpc_key}
base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
# step 2: wait for in-coming connections
conn, addr, opts = _accept_conn(sock, tracker_conn)
except (socket.error, IOError):
# retry when tracker is dropped
if tracker_conn:
tracker_conn.close()
tracker_conn = None
continue
except RuntimeError as exc:
raise exc
# step 3: serving
work_path = util.tempdir()
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(
target=_serve_loop, args=(conn, addr, load_library, work_path)
)
server_proc.deamon = True
server_proc.start()
# close from our side.
conn.close()
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))
if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..")
# pylint: disable=import-outside-toplevel
import psutil
parent = psutil.Process(server_proc.pid)
# terminate worker childs
for child in parent.children(recursive=True):
child.terminate()
# terminate the worker
server_proc.terminate()
work_path.remove()
def _connect_proxy_loop(addr, key, load_library):
key = "server:" + key
retry_count = 0
max_retry = 5
retry_period = 5
while True:
try:
sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("<i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
if magic == base.RPC_CODE_MISMATCH:
logger.warning("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logger.info("Timeout in RPC session, kill..")
process.terminate()
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
def _popen(cmd):
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Server invoke error:\n"
msg += out
raise RuntimeError(msg)
class Server(object):
"""Start RPC server on a separate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based server with
TVM runtime which does not depend on the python.
Parameters
----------
host : str
The host url of the server.
port : int
The port to be bind to
port_end : int, optional
The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple(host, ip) format.
If is not None, the server will register itself to the tracker.
key : str, optional
The key used to identify the device type in tracker.
load_library : str, optional
List of additional libraries to be loaded during execution.
custom_addr: str, optional
Custom IP Address to Report to RPC Tracker
silent: bool, optional
Whether run this server in silent mode.
"""
def __init__(
self,
host,
port=9091,
port_end=9199,
is_proxy=False,
use_popen=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
utvm_dev_id=None,
utvm_dev_config_args=None,
):
try:
if _ffi_api.ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr
self.use_popen = use_popen
if silent:
logger.setLevel(logging.ERROR)
if use_popen:
cmd = [
sys.executable,
"-m",
"tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port,
"--port-end=%s" % port_end,
]
if tracker_addr:
assert key
cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key]
if load_library:
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]
if utvm_dev_id is not None:
assert utvm_dev_config_args is not None
cmd += [f"--utvm-dev-id={utvm_dev_id}"]
cmd += [f"--utvm-dev-config-args={utvm_dev_config_args}"]
# prexec_fn is not thread safe and may result in deadlock.
# python 3.2 introduced the start_new_session parameter as
# an alternative to the common use case of
# prexec_fn=os.setsid. Once the minimum version of python
# supported by TVM reaches python 3.2 this code can be
# rewritten in favour of start_new_session. In the
# interim, stop the pylint diagnostic.
#
# pylint: disable=subprocess-popen-preexec-fn
if platform.system() == "Windows":
self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else:
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop,
args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
)
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library)
)
self.proc.deamon = True
self.proc.start()
def terminate(self):
"""Terminate the server process"""
if self.use_popen:
if self.proc:
if platform.system() == "Windows":
os.kill(self.proc.pid, signal.CTRL_C_EVENT)
else:
os.killpg(self.proc.pid, signal.SIGTERM)
self.proc = None
else:
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()