blob: ce1c98fbbf90b00e51a0cc6324d884b48d4a2c2f [file] [log] [blame]
# -*- coding: utf-8 -*-
from base64 import b64encode
from hashlib import sha1
import os
import socket
import ssl
from ambari_ws4py import WS_KEY, WS_VERSION
from ambari_ws4py.exc import HandshakeError
from ambari_ws4py.websocket import WebSocket
from ambari_ws4py.compat import urlsplit
__all__ = ['WebSocketBaseClient']
class WebSocketBaseClient(WebSocket):
def __init__(self, url, protocols=None, extensions=None,
heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
"""
A websocket client that implements :rfc:`6455` and provides a simple
interface to communicate with a websocket server.
This class works on its own but will block if not run in
its own thread.
When an instance of this class is created, a :py:mod:`socket`
is created. If the connection is a TCP socket,
the nagle's algorithm is disabled.
The address of the server will be extracted from the given
websocket url.
The websocket key is randomly generated, reset the
`key` attribute if you want to provide yours.
For instance to create a TCP client:
.. code-block:: python
>>> from ambari_ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws://localhost/ws')
Here is an example for a TCP client over SSL:
.. code-block:: python
>>> from ambari_ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('wss://localhost/ws')
Finally an example of a Unix-domain connection:
.. code-block:: python
>>> from ambari_ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
Note that in this case, the initial Upgrade request
will be sent to ``/``. You may need to change this
by setting the resource explicitely before connecting:
.. code-block:: python
>>> from ambari_ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
>>> ws.resource = '/ws'
>>> ws.connect()
You may provide extra headers by passing a list of tuples
which must be unicode objects.
"""
self.url = url
self.host = None
self.scheme = None
self.port = None
self.unix_socket_path = None
self.resource = None
self.ssl_options = ssl_options or {}
self.extra_headers = headers or []
self.exclude_headers = exclude_headers or []
self.exclude_headers = [x.lower() for x in self.exclude_headers]
if self.scheme == "wss":
# Prevent check_hostname requires server_hostname (ref #187)
if "cert_reqs" not in self.ssl_options:
self.ssl_options["cert_reqs"] = ssl.CERT_NONE
self._parse_url()
if self.unix_socket_path:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
else:
# Let's handle IPv4 and IPv6 addresses
# Simplified from CherryPy's code
try:
family, socktype, proto, canonname, sa = socket.getaddrinfo(self.host, self.port,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
0, socket.AI_PASSIVE)[0]
except socket.gaierror:
family = socket.AF_INET
if self.host.startswith('::'):
family = socket.AF_INET6
socktype = socket.SOCK_STREAM
proto = 0
canonname = ""
sa = (self.host, self.port, 0, 0)
sock = socket.socket(family, socktype, proto)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 and \
self.host.startswith('::'):
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
except (AttributeError, socket.error):
pass
WebSocket.__init__(self, sock, protocols=protocols,
extensions=extensions,
heartbeat_freq=heartbeat_freq)
self.stream.always_mask = True
self.stream.expect_masking = False
self.key = b64encode(os.urandom(16))
# Adpated from: https://github.com/liris/websocket-client/blob/master/websocket.py#L105
def _parse_url(self):
"""
Parses a URL which must have one of the following forms:
- ws://host[:port][path]
- wss://host[:port][path]
- ws+unix:///path/to/my.socket
In the first two cases, the ``host`` and ``port``
attributes will be set to the parsed values. If no port
is explicitely provided, it will be either 80 or 443
based on the scheme. Also, the ``resource`` attribute is
set to the path segment of the URL (alongside any querystring).
In addition, if the scheme is ``ws+unix``, the
``unix_socket_path`` attribute is set to the path to
the Unix socket while the ``resource`` attribute is
set to ``/``.
"""
# Python 2.6.1 and below don't parse ws or wss urls properly. netloc is empty.
# See: https://github.com/Lawouach/WebSocket-for-Python/issues/59
scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
if parsed.hostname:
self.host = parsed.hostname
elif '+unix' in scheme:
self.host = 'localhost'
else:
raise ValueError("Invalid hostname from: %s", self.url)
if parsed.port:
self.port = parsed.port
if scheme == "ws":
if not self.port:
self.port = 80
elif scheme == "wss":
if not self.port:
self.port = 443
elif scheme in ('ws+unix', 'wss+unix'):
pass
else:
raise ValueError("Invalid scheme: %s" % scheme)
if parsed.path:
resource = parsed.path
else:
resource = "/"
if '+unix' in scheme:
self.unix_socket_path = resource
resource = '/'
if parsed.query:
resource += "?" + parsed.query
self.scheme = scheme
self.resource = resource
@property
def bind_addr(self):
"""
Returns the Unix socket path if or a tuple
``(host, port)`` depending on the initial
URL's scheme.
"""
return self.unix_socket_path or (self.host, self.port)
def close(self, code=1000, reason=''):
"""
Initiate the closing handshake with the server.
"""
if not self.client_terminated:
self.client_terminated = True
self._write(self.stream.close(code=code, reason=reason).single(mask=True))
def connect(self):
"""
Connects this websocket and starts the upgrade handshake
with the remote endpoint.
"""
if self.scheme == "wss":
# default port is now 443; upgrade self.sender to send ssl
self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
self._is_secure = True
self.sock.settimeout(10.0)
self.sock.connect(self.bind_addr)
self._write(self.handshake_request)
response = b''
doubleCLRF = b'\r\n\r\n'
while True:
bytes = self.sock.recv(128)
if not bytes:
break
response += bytes
if doubleCLRF in response:
break
if not response:
self.close_connection()
raise HandshakeError("Invalid response")
headers, _, body = response.partition(doubleCLRF)
response_line, _, headers = headers.partition(b'\r\n')
try:
self.process_response_line(response_line)
self.protocols, self.extensions = self.process_handshake_header(headers)
except HandshakeError:
self.close_connection()
raise
self.handshake_ok()
if body:
self.process(body)
@property
def handshake_headers(self):
"""
List of headers appropriate for the upgrade
handshake.
"""
headers = [
('Host', '%s:%s' % (self.host, self.port)),
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', self.key.decode('utf-8')),
('Sec-WebSocket-Version', str(max(WS_VERSION)))
]
if self.protocols:
headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols)))
if self.extra_headers:
headers.extend(self.extra_headers)
if not any(x for x in headers if x[0].lower() == 'origin') and \
'origin' not in self.exclude_headers:
scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
if parsed.hostname:
self.host = parsed.hostname
else:
self.host = 'localhost'
origin = scheme + '://' + self.host
if parsed.port:
origin = origin + ':' + str(parsed.port)
headers.append(('Origin', origin))
headers = [x for x in headers if x[0].lower() not in self.exclude_headers]
return headers
@property
def handshake_request(self):
"""
Prepare the request to be sent for the upgrade handshake.
"""
headers = self.handshake_headers
request = [("GET %s HTTP/1.1" % self.resource).encode('utf-8')]
for header, value in headers:
request.append(("%s: %s" % (header, value)).encode('utf-8'))
request.append(b'\r\n')
return b'\r\n'.join(request)
def process_response_line(self, response_line):
"""
Ensure that we received a HTTP `101` status code in
response to our request and if not raises :exc:`HandshakeError`.
"""
protocol, code, status = response_line.split(b' ', 2)
if code != b'101':
raise HandshakeError("Invalid response status: %s %s" % (code, status))
def process_handshake_header(self, headers):
"""
Read the upgrade handshake's response headers and
validate them against :rfc:`6455`.
"""
protocols = []
extensions = []
headers = headers.strip()
for header_line in headers.split(b'\r\n'):
header, value = header_line.split(b':', 1)
header = header.strip().lower()
value = value.strip().lower()
if header == b'upgrade' and value != b'websocket':
raise HandshakeError("Invalid Upgrade header: %s" % value)
elif header == b'connection' and value != b'upgrade':
raise HandshakeError("Invalid Connection header: %s" % value)
elif header == b'sec-websocket-accept':
match = b64encode(sha1(self.key + WS_KEY).digest())
if value != match.lower():
raise HandshakeError("Invalid challenge response: %s" % value)
elif header == b'sec-websocket-protocol':
protocols.extend([x.strip() for x in value.split(b',')])
elif header == b'sec-websocket-extensions':
extensions.extend([x.strip() for x in value.split(b',')])
return protocols, extensions
def handshake_ok(self):
self.opened()