blob: 50aa5a775d51051c0767a158b127d2958f4162e5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC proxy, allows both client/server to connect and match connection.
In normal RPC, client directly connect to server's IP address.
Sometimes this cannot be done when server do not have a static address.
RPCProxy allows both client and server connect to the proxy server,
the proxy server will forward the message between the client and server.
"""
# pylint: disable=unused-variable, unused-argument
import os
import asyncio
import logging
import socket
import threading
import errno
import struct
import time
try:
import tornado
from tornado import gen
from tornado import websocket
from tornado import ioloop
from . import tornado_util
except ImportError as error_msg:
raise ImportError(
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg
)
from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base
from .base import TrackerCode
from .server import _server_env
from .._ffi.base import py_str
class ForwardHandler(object):
"""Forward handler to forward the message."""
def _init_handler(self):
"""Initialize handler."""
self._init_message = bytes()
self._init_req_nbytes = 4
self._magic = None
self.timeout = None
self._rpc_key_length = None
self._done = False
self._proxy = ProxyServerHandler.current
assert self._proxy
self.rpc_key = None
self.match_key = None
self.forward_proxy = None
self.alloc_time = None
def __del__(self):
logging.info("Delete %s...", self.name())
def name(self):
"""Name of this connection."""
return "RPCConnection"
def _init_step(self, message):
if self._magic is None:
assert len(message) == 4
self._magic = struct.unpack("<i", message)[0]
if self._magic != base.RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name())
self.close()
self._init_req_nbytes = 4
elif self._rpc_key_length is None:
assert len(message) == 4
self._rpc_key_length = struct.unpack("<i", message)[0]
self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None:
assert len(message) == self._rpc_key_length
self.rpc_key = py_str(message)
# match key is used to do the matching
self.match_key = self.rpc_key[7:].split()[0]
self.on_start()
else:
assert False
def on_start(self):
"""Event when the initialization is completed"""
self._proxy.handler_ready(self)
def on_data(self, message):
"""on data"""
assert isinstance(message, bytes)
if self.forward_proxy:
self.forward_proxy.send_data(message)
else:
while message and self._init_req_nbytes > len(self._init_message):
nbytes = self._init_req_nbytes - len(self._init_message)
self._init_message += message[:nbytes]
message = message[nbytes:]
if self._init_req_nbytes == len(self._init_message):
temp = self._init_message
self._init_req_nbytes = 0
self._init_message = bytes()
self._init_step(temp)
if message:
logging.info("Invalid RPC protocol, too many bytes %s", self.name())
self.close()
def on_error(self, err):
logging.info("%s: Error in RPC %s", self.name(), err)
self.close_pair()
def close_pair(self):
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
self.close()
def on_close_event(self):
"""on close event"""
assert not self._done
logging.info("RPCProxy:on_close_event %s ...", self.name())
if self.match_key:
key = self.match_key
if self._proxy._client_pool.get(key, None) == self:
self._proxy._client_pool.pop(key)
if self._proxy._server_pool.get(key, None) == self:
self._proxy._server_pool.pop(key)
self._done = True
self.forward_proxy = None
class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
"""Event driven TCP handler."""
def __init__(self, sock, addr):
super(TCPHandler, self).__init__(sock)
self._init_handler()
self.addr = addr
def name(self):
return "TCPSocketProxy:%s:%s" % (str(self.addr[0]), self.rpc_key)
def send_data(self, message, binary=True):
self.write_message(message, True)
def on_message(self, message):
self.on_data(message)
def on_close(self):
logging.info("RPCProxy: on_close %s ...", self.name())
self._close_process = True
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
self.on_close_event()
class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
"""Handler for websockets."""
def __init__(self, *args, **kwargs):
super(WebSocketHandler, self).__init__(*args, **kwargs)
self._init_handler()
def name(self):
return "WebSocketProxy:%s" % (self.rpc_key)
def on_message(self, message):
self.on_data(message)
def data_received(self, _):
raise NotImplementedError()
def send_data(self, message):
try:
self.write_message(message, True)
except websocket.WebSocketClosedError as err:
self.on_error(err)
def on_close(self):
logging.info("RPCProxy: on_close %s ...", self.name())
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
self.on_close_event()
def signal_close(self):
self.close()
class RequestHandler(tornado.web.RequestHandler):
"""Handles html request."""
def __init__(self, *args, **kwargs):
file_path = kwargs.pop("file_path")
if file_path.endswith("html"):
self.page = open(file_path).read()
web_port = kwargs.pop("rpc_web_port", None)
if web_port:
self.page = self.page.replace(
"ws://localhost:9190/ws", "ws://localhost:%d/ws" % web_port
)
else:
self.page = open(file_path, "rb").read()
super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _):
pass
def get(self, *args, **kwargs):
self.write(self.page)
class ProxyServerHandler(object):
"""Internal proxy server handler class."""
current = None
def __init__(
self,
sock,
listen_port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page=None,
resource_files=None,
):
assert ProxyServerHandler.current is None
ProxyServerHandler.current = self
if web_port:
handlers = [
(r"/ws", WebSocketHandler),
]
if index_page:
handlers.append(
(r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port})
)
logging.info("Serving RPC index html page at http://localhost:%d", web_port)
resource_files = resource_files if resource_files else []
for fname in resource_files:
basename = os.path.basename(fname)
pair = (r"/%s" % basename, RequestHandler, {"file_path": fname})
handlers.append(pair)
logging.info(pair)
self.app = tornado.web.Application(handlers)
self.app.listen(web_port)
self.sock = sock
self.sock.setblocking(0)
self.loop = ioloop.IOLoop.current()
def event_handler(_, events):
self._on_event(events)
self.loop.add_handler(self.sock.fileno(), event_handler, self.loop.READ)
self._client_pool = {}
self._server_pool = {}
self.timeout_alloc = 5
self.timeout_client = timeout_client
self.timeout_server = timeout_server
# tracker information
self._listen_port = listen_port
self._tracker_addr = tracker_addr
self._tracker_conn = None
self._tracker_pending_puts = []
self._key_set = set()
self.update_tracker_period = 2
if tracker_addr:
logging.info("Tracker address:%s", str(tracker_addr))
def _callback():
self._update_tracker(True)
self.loop.call_later(self.update_tracker_period, _callback)
logging.info("RPCProxy: Websock port bind to %d", web_port)
def _on_event(self, _):
while True:
try:
conn, addr = self.sock.accept()
TCPHandler(conn, addr)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
def _pair_up(self, lhs, rhs):
lhs.forward_proxy = rhs
rhs.forward_proxy = lhs
lhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack("<i", len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack("<i", len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def _regenerate_server_keys(self, keys):
"""Regenerate keys for server pool"""
keyset = set(self._server_pool.keys())
new_keys = []
# re-generate the server match key, so old information is invalidated.
for key in keys:
rpc_key, _ = key.split(":")
handle = self._server_pool[key]
del self._server_pool[key]
new_key = base.random_key(rpc_key + ":", keyset)
self._server_pool[new_key] = handle
keyset.add(new_key)
new_keys.append(new_key)
return new_keys
def _update_tracker(self, period_update=False):
"""Update information on tracker."""
try:
if self._tracker_conn is None:
self._tracker_conn = socket.socket(
base.get_addr_family(self._tracker_addr), socket.SOCK_STREAM
)
self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
self.loop.stop()
raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
# just connect to tracker, need to update all keys
self._tracker_pending_puts = self._server_pool.keys()
if self._tracker_conn and period_update:
# periodically update tracker information
# regenerate key if the key is not in tracker anymore
# and there is no in-coming connection after timeout_alloc
base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
pending_keys = set(base.recvjson(self._tracker_conn))
update_keys = []
for k, v in self._server_pool.items():
if k not in pending_keys:
if v.alloc_time is None:
v.alloc_time = time.time()
elif time.time() - v.alloc_time > self.timeout_alloc:
update_keys.append(k)
v.alloc_time = None
if update_keys:
logging.info(
"RPCProxy: No incoming conn on %s, regenerate keys...", str(update_keys)
)
new_keys = self._regenerate_server_keys(update_keys)
self._tracker_pending_puts += new_keys
need_update_info = False
# report new connections
for key in self._tracker_pending_puts:
rpc_key = key.split(":")[0]
base.sendjson(
self._tracker_conn, [TrackerCode.PUT, rpc_key, (self._listen_port, key), None]
)
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
if rpc_key not in self._key_set:
self._key_set.add(rpc_key)
need_update_info = True
if need_update_info:
keylist = "[" + ",".join(self._key_set) + "]"
cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]}
base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
self._tracker_pending_puts = []
except (socket.error, IOError) as err:
logging.info(
"Lost tracker connection: %s, try reconnect in %g sec",
str(err),
self.update_tracker_period,
)
self._tracker_conn.close()
self._tracker_conn = None
self._regenerate_server_keys(self._server_pool.keys())
if period_update:
def _callback():
self._update_tracker(True)
self.loop.call_later(self.update_tracker_period, _callback)
def _handler_ready_tracker_mode(self, handler):
"""tracker mode to handle handler ready."""
if handler.rpc_key.startswith("server:"):
key = base.random_key(handler.match_key + ":", self._server_pool)
handler.match_key = key
self._server_pool[key] = handler
self._tracker_pending_puts.append(key)
self._update_tracker()
else:
if handler.match_key in self._server_pool:
self._pair_up(self._server_pool.pop(handler.match_key), handler)
else:
handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
handler.signal_close()
def _handler_ready_proxy_mode(self, handler):
"""Normal proxy mode when handler is ready."""
if handler.rpc_key.startswith("server:"):
pool_src, pool_dst = self._client_pool, self._server_pool
timeout = self.timeout_server
else:
pool_src, pool_dst = self._server_pool, self._client_pool
timeout = self.timeout_client
key = handler.match_key
if key in pool_src:
self._pair_up(pool_src.pop(key), handler)
return
if key not in pool_dst:
pool_dst[key] = handler
def cleanup():
"""Cleanup client connection if timeout"""
if pool_dst.get(key, None) == handler:
logging.info(
"Timeout client connection %s, cannot find match key=%s",
handler.name(),
key,
)
pool_dst.pop(key)
handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
handler.signal_close()
self.loop.call_later(timeout, cleanup)
else:
logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack("<i", base.RPC_CODE_DUPLICATE))
handler.signal_close()
def handler_ready(self, handler):
"""Report handler to be ready."""
logging.info("Handler ready %s", handler.name())
if self._tracker_addr:
self._handler_ready_tracker_mode(handler)
else:
self._handler_ready_proxy_mode(handler)
def run(self):
"""Run the proxy server"""
ioloop.IOLoop.current().start()
def _proxy_server(
listen_sock,
listen_port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = ProxyServerHandler(
listen_sock,
listen_port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
)
handler.run()
class PopenProxyServerState(object):
"""Internal PopenProxy State for Popen"""
current = None
def __init__(
self,
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [errno.EADDRINUSE]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(
target=_proxy_server,
args=(
sock,
self.port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
),
)
# start the server in a different thread
# so we can return the port directly
self.thread.start()
def _popen_start_proxy_server(
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenProxyServerState(
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
)
PopenProxyServerState.current = state
# returns the port so that the main can get the port number.
return state.port
class Proxy(object):
"""Start RPC proxy server on a separate process.
Python implementation based on PopenWorker.
Parameters
----------
host : str
The host url of the server.
port : int
The TCP port to be bind to
port_end : int, optional
The end TCP port to search
web_port : int, optional
The http/websocket port of the server.
timeout_client : float, optional
Timeout of client until it sees a matching connection.
timeout_server : float, optional
Timeout of server until it sees a matching connection.
tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple (host, ip) format.
If is not None, the server will register itself to the tracker.
index_page : str, optional
Path to an index page that can be used to display at proxy index.
resource_files : str, optional
Path to local resources that can be included in the http request
"""
def __init__(
self,
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_proxy_server,
[
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
],
)
# receive the port
self.port = self.proc.recv()
self.host = host
def terminate(self):
"""Terminate the server process"""
if self.proc:
logging.info("Terminating Proxy Server...")
self.proc.kill()
self.proc = None
def __del__(self):
self.terminate()
def websocket_proxy_server(url, key=""):
"""Create a RPC server that uses an websocket that connects to a proxy.
Parameters
----------
url : str
The url to be connected.
key : str
The key to identify the server.
"""
def create_on_message(conn):
def _fsend(data):
data = bytes(data)
conn.write_message(data, binary=True)
return len(data)
on_message = _ffi_api.CreateEventDrivenServer(_fsend, "WebSocketProxyServer", "%toinit")
return on_message
@gen.coroutine
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack("<i", base.RPC_MAGIC), binary=True)
key = "server:" + key
conn.write_message(struct.pack("<i", len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message()
assert len(msg) >= 4
magic = struct.unpack("<i", msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
if magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % url)
msg = msg[4:]
logging.info("Connection established with remote")
if msg:
on_message(bytearray(msg), 3)
while True:
try:
msg = yield conn.read_message()
if msg is None:
break
on_message(bytearray(msg), 3)
except websocket.WebSocketClosedError as err:
break
logging.info("WebSocketProxyServer closed...")
temp.remove()
ioloop.IOLoop.current().stop()
ioloop.IOLoop.current().spawn_callback(_connect, key)
ioloop.IOLoop.current().start()