| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| import inspect |
| import logging |
| import time |
| import os |
| import platform |
| import ssl |
| import sys |
| import tempfile |
| import threading |
| import unittest |
| import warnings |
| from contextlib import contextmanager |
| |
| import _import_local_thrift # noqa |
| |
| SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) |
| SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') |
| SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') |
| SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') |
| CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') |
| CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') |
| CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') |
| CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') |
| CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') |
| |
| TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' |
| |
| |
| class ServerAcceptor(threading.Thread): |
| def __init__(self, server, expect_failure=False): |
| super(ServerAcceptor, self).__init__() |
| self.daemon = True |
| self._server = server |
| self._listening = threading.Event() |
| self._port = None |
| self._port_bound = threading.Event() |
| self._client = None |
| self._client_accepted = threading.Event() |
| self._expect_failure = expect_failure |
| frame = inspect.stack(3)[2] |
| self.name = frame[3] |
| del frame |
| |
| def run(self): |
| self._server.listen() |
| self._listening.set() |
| |
| try: |
| address = self._server.handle.getsockname() |
| if len(address) > 1: |
| # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are |
| # 4-tuples (host, port, ...), but in each case port is in the second slot. |
| self._port = address[1] |
| finally: |
| self._port_bound.set() |
| |
| try: |
| self._client = self._server.accept() |
| if self._client: |
| data = self._client.read(5) # hello/sleep |
| if data == b"sleep": |
| time.sleep(2) |
| self._client.write(b"there") |
| except Exception: |
| logging.exception('error on server side (%s):' % self.name) |
| if not self._expect_failure: |
| raise |
| finally: |
| self._client_accepted.set() |
| |
| def await_listening(self): |
| self._listening.wait() |
| |
| @property |
| def port(self): |
| self._port_bound.wait() |
| return self._port |
| |
| @property |
| def client(self): |
| self._client_accepted.wait() |
| return self._client |
| |
| def close(self): |
| if self._client: |
| self._client.close() |
| self._server.close() |
| |
| |
| # Python 2.6 compat |
| class AssertRaises(object): |
| def __init__(self, expected): |
| self._expected = expected |
| |
| def __enter__(self): |
| pass |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| if not exc_type or not issubclass(exc_type, self._expected): |
| raise Exception('fail') |
| return True |
| |
| |
| @unittest.skip("failing SSL test to be fixed in subsequent pull request") |
| class TSSLSocketTest(unittest.TestCase): |
| def _server_socket(self, **kwargs): |
| return TSSLServerSocket(port=0, **kwargs) |
| |
| @contextmanager |
| def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): |
| acc = ServerAcceptor(server, expect_failure) |
| try: |
| acc.start() |
| acc.await_listening() |
| |
| host, port = ('localhost', acc.port) if path is None else (None, None) |
| client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) |
| yield acc, client |
| finally: |
| acc.close() |
| |
| def _assert_connection_failure(self, server, path=None, **client_args): |
| logging.disable(logging.CRITICAL) |
| try: |
| with self._connectable_client(server, True, path=path, **client_args) as (acc, client): |
| # We need to wait for a connection failure, but not too long. 20ms is a tunable |
| # compromise between test speed and stability |
| client.setTimeout(20) |
| with self._assert_raises(TTransportException): |
| client.open() |
| client.write(b"hello") |
| client.read(5) # b"there" |
| finally: |
| logging.disable(logging.NOTSET) |
| |
| def _assert_raises(self, exc): |
| if sys.hexversion >= 0x020700F0: |
| return self.assertRaises(exc) |
| else: |
| return AssertRaises(exc) |
| |
| def _assert_connection_success(self, server, path=None, **client_args): |
| with self._connectable_client(server, path=path, **client_args) as (acc, client): |
| try: |
| self.assertFalse(client.isOpen()) |
| client.open() |
| self.assertTrue(client.isOpen()) |
| client.write(b"hello") |
| self.assertEqual(client.read(5), b"there") |
| self.assertTrue(acc.client is not None) |
| finally: |
| client.close() |
| |
| # deprecated feature |
| def test_deprecation(self): |
| with warnings.catch_warnings(record=True) as w: |
| warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) |
| TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) |
| self.assertEqual(len(w), 1) |
| |
| with warnings.catch_warnings(record=True) as w: |
| warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) |
| # Deprecated signature |
| # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): |
| TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) |
| self.assertEqual(len(w), 7) |
| |
| with warnings.catch_warnings(record=True) as w: |
| warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) |
| # Deprecated signature |
| # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): |
| TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) |
| self.assertEqual(len(w), 3) |
| |
| # deprecated feature |
| def test_set_cert_reqs_by_validate(self): |
| with warnings.catch_warnings(record=True) as w: |
| warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) |
| c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) |
| self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) |
| |
| c1 = TSSLSocket('localhost', 0, validate=False) |
| self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) |
| |
| self.assertEqual(len(w), 2) |
| |
| # deprecated feature |
| def test_set_validate_by_cert_reqs(self): |
| with warnings.catch_warnings(record=True) as w: |
| warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) |
| c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) |
| self.assertFalse(c1.validate) |
| |
| c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) |
| self.assertTrue(c2.validate) |
| |
| c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) |
| self.assertTrue(c3.validate) |
| |
| self.assertEqual(len(w), 3) |
| |
| def test_unix_domain_socket(self): |
| if platform.system() == 'Windows': |
| print('skipping test_unix_domain_socket') |
| return |
| fd, path = tempfile.mkstemp() |
| os.close(fd) |
| os.unlink(path) |
| try: |
| server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) |
| finally: |
| os.unlink(path) |
| |
| def test_server_cert(self): |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) |
| |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| # server cert not in ca_certs |
| self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) |
| |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) |
| |
| def test_set_server_cert(self): |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) |
| with self._assert_raises(Exception): |
| server.certfile = 'foo' |
| with self._assert_raises(Exception): |
| server.certfile = None |
| server.certfile = SERVER_CERT |
| self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) |
| |
| def test_client_cert(self): |
| if not _match_has_ipaddress: |
| print('skipping test_client_cert') |
| return |
| server = self._server_socket( |
| cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, |
| certfile=SERVER_CERT, ca_certs=CLIENT_CERT) |
| self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) |
| |
| server = self._server_socket( |
| cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, |
| certfile=SERVER_CERT, ca_certs=CLIENT_CA) |
| self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) |
| |
| server = self._server_socket( |
| cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, |
| certfile=SERVER_CERT, ca_certs=CLIENT_CA) |
| self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) |
| |
| server = self._server_socket( |
| cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, |
| certfile=SERVER_CERT, ca_certs=CLIENT_CA) |
| self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) |
| |
| def test_ciphers(self): |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) |
| self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) |
| |
| if not TSSLSocket._has_ciphers: |
| # unittest.skip is not available for Python 2.6 |
| print('skipping test_ciphers') |
| return |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') |
| |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') |
| |
| def test_ssl2_and_ssl3_disabled(self): |
| if not hasattr(ssl, 'PROTOCOL_SSLv3'): |
| print('PROTOCOL_SSLv3 is not available') |
| else: |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) |
| |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT) |
| |
| if not hasattr(ssl, 'PROTOCOL_SSLv2'): |
| print('PROTOCOL_SSLv2 is not available') |
| else: |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) |
| |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT) |
| |
| def test_newer_tls(self): |
| if not TSSLSocket._has_ssl_context: |
| # unittest.skip is not available for Python 2.6 |
| print('skipping test_newer_tls') |
| return |
| if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): |
| print('PROTOCOL_TLSv1_2 is not available') |
| else: |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) |
| self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) |
| |
| if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): |
| print('PROTOCOL_TLSv1_1 is not available') |
| else: |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) |
| self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) |
| |
| if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): |
| print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') |
| else: |
| server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) |
| self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) |
| |
| def test_ssl_context(self): |
| if not TSSLSocket._has_ssl_context: |
| # unittest.skip is not available for Python 2.6 |
| print('skipping test_ssl_context') |
| return |
| server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
| server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) |
| server_context.load_verify_locations(CLIENT_CA) |
| server_context.verify_mode = ssl.CERT_REQUIRED |
| server = self._server_socket(ssl_context=server_context) |
| |
| client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) |
| client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) |
| client_context.load_verify_locations(SERVER_CERT) |
| client_context.verify_mode = ssl.CERT_REQUIRED |
| |
| self._assert_connection_success(server, ssl_context=client_context) |
| |
| |
| # Add a dummy test because starting from python 3.12, if all tests in a test |
| # file are skipped that's considered an error. |
| class DummyTest(unittest.TestCase): |
| def test_dummy(self): |
| self.assertEqual(0, 0) |
| |
| |
| if __name__ == '__main__': |
| logging.basicConfig(level=logging.WARN) |
| from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress |
| from thrift.transport.TTransport import TTransportException |
| |
| unittest.main() |