| # -*- coding: utf-8 -*- |
| from base64 import b64encode |
| from hashlib import sha1 |
| import os |
| import socket |
| import ssl |
| |
| from ambari_ws4py import WS_KEY, WS_VERSION |
| from ambari_ws4py.exc import HandshakeError |
| from ambari_ws4py.websocket import WebSocket |
| from ambari_ws4py.compat import urlsplit |
| |
| __all__ = ['WebSocketBaseClient'] |
| |
| class WebSocketBaseClient(WebSocket): |
| def __init__(self, url, protocols=None, extensions=None, |
| heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None): |
| """ |
| A websocket client that implements :rfc:`6455` and provides a simple |
| interface to communicate with a websocket server. |
| |
| This class works on its own but will block if not run in |
| its own thread. |
| |
| When an instance of this class is created, a :py:mod:`socket` |
| is created. If the connection is a TCP socket, |
| the nagle's algorithm is disabled. |
| |
| The address of the server will be extracted from the given |
| websocket url. |
| |
| The websocket key is randomly generated, reset the |
| `key` attribute if you want to provide yours. |
| |
| For instance to create a TCP client: |
| |
| .. code-block:: python |
| |
| >>> from ambari_ws4py.client import WebSocketBaseClient |
| >>> ws = WebSocketBaseClient('ws://localhost/ws') |
| |
| |
| Here is an example for a TCP client over SSL: |
| |
| .. code-block:: python |
| |
| >>> from ambari_ws4py.client import WebSocketBaseClient |
| >>> ws = WebSocketBaseClient('wss://localhost/ws') |
| |
| |
| Finally an example of a Unix-domain connection: |
| |
| .. code-block:: python |
| |
| >>> from ambari_ws4py.client import WebSocketBaseClient |
| >>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock') |
| |
| Note that in this case, the initial Upgrade request |
| will be sent to ``/``. You may need to change this |
| by setting the resource explicitely before connecting: |
| |
| .. code-block:: python |
| |
| >>> from ambari_ws4py.client import WebSocketBaseClient |
| >>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock') |
| >>> ws.resource = '/ws' |
| >>> ws.connect() |
| |
| You may provide extra headers by passing a list of tuples |
| which must be unicode objects. |
| |
| """ |
| self.url = url |
| self.host = None |
| self.scheme = None |
| self.port = None |
| self.unix_socket_path = None |
| self.resource = None |
| self.ssl_options = ssl_options or {} |
| self.extra_headers = headers or [] |
| self.exclude_headers = exclude_headers or [] |
| self.exclude_headers = [x.lower() for x in self.exclude_headers] |
| |
| if self.scheme == "wss": |
| # Prevent check_hostname requires server_hostname (ref #187) |
| if "cert_reqs" not in self.ssl_options: |
| self.ssl_options["cert_reqs"] = ssl.CERT_NONE |
| |
| self._parse_url() |
| |
| if self.unix_socket_path: |
| sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) |
| else: |
| # Let's handle IPv4 and IPv6 addresses |
| # Simplified from CherryPy's code |
| try: |
| family, socktype, proto, canonname, sa = socket.getaddrinfo(self.host, self.port, |
| socket.AF_UNSPEC, |
| socket.SOCK_STREAM, |
| 0, socket.AI_PASSIVE)[0] |
| except socket.gaierror: |
| family = socket.AF_INET |
| if self.host.startswith('::'): |
| family = socket.AF_INET6 |
| |
| socktype = socket.SOCK_STREAM |
| proto = 0 |
| canonname = "" |
| sa = (self.host, self.port, 0, 0) |
| |
| sock = socket.socket(family, socktype, proto) |
| sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| if hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 and \ |
| self.host.startswith('::'): |
| try: |
| sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) |
| except (AttributeError, socket.error): |
| pass |
| |
| WebSocket.__init__(self, sock, protocols=protocols, |
| extensions=extensions, |
| heartbeat_freq=heartbeat_freq) |
| |
| self.stream.always_mask = True |
| self.stream.expect_masking = False |
| self.key = b64encode(os.urandom(16)) |
| |
| # Adpated from: https://github.com/liris/websocket-client/blob/master/websocket.py#L105 |
| def _parse_url(self): |
| """ |
| Parses a URL which must have one of the following forms: |
| |
| - ws://host[:port][path] |
| - wss://host[:port][path] |
| - ws+unix:///path/to/my.socket |
| |
| In the first two cases, the ``host`` and ``port`` |
| attributes will be set to the parsed values. If no port |
| is explicitely provided, it will be either 80 or 443 |
| based on the scheme. Also, the ``resource`` attribute is |
| set to the path segment of the URL (alongside any querystring). |
| |
| In addition, if the scheme is ``ws+unix``, the |
| ``unix_socket_path`` attribute is set to the path to |
| the Unix socket while the ``resource`` attribute is |
| set to ``/``. |
| """ |
| # Python 2.6.1 and below don't parse ws or wss urls properly. netloc is empty. |
| # See: https://github.com/Lawouach/WebSocket-for-Python/issues/59 |
| scheme, url = self.url.split(":", 1) |
| |
| parsed = urlsplit(url, scheme="http") |
| if parsed.hostname: |
| self.host = parsed.hostname |
| elif '+unix' in scheme: |
| self.host = 'localhost' |
| else: |
| raise ValueError("Invalid hostname from: %s", self.url) |
| |
| if parsed.port: |
| self.port = parsed.port |
| |
| if scheme == "ws": |
| if not self.port: |
| self.port = 80 |
| elif scheme == "wss": |
| if not self.port: |
| self.port = 443 |
| elif scheme in ('ws+unix', 'wss+unix'): |
| pass |
| else: |
| raise ValueError("Invalid scheme: %s" % scheme) |
| |
| if parsed.path: |
| resource = parsed.path |
| else: |
| resource = "/" |
| |
| if '+unix' in scheme: |
| self.unix_socket_path = resource |
| resource = '/' |
| |
| if parsed.query: |
| resource += "?" + parsed.query |
| |
| self.scheme = scheme |
| self.resource = resource |
| |
| @property |
| def bind_addr(self): |
| """ |
| Returns the Unix socket path if or a tuple |
| ``(host, port)`` depending on the initial |
| URL's scheme. |
| """ |
| return self.unix_socket_path or (self.host, self.port) |
| |
| def close(self, code=1000, reason=''): |
| """ |
| Initiate the closing handshake with the server. |
| """ |
| if not self.client_terminated: |
| self.client_terminated = True |
| self._write(self.stream.close(code=code, reason=reason).single(mask=True)) |
| |
| def connect(self): |
| """ |
| Connects this websocket and starts the upgrade handshake |
| with the remote endpoint. |
| """ |
| if self.scheme == "wss": |
| # default port is now 443; upgrade self.sender to send ssl |
| self.sock = ssl.wrap_socket(self.sock, **self.ssl_options) |
| self._is_secure = True |
| |
| self.sock.settimeout(10.0) |
| self.sock.connect(self.bind_addr) |
| |
| self._write(self.handshake_request) |
| |
| response = b'' |
| doubleCLRF = b'\r\n\r\n' |
| while True: |
| bytes = self.sock.recv(128) |
| if not bytes: |
| break |
| response += bytes |
| if doubleCLRF in response: |
| break |
| |
| if not response: |
| self.close_connection() |
| raise HandshakeError("Invalid response") |
| |
| headers, _, body = response.partition(doubleCLRF) |
| response_line, _, headers = headers.partition(b'\r\n') |
| |
| try: |
| self.process_response_line(response_line) |
| self.protocols, self.extensions = self.process_handshake_header(headers) |
| except HandshakeError: |
| self.close_connection() |
| raise |
| |
| self.handshake_ok() |
| if body: |
| self.process(body) |
| |
| @property |
| def handshake_headers(self): |
| """ |
| List of headers appropriate for the upgrade |
| handshake. |
| """ |
| headers = [ |
| ('Host', '%s:%s' % (self.host, self.port)), |
| ('Connection', 'Upgrade'), |
| ('Upgrade', 'websocket'), |
| ('Sec-WebSocket-Key', self.key.decode('utf-8')), |
| ('Sec-WebSocket-Version', str(max(WS_VERSION))) |
| ] |
| |
| if self.protocols: |
| headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols))) |
| |
| if self.extra_headers: |
| headers.extend(self.extra_headers) |
| |
| if not any(x for x in headers if x[0].lower() == 'origin') and \ |
| 'origin' not in self.exclude_headers: |
| |
| scheme, url = self.url.split(":", 1) |
| parsed = urlsplit(url, scheme="http") |
| if parsed.hostname: |
| self.host = parsed.hostname |
| else: |
| self.host = 'localhost' |
| origin = scheme + '://' + self.host |
| if parsed.port: |
| origin = origin + ':' + str(parsed.port) |
| headers.append(('Origin', origin)) |
| |
| headers = [x for x in headers if x[0].lower() not in self.exclude_headers] |
| |
| return headers |
| |
| @property |
| def handshake_request(self): |
| """ |
| Prepare the request to be sent for the upgrade handshake. |
| """ |
| headers = self.handshake_headers |
| request = [("GET %s HTTP/1.1" % self.resource).encode('utf-8')] |
| for header, value in headers: |
| request.append(("%s: %s" % (header, value)).encode('utf-8')) |
| request.append(b'\r\n') |
| |
| return b'\r\n'.join(request) |
| |
| def process_response_line(self, response_line): |
| """ |
| Ensure that we received a HTTP `101` status code in |
| response to our request and if not raises :exc:`HandshakeError`. |
| """ |
| protocol, code, status = response_line.split(b' ', 2) |
| if code != b'101': |
| raise HandshakeError("Invalid response status: %s %s" % (code, status)) |
| |
| def process_handshake_header(self, headers): |
| """ |
| Read the upgrade handshake's response headers and |
| validate them against :rfc:`6455`. |
| """ |
| protocols = [] |
| extensions = [] |
| |
| headers = headers.strip() |
| |
| for header_line in headers.split(b'\r\n'): |
| header, value = header_line.split(b':', 1) |
| header = header.strip().lower() |
| value = value.strip().lower() |
| |
| if header == b'upgrade' and value != b'websocket': |
| raise HandshakeError("Invalid Upgrade header: %s" % value) |
| |
| elif header == b'connection' and value != b'upgrade': |
| raise HandshakeError("Invalid Connection header: %s" % value) |
| |
| elif header == b'sec-websocket-accept': |
| match = b64encode(sha1(self.key + WS_KEY).digest()) |
| if value != match.lower(): |
| raise HandshakeError("Invalid challenge response: %s" % value) |
| |
| elif header == b'sec-websocket-protocol': |
| protocols.extend([x.strip() for x in value.split(b',')]) |
| |
| elif header == b'sec-websocket-extensions': |
| extensions.extend([x.strip() for x in value.split(b',')]) |
| |
| return protocols, extensions |
| |
| def handshake_ok(self): |
| self.opened() |