| # -*- coding: utf-8 -*- |
| import ssl |
| |
| from tornado import iostream, escape |
| from ambari_ws4py.client import WebSocketBaseClient |
| from ambari_ws4py.exc import HandshakeError |
| |
| __all__ = ['TornadoWebSocketClient'] |
| |
| class TornadoWebSocketClient(WebSocketBaseClient): |
| def __init__(self, url, protocols=None, extensions=None, |
| io_loop=None, ssl_options=None, headers=None, exclude_headers=None): |
| """ |
| .. code-block:: python |
| |
| from tornado import ioloop |
| |
| class MyClient(TornadoWebSocketClient): |
| def opened(self): |
| for i in range(0, 200, 25): |
| self.send("*" * i) |
| |
| def received_message(self, m): |
| print((m, len(str(m)))) |
| |
| def closed(self, code, reason=None): |
| ioloop.IOLoop.instance().stop() |
| |
| ws = MyClient('ws://localhost:9000/echo', protocols=['http-only', 'chat']) |
| ws.connect() |
| |
| ioloop.IOLoop.instance().start() |
| """ |
| WebSocketBaseClient.__init__(self, url, protocols, extensions, |
| ssl_options=ssl_options, headers=headers, exclude_headers=exclude_headers) |
| if self.scheme == "wss": |
| self.sock = ssl.wrap_socket(self.sock, do_handshake_on_connect=False, **self.ssl_options) |
| self._is_secure = True |
| self.io = iostream.SSLIOStream(self.sock, io_loop, ssl_options=self.ssl_options) |
| else: |
| self.io = iostream.IOStream(self.sock, io_loop) |
| self.io_loop = io_loop |
| |
| def connect(self): |
| """ |
| Connects the websocket and initiate the upgrade handshake. |
| """ |
| self.io.set_close_callback(self.__connection_refused) |
| self.io.connect((self.host, int(self.port)), self.__send_handshake) |
| |
| def _write(self, b): |
| """ |
| Trying to prevent a write operation |
| on an already closed websocket stream. |
| |
| This cannot be bullet proof but hopefully |
| will catch almost all use cases. |
| """ |
| if self.terminated: |
| raise RuntimeError("Cannot send on a terminated websocket") |
| |
| self.io.write(b) |
| |
| def __connection_refused(self, *args, **kwargs): |
| self.server_terminated = True |
| self.closed(1005, 'Connection refused') |
| |
| def __send_handshake(self): |
| self.io.set_close_callback(self.__connection_closed) |
| self.io.write(escape.utf8(self.handshake_request), |
| self.__handshake_sent) |
| |
| def __connection_closed(self, *args, **kwargs): |
| self.server_terminated = True |
| self.closed(1006, 'Connection closed during handshake') |
| |
| def __handshake_sent(self): |
| self.io.read_until(b"\r\n\r\n", self.__handshake_completed) |
| |
| def __handshake_completed(self, data): |
| self.io.set_close_callback(None) |
| try: |
| response_line, _, headers = data.partition(b'\r\n') |
| self.process_response_line(response_line) |
| protocols, extensions = self.process_handshake_header(headers) |
| except HandshakeError: |
| self.close_connection() |
| raise |
| |
| self.opened() |
| self.io.set_close_callback(self.__stream_closed) |
| self.io.read_bytes(self.reading_buffer_size, self.__fetch_more) |
| |
| def __fetch_more(self, bytes): |
| try: |
| should_continue = self.process(bytes) |
| except: |
| should_continue = False |
| |
| if should_continue: |
| self.io.read_bytes(self.reading_buffer_size, self.__fetch_more) |
| else: |
| self.__gracefully_terminate() |
| |
| def __gracefully_terminate(self): |
| self.client_terminated = self.server_terminated = True |
| |
| try: |
| if not self.stream.closing: |
| self.closed(1006) |
| finally: |
| self.close_connection() |
| |
| def __stream_closed(self, *args, **kwargs): |
| self.io.set_close_callback(None) |
| code = 1006 |
| reason = None |
| if self.stream.closing: |
| code, reason = self.stream.closing.code, self.stream.closing.reason |
| self.closed(code, reason) |
| self.stream._cleanup() |
| |
| def close_connection(self): |
| """ |
| Close the underlying connection |
| """ |
| self.io.close() |
| |
| if __name__ == '__main__': |
| from tornado import ioloop |
| |
| class MyClient(TornadoWebSocketClient): |
| def opened(self): |
| def data_provider(): |
| for i in range(0, 200, 25): |
| yield "#" * i |
| |
| self.send(data_provider()) |
| |
| for i in range(0, 200, 25): |
| self.send("*" * i) |
| |
| def received_message(self, m): |
| print("#%d" % len(m)) |
| if len(m) == 175: |
| self.close() |
| |
| def closed(self, code, reason=None): |
| ioloop.IOLoop.instance().stop() |
| print(("Closed down", code, reason)) |
| |
| ws = MyClient('ws://localhost:9000/ws', protocols=['http-only', 'chat']) |
| ws.connect() |
| |
| ioloop.IOLoop.instance().start() |