| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """RPC Tracker, tracks and distributes the TVM RPC resources. |
| |
| This folder implemements the tracker server logic. |
| |
| Note |
| ---- |
| Tracker is a TCP based rest api with the following protocol: |
| - Initial handshake to the peer |
| - RPC_TRACKER_MAGIC |
| - Normal message: [size(int32), json-data] |
| - Each message is initiated by the client, and the tracker replies with a json. |
| |
| List of available APIs: |
| |
| - PING: check if tracker is alive |
| - input: [TrackerCode.PING] |
| - return: TrackerCode.SUCCESS |
| - PUT: report resource to tracker |
| - input: [TrackerCode.PUT, [port, match-key]] |
| - return: TrackerCode.SUCCESS |
| - note: match-key is a randomly generated identify the resource during connection. |
| - REQUEST: request a new resource from tracker |
| - input: [TrackerCode.REQUEST, [key, user, priority]] |
| - return: [TrackerCode.SUCCESS, [url, port, match-key]] |
| """ |
| # pylint: disable=invalid-name |
| |
| import heapq |
| import time |
| import logging |
| import socket |
| import multiprocessing |
| import errno |
| import struct |
| import json |
| |
| try: |
| from tornado import ioloop |
| from . import tornado_util |
| except ImportError as error_msg: |
| raise ImportError( |
| "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg |
| ) |
| |
| from .._ffi.base import py_str |
| from . import base |
| from .base import RPC_TRACKER_MAGIC, TrackerCode |
| |
| logger = logging.getLogger("RPCTracker") |
| |
| |
| class Scheduler(object): |
| """Abstratc interface of scheduler.""" |
| |
| def put(self, value): |
| """Push a resource into the scheduler. |
| |
| This function can trigger callbacks in the scheduler. |
| |
| Parameters |
| ---------- |
| value : object |
| The resource to be put in the scheduler. |
| """ |
| raise NotImplementedError() |
| |
| def request(self, user, priority, callback): |
| """Request a resource. |
| |
| Parameters |
| ---------- |
| user : str |
| The user who is requesting the resource. |
| |
| priority : int |
| The job priority |
| |
| callback : function: value->bool |
| Callback function to receive an resource when ready |
| returns True if the resource is consumed. |
| """ |
| raise NotImplementedError() |
| |
| def remove(self, value): |
| """Remove a resource in the scheduler |
| |
| Parameters |
| ---------- |
| value: object |
| The resource to remove |
| """ |
| |
| def summary(self): |
| """Get summary information of the scheduler.""" |
| raise NotImplementedError() |
| |
| |
| class PriorityScheduler(Scheduler): |
| """Priority based scheduler, FIFO based on time""" |
| |
| def __init__(self, key): |
| self._key = key |
| self._values = [] |
| self._requests = [] |
| |
| def _schedule(self): |
| while self._requests and self._values: |
| value = self._values.pop(0) |
| item = heapq.heappop(self._requests) |
| callback = item[-1] |
| if callback(value[1:]): |
| value[0].pending_matchkeys.remove(value[-1]) |
| else: |
| self._values.append(value) |
| |
| def put(self, value): |
| self._values.append(value) |
| self._schedule() |
| |
| def request(self, user, priority, callback): |
| heapq.heappush(self._requests, (-priority, time.time(), callback)) |
| self._schedule() |
| |
| def remove(self, value): |
| if value in self._values: |
| self._values.remove(value) |
| self._schedule() |
| |
| def summary(self): |
| """Get summary information of the scheduler.""" |
| return {"free": len(self._values), "pending": len(self._requests)} |
| |
| |
| class TCPEventHandler(tornado_util.TCPHandler): |
| """Base asynchronize message handler. |
| |
| The tracker and client follows a simple message protocol. |
| The message is in form [nbytes(int32)] [json-str]. |
| All the information is packed in json-str |
| """ |
| |
| def __init__(self, tracker, sock, addr): |
| super(TCPEventHandler, self).__init__(sock) |
| self._data = bytearray() |
| self._tracker = tracker |
| self._msg_size = 0 |
| self._addr = addr |
| self._init_req_nbytes = 4 |
| self._info = {"addr": addr} |
| # list of pending match keys that has not been used. |
| self.pending_matchkeys = set() |
| self._tracker._connections.add(self) |
| self.put_values = [] |
| |
| def name(self): |
| """name of connection""" |
| return "TCPSocket: %s" % str(self._addr) |
| |
| def summary(self): |
| """Summary of this connection""" |
| return self._info |
| |
| def _init_conn(self, message): |
| """Initialie the connection""" |
| if len(message) != 4: |
| logger.warning("Invalid connection from %s", self.name()) |
| self.close() |
| magic = struct.unpack("<i", message)[0] |
| if magic != RPC_TRACKER_MAGIC: |
| logger.warning("Invalid magic from %s", self.name()) |
| self.close() |
| self.write_message(struct.pack("<i", RPC_TRACKER_MAGIC), binary=True) |
| self._init_req_nbytes = 0 |
| |
| def on_message(self, message): |
| """Callback when a message is received. |
| |
| Parameters |
| ---------- |
| message : bytearray |
| The bytes received |
| """ |
| assert isinstance(message, bytes) |
| if self._init_req_nbytes: |
| self._init_conn(message) |
| return |
| |
| self._data += message |
| |
| while True: |
| if self._msg_size == 0: |
| if len(self._data) >= 4: |
| self._msg_size = struct.unpack("<i", self._data[:4])[0] |
| else: |
| return |
| if self._msg_size != 0 and len(self._data) >= self._msg_size + 4: |
| msg = py_str(bytes(self._data[4 : 4 + self._msg_size])) |
| del self._data[: 4 + self._msg_size] |
| self._msg_size = 0 |
| # pylint: disable=broad-except |
| self.call_handler(json.loads(msg)) |
| else: |
| return |
| |
| def ret_value(self, data): |
| """return value to the output""" |
| data = json.dumps(data) |
| self.write_message(struct.pack("<i", len(data)), binary=True) |
| self.write_message(data.encode("utf-8"), binary=True) |
| |
| def call_handler(self, args): |
| """Event handler when json request arrives.""" |
| code = args[0] |
| if code == TrackerCode.PUT: |
| key = args[1] |
| port, matchkey = args[2] |
| self.pending_matchkeys.add(matchkey) |
| # got custom address (from rpc server) |
| if len(args) >= 4 and args[3] is not None: |
| value = (self, args[3], port, matchkey) |
| else: |
| value = (self, self._addr[0], port, matchkey) |
| self._tracker.put(key, value) |
| self.put_values.append(value) |
| self.ret_value(TrackerCode.SUCCESS) |
| elif code == TrackerCode.REQUEST: |
| key = args[1] |
| user = args[2] |
| priority = args[3] |
| |
| def _cb(value): |
| # if the connection is already closed |
| if not self._sock: |
| return False |
| try: |
| self.ret_value([TrackerCode.SUCCESS, value]) |
| except (socket.error, IOError): |
| return False |
| return True |
| |
| self._tracker.request(key, user, priority, _cb) |
| elif code == TrackerCode.PING: |
| self.ret_value(TrackerCode.SUCCESS) |
| elif code == TrackerCode.GET_PENDING_MATCHKEYS: |
| self.ret_value(list(self.pending_matchkeys)) |
| elif code == TrackerCode.STOP: |
| # safe stop tracker |
| if self._tracker._stop_key == args[1]: |
| self.ret_value(TrackerCode.SUCCESS) |
| self._tracker.stop() |
| else: |
| self.ret_value(TrackerCode.FAIL) |
| elif code == TrackerCode.UPDATE_INFO: |
| self._info.update(args[1]) |
| self.ret_value(TrackerCode.SUCCESS) |
| elif code == TrackerCode.SUMMARY: |
| status = self._tracker.summary() |
| self.ret_value([TrackerCode.SUCCESS, status]) |
| else: |
| logger.warning("Unknown tracker code %d", code) |
| self.close() |
| |
| def on_close(self): |
| self._tracker.close(self) |
| |
| def on_error(self, err): |
| logger.warning("%s: Error in RPC Tracker: %s", self.name(), err) |
| self.close() |
| |
| |
| class TrackerServerHandler(object): |
| """Tracker that tracks the resources.""" |
| |
| def __init__(self, sock, stop_key): |
| self._scheduler_map = {} |
| self._sock = sock |
| self._sock.setblocking(0) |
| self._ioloop = ioloop.IOLoop.current() |
| self._stop_key = stop_key |
| self._connections = set() |
| |
| def _event_handler(_, events): |
| self._on_event(events) |
| |
| self._ioloop.add_handler(self._sock.fileno(), _event_handler, self._ioloop.READ) |
| |
| def _on_event(self, _): |
| while True: |
| try: |
| conn, addr = self._sock.accept() |
| TCPEventHandler(self, conn, addr) |
| except socket.error as err: |
| if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): |
| break |
| |
| def create_scheduler(self, key): |
| """Create a new scheduler.""" |
| return PriorityScheduler(key) |
| |
| def put(self, key, value): |
| """Report a new resource to the tracker.""" |
| if key not in self._scheduler_map: |
| self._scheduler_map[key] = self.create_scheduler(key) |
| self._scheduler_map[key].put(value) |
| |
| def request(self, key, user, priority, callback): |
| """Request a new resource.""" |
| if key not in self._scheduler_map: |
| self._scheduler_map[key] = self.create_scheduler(key) |
| self._scheduler_map[key].request(user, priority, callback) |
| |
| def close(self, conn): |
| self._connections.remove(conn) |
| if "key" in conn._info: |
| key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b' |
| for value in conn.put_values: |
| self._scheduler_map[key].remove(value) |
| |
| def stop(self): |
| """Safely stop tracker.""" |
| for conn in list(self._connections): |
| conn.close() |
| self._sock.close() |
| self._ioloop.stop() |
| |
| def summary(self): |
| """Return a dict summarizing current status.""" |
| qinfo = {} |
| for k, v in self._scheduler_map.items(): |
| qinfo[k] = v.summary() |
| cinfo = [] |
| # ignore client connections without key |
| for conn in self._connections: |
| res = conn.summary() |
| if res.get("key", "").startswith("server"): |
| cinfo.append(res) |
| return {"queue_info": qinfo, "server_info": cinfo} |
| |
| def run(self): |
| """Run the tracker server""" |
| self._ioloop.start() |
| |
| |
| def _tracker_server(listen_sock, stop_key): |
| handler = TrackerServerHandler(listen_sock, stop_key) |
| handler.run() |
| |
| |
| class Tracker(object): |
| """Start RPC tracker on a seperate process. |
| |
| Python implementation based on multi-processing. |
| |
| Parameters |
| ---------- |
| host : str |
| The host url of the server. |
| |
| port : int |
| The TCP port to be bind to |
| |
| port_end : int, optional |
| The end TCP port to search |
| |
| silent: bool, optional |
| Whether run in silent mode |
| """ |
| |
| def __init__(self, host, port=9190, port_end=9199, silent=False): |
| if silent: |
| logger.setLevel(logging.WARN) |
| |
| sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) |
| self.port = None |
| self.stop_key = base.random_key("tracker") |
| for my_port in range(port, port_end): |
| try: |
| sock.bind((host, my_port)) |
| self.port = my_port |
| break |
| except socket.error as sock_err: |
| if sock_err.errno in [98, 48]: |
| continue |
| raise sock_err |
| if not self.port: |
| raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) |
| logger.info("bind to %s:%d", host, self.port) |
| sock.listen(1) |
| self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key)) |
| self.proc.start() |
| self.host = host |
| # close the socket on this process |
| sock.close() |
| |
| def _stop_tracker(self): |
| sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM) |
| sock.connect((self.host, self.port)) |
| sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) |
| magic = struct.unpack("<i", base.recvall(sock, 4))[0] |
| assert magic == base.RPC_TRACKER_MAGIC |
| base.sendjson(sock, [TrackerCode.STOP, self.stop_key]) |
| assert base.recvjson(sock) == TrackerCode.SUCCESS |
| sock.close() |
| |
| def terminate(self): |
| """Terminate the server process""" |
| if self.proc: |
| if self.proc.is_alive(): |
| self._stop_tracker() |
| self.proc.join(1) |
| if self.proc.is_alive(): |
| logger.info("Terminating Tracker Server...") |
| self.proc.terminate() |
| self.proc = None |
| |
| def __del__(self): |
| self.terminate() |