blob: 557c9ae24d402f42de2ce7240c592ebc026a98cb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC Tracker, tracks and distributes the TVM RPC resources.
This folder implemements the tracker server logic.
Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
- RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.
List of available APIs:
- PING: check if tracker is alive
- input: [TrackerCode.PING]
- return: TrackerCode.SUCCESS
- PUT: report resource to tracker
- input: [TrackerCode.PUT, [port, match-key]]
- return: TrackerCode.SUCCESS
- note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
# pylint: disable=invalid-name
import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json
try:
from tornado import ioloop
from . import tornado_util
except ImportError as error_msg:
raise ImportError(
"RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg
)
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
logger = logging.getLogger("RPCTracker")
class Scheduler(object):
"""Abstratc interface of scheduler."""
def put(self, value):
"""Push a resource into the scheduler.
This function can trigger callbacks in the scheduler.
Parameters
----------
value : object
The resource to be put in the scheduler.
"""
raise NotImplementedError()
def request(self, user, priority, callback):
"""Request a resource.
Parameters
----------
user : str
The user who is requesting the resource.
priority : int
The job priority
callback : function: value->bool
Callback function to receive an resource when ready
returns True if the resource is consumed.
"""
raise NotImplementedError()
def remove(self, value):
"""Remove a resource in the scheduler
Parameters
----------
value: object
The resource to remove
"""
def summary(self):
"""Get summary information of the scheduler."""
raise NotImplementedError()
class PriorityScheduler(Scheduler):
"""Priority based scheduler, FIFO based on time"""
def __init__(self, key):
self._key = key
self._values = []
self._requests = []
def _schedule(self):
while self._requests and self._values:
value = self._values.pop(0)
item = heapq.heappop(self._requests)
callback = item[-1]
if callback(value[1:]):
value[0].pending_matchkeys.remove(value[-1])
else:
self._values.append(value)
def put(self, value):
self._values.append(value)
self._schedule()
def request(self, user, priority, callback):
heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule()
def remove(self, value):
if value in self._values:
self._values.remove(value)
self._schedule()
def summary(self):
"""Get summary information of the scheduler."""
return {"free": len(self._values), "pending": len(self._requests)}
class TCPEventHandler(tornado_util.TCPHandler):
"""Base asynchronize message handler.
The tracker and client follows a simple message protocol.
The message is in form [nbytes(int32)] [json-str].
All the information is packed in json-str
"""
def __init__(self, tracker, sock, addr):
super(TCPEventHandler, self).__init__(sock)
self._data = bytearray()
self._tracker = tracker
self._msg_size = 0
self._addr = addr
self._init_req_nbytes = 4
self._info = {"addr": addr}
# list of pending match keys that has not been used.
self.pending_matchkeys = set()
self._tracker._connections.add(self)
self.put_values = []
def name(self):
"""name of connection"""
return "TCPSocket: %s" % str(self._addr)
def summary(self):
"""Summary of this connection"""
return self._info
def _init_conn(self, message):
"""Initialie the connection"""
if len(message) != 4:
logger.warning("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack("<i", message)[0]
if magic != RPC_TRACKER_MAGIC:
logger.warning("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack("<i", RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
def on_message(self, message):
"""Callback when a message is received.
Parameters
----------
message : bytearray
The bytes received
"""
assert isinstance(message, bytes)
if self._init_req_nbytes:
self._init_conn(message)
return
self._data += message
while True:
if self._msg_size == 0:
if len(self._data) >= 4:
self._msg_size = struct.unpack("<i", self._data[:4])[0]
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
msg = py_str(bytes(self._data[4 : 4 + self._msg_size]))
del self._data[: 4 + self._msg_size]
self._msg_size = 0
# pylint: disable=broad-except
self.call_handler(json.loads(msg))
else:
return
def ret_value(self, data):
"""return value to the output"""
data = json.dumps(data)
self.write_message(struct.pack("<i", len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args):
"""Event handler when json request arrives."""
code = args[0]
if code == TrackerCode.PUT:
key = args[1]
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server)
if len(args) >= 4 and args[3] is not None:
value = (self, args[3], port, matchkey)
else:
value = (self, self._addr[0], port, matchkey)
self._tracker.put(key, value)
self.put_values.append(value)
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
user = args[2]
priority = args[3]
def _cb(value):
# if the connection is already closed
if not self._sock:
return False
try:
self.ret_value([TrackerCode.SUCCESS, value])
except (socket.error, IOError):
return False
return True
self._tracker.request(key, user, priority, _cb)
elif code == TrackerCode.PING:
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.GET_PENDING_MATCHKEYS:
self.ret_value(list(self.pending_matchkeys))
elif code == TrackerCode.STOP:
# safe stop tracker
if self._tracker._stop_key == args[1]:
self.ret_value(TrackerCode.SUCCESS)
self._tracker.stop()
else:
self.ret_value(TrackerCode.FAIL)
elif code == TrackerCode.UPDATE_INFO:
self._info.update(args[1])
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.SUMMARY:
status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status])
else:
logger.warning("Unknown tracker code %d", code)
self.close()
def on_close(self):
self._tracker.close(self)
def on_error(self, err):
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
self.close()
class TrackerServerHandler(object):
"""Tracker that tracks the resources."""
def __init__(self, sock, stop_key):
self._scheduler_map = {}
self._sock = sock
self._sock.setblocking(0)
self._ioloop = ioloop.IOLoop.current()
self._stop_key = stop_key
self._connections = set()
def _event_handler(_, events):
self._on_event(events)
self._ioloop.add_handler(self._sock.fileno(), _event_handler, self._ioloop.READ)
def _on_event(self, _):
while True:
try:
conn, addr = self._sock.accept()
TCPEventHandler(self, conn, addr)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
def create_scheduler(self, key):
"""Create a new scheduler."""
return PriorityScheduler(key)
def put(self, key, value):
"""Report a new resource to the tracker."""
if key not in self._scheduler_map:
self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].put(value)
def request(self, key, user, priority, callback):
"""Request a new resource."""
if key not in self._scheduler_map:
self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].request(user, priority, callback)
def close(self, conn):
self._connections.remove(conn)
if "key" in conn._info:
key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b'
for value in conn.put_values:
self._scheduler_map[key].remove(value)
def stop(self):
"""Safely stop tracker."""
for conn in list(self._connections):
conn.close()
self._sock.close()
self._ioloop.stop()
def summary(self):
"""Return a dict summarizing current status."""
qinfo = {}
for k, v in self._scheduler_map.items():
qinfo[k] = v.summary()
cinfo = []
# ignore client connections without key
for conn in self._connections:
res = conn.summary()
if res.get("key", "").startswith("server"):
cinfo.append(res)
return {"queue_info": qinfo, "server_info": cinfo}
def run(self):
"""Run the tracker server"""
self._ioloop.start()
def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()
class Tracker(object):
"""Start RPC tracker on a seperate process.
Python implementation based on multi-processing.
Parameters
----------
host : str
The host url of the server.
port : int
The TCP port to be bind to
port_end : int, optional
The end TCP port to search
silent: bool, optional
Whether run in silent mode
"""
def __init__(self, host, port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key))
self.proc.start()
self.host = host
# close the socket on this process
sock.close()
def _stop_tracker(self):
sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
sock.connect((self.host, self.port))
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS
sock.close()
def terminate(self):
"""Terminate the server process"""
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()