| #!/usr/bin/env python |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import print_function, unicode_literals |
| |
| import re |
| import ssl |
| |
| from thrift.transport import TSSLSocket |
| from thrift.transport.TTransport import TTransportException |
| |
| class CertificateError(ValueError): |
| """Convenience class to raise errors""" |
| pass |
| |
| class TSSLSocketWithWildcardSAN(TSSLSocket.TSSLSocket): |
| """ |
| This is a subclass of thrift's TSSLSocket which has been extended to add the missing |
| functionality of validating wildcard certificates and certificates with SANs |
| (subjectAlternativeName). |
| |
| The core of the validation logic is based on the python-ssl library: |
| See <https://svn.python.org/projects/python/tags/r32/Lib/ssl.py> |
| """ |
| def __init__(self, |
| host='localhost', |
| port=9090, |
| validate=True, |
| ca_certs=None, |
| unix_socket=None): |
| cert_reqs = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE |
| # Set client protocol choice to be very permissive, as we rely on servers to enforce |
| # good protocol selection. This value is forwarded to the ssl.wrap_socket() API during |
| # open(). See https://docs.python.org/2/library/ssl.html#socket-creation for a table |
| # that shows a better option is not readily available for sockets that use |
| # wrap_socket(). |
| # THRIFT-3505 changes transport/TSSLSocket.py. The SSL_VERSION is passed to TSSLSocket |
| # via a parameter. |
| TSSLSocket.TSSLSocket.__init__(self, host=host, port=port, cert_reqs=cert_reqs, |
| ca_certs=ca_certs, unix_socket=unix_socket, |
| ssl_version=ssl.PROTOCOL_SSLv23) |
| |
| # THRIFT-5595: override TSocket.isOpen because it's broken for TSSLSocket |
| def isOpen(self): |
| return self.handle is not None |
| |
| def _validate_cert(self): |
| cert = self.handle.getpeercert() |
| self.peercert = cert |
| if 'subject' not in cert: |
| raise TTransportException( |
| type=TTransportException.NOT_OPEN, |
| message='No SSL certificate found from %s:%s' % (self.host, self.port)) |
| try: |
| self._match_hostname(cert, self.host) |
| self.is_valid = True |
| return |
| except CertificateError as ce: |
| raise TTransportException( |
| type=TTransportException.UNKNOWN, |
| message='Certificate error with remote host: %s' % (ce)) |
| raise TTransportException( |
| type=TTransportException.UNKNOWN, |
| message='Could not validate SSL certificate from ' |
| 'host "%s". Cert=%s' % (self.host, cert)) |
| |
| def _match_hostname(self, cert, hostname): |
| """Verify that *cert* (in decoded format as returned by |
| SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 |
| rules are followed, but IP addresses are not accepted for *hostname*. |
| |
| CertificateError is raised on failure. On success, the function |
| returns nothing. |
| """ |
| dnsnames = [] |
| san = cert.get('subjectAltName', ()) |
| for key, value in san: |
| if key == 'DNS': |
| if self._dnsname_match(value, hostname): |
| return |
| dnsnames.append(value) |
| if not dnsnames: |
| # The subject is only checked when there is no dNSName entry |
| # in subjectAltName |
| for sub in cert.get('subject', ()): |
| for key, value in sub: |
| # XXX according to RFC 2818, the most specific Common Name |
| # must be used. |
| if key == 'commonName': |
| if self._dnsname_match(value, hostname): |
| return |
| dnsnames.append(value) |
| if len(dnsnames) > 1: |
| raise CertificateError("hostname %r " |
| "doesn't match either of %s" |
| % (hostname, ', '.join(map(repr, dnsnames)))) |
| elif len(dnsnames) == 1: |
| raise CertificateError("hostname %r " |
| "doesn't match %r" |
| % (hostname, dnsnames[0])) |
| else: |
| raise CertificateError("no appropriate commonName or " |
| "subjectAltName fields were found") |
| |
| def _dnsname_match(self, dn, hostname, max_wildcards=1): |
| """Matching according to RFC 6125, section 6.4.3 |
| http://tools.ietf.org/html/rfc6125#section-6.4.3 |
| """ |
| pats = [] |
| if not dn: |
| return False |
| |
| # Ported from python3-syntax: |
| # leftmost, *remainder = dn.split(r'.') |
| parts = dn.split(r'.') |
| leftmost = parts[0] |
| remainder = parts[1:] |
| |
| wildcards = leftmost.count('*') |
| if wildcards > max_wildcards: |
| # Issue #17980: avoid denials of service by refusing more |
| # than one wildcard per fragment. A survey of established |
| # policy among SSL implementations showed it to be a |
| # reasonable choice. |
| raise CertificateError( |
| "too many wildcards in certificate DNS name: " + repr(dn)) |
| |
| # speed up common case w/o wildcards |
| if not wildcards: |
| return dn.lower() == hostname.lower() |
| |
| # RFC 6125, section 6.4.3, subitem 1. |
| # The client SHOULD NOT attempt to match a presented identifier in which |
| # the wildcard character comprises a label other than the left-most label. |
| if leftmost == '*': |
| # When '*' is a fragment by itself, it matches a non-empty dotless |
| # fragment. |
| pats.append('[^.]+') |
| elif leftmost.startswith('xn--') or hostname.startswith('xn--'): |
| # RFC 6125, section 6.4.3, subitem 3. |
| # The client SHOULD NOT attempt to match a presented identifier |
| # where the wildcard character is embedded within an A-label or |
| # U-label of an internationalized domain name. |
| pats.append(re.escape(leftmost)) |
| else: |
| # Otherwise, '*' matches any dotless string, e.g. www* |
| pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) |
| |
| # add the remaining fragments, ignore any wildcards |
| for frag in remainder: |
| pats.append(re.escape(frag)) |
| |
| pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) |
| return pat.match(hostname) |