| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import unittest |
| import random |
| import requests |
| from libcloud.common.base import Response |
| from libcloud.http import LibcloudConnection |
| from libcloud.utils.py3 import PY2 |
| |
| if PY2: |
| from StringIO import StringIO |
| else: |
| from io import StringIO |
| |
| import requests_mock |
| |
| from libcloud.utils.py3 import httplib |
| from libcloud.utils.py3 import urlparse |
| from libcloud.utils.py3 import parse_qs |
| from libcloud.utils.py3 import parse_qsl |
| from libcloud.utils.py3 import urlquote |
| |
| |
| XML_HEADERS = {"content-type": "application/xml"} |
| |
| |
| class LibcloudTestCase(unittest.TestCase): |
| def __init__(self, *args, **kwargs): |
| self._visited_urls = [] |
| self._executed_mock_methods = [] |
| super(LibcloudTestCase, self).__init__(*args, **kwargs) |
| |
| def setUp(self): |
| self._visited_urls = [] |
| self._executed_mock_methods = [] |
| |
| def _add_visited_url(self, url): |
| self._visited_urls.append(url) |
| |
| def _add_executed_mock_method(self, method_name): |
| self._executed_mock_methods.append(method_name) |
| |
| def assertExecutedMethodCount(self, expected): |
| actual = len(self._executed_mock_methods) |
| self.assertEqual( |
| actual, |
| expected, |
| "expected %d, but %d mock methods were executed" % (expected, actual), |
| ) |
| |
| |
| class multipleresponse(object): |
| """ |
| A decorator that allows MockHttp objects to return multi responses |
| """ |
| |
| count = 0 |
| func = None |
| |
| def __init__(self, f): |
| self.func = f |
| |
| def __call__(self, *args, **kwargs): |
| ret = self.func(self.func.__class__, *args, **kwargs) |
| response = ret[self.count] |
| self.count = self.count + 1 |
| return response |
| |
| |
| class BodyStream(StringIO): |
| def next(self, chunk_size=None): |
| return StringIO.next(self) |
| |
| def __next__(self, chunk_size=None): |
| return StringIO.__next__(self) |
| |
| def read(self, chunk_size=None): |
| return StringIO.read(self) |
| |
| |
| class MockHttp(LibcloudConnection): |
| """ |
| A mock HTTP client/server suitable for testing purposes. This replaces |
| `HTTPConnection` by implementing its API and returning a mock response. |
| |
| Define methods by request path, replacing slashes (/) with underscores (_). |
| Each of these mock methods should return a tuple of: |
| |
| (int status, str body, dict headers, str reason) |
| """ |
| |
| type = None |
| use_param = None # will use this param to namespace the request function |
| test = None # TestCase instance which is using this mock |
| proxy_url = None |
| |
| def __init__(self, *args, **kwargs): |
| # Load assertion methods into the class, incase people want to assert |
| # within a response |
| if isinstance(self, unittest.TestCase): |
| unittest.TestCase.__init__(self, "__init__") |
| super(MockHttp, self).__init__(*args, **kwargs) |
| |
| def _get_request(self, method, url, body=None, headers=None): |
| # Find a method we can use for this request |
| parsed = urlparse.urlparse(url) |
| _, _, path, _, query, _ = parsed |
| qs = parse_qs(query) |
| if path.endswith("/"): |
| path = path[:-1] |
| meth_name = self._get_method_name( |
| type=self.type, use_param=self.use_param, qs=qs, path=path |
| ) |
| meth = getattr(self, meth_name.replace("%", "_")) |
| |
| if self.test and isinstance(self.test, LibcloudTestCase): |
| self.test._add_visited_url(url=url) |
| self.test._add_executed_mock_method(method_name=meth_name) |
| return meth(method, url, body, headers) |
| |
| def request(self, method, url, body=None, headers=None, raw=False, stream=False): |
| headers = self._normalize_headers(headers=headers) |
| r_status, r_body, r_headers, r_reason = self._get_request( |
| method, url, body, headers |
| ) |
| if r_body is None: |
| r_body = "" |
| # this is to catch any special chars e.g. ~ in the request. URL |
| url = urlquote(url) |
| |
| with requests_mock.mock() as m: |
| m.register_uri( |
| method, |
| url, |
| text=r_body, |
| reason=r_reason, |
| headers=r_headers, |
| status_code=r_status, |
| ) |
| try: |
| super(MockHttp, self).request( |
| method=method, |
| url=url, |
| body=body, |
| headers=headers, |
| raw=raw, |
| stream=stream, |
| ) |
| except requests_mock.exceptions.NoMockAddress as nma: |
| raise AttributeError( |
| "Failed to mock out URL {0} - {1}".format(url, nma.request.url) |
| ) |
| |
| def prepared_request( |
| self, method, url, body=None, headers=None, raw=False, stream=False |
| ): |
| headers = self._normalize_headers(headers=headers) |
| r_status, r_body, r_headers, r_reason = self._get_request( |
| method, url, body, headers |
| ) |
| |
| with requests_mock.mock() as m: |
| m.register_uri( |
| method, |
| url, |
| text=r_body, |
| reason=r_reason, |
| headers=r_headers, |
| status_code=r_status, |
| ) |
| super(MockHttp, self).prepared_request( |
| method=method, |
| url=url, |
| body=body, |
| headers=headers, |
| raw=raw, |
| stream=stream, |
| ) |
| |
| # Mock request/response example |
| def _example(self, method, url, body, headers): |
| """ |
| Return a simple message and header, regardless of input. |
| """ |
| return ( |
| httplib.OK, |
| "Hello World!", |
| {"X-Foo": "libcloud"}, |
| httplib.responses[httplib.OK], |
| ) |
| |
| def _example_fail(self, method, url, body, headers): |
| return ( |
| httplib.FORBIDDEN, |
| "Oh Noes!", |
| {"X-Foo": "fail"}, |
| httplib.responses[httplib.FORBIDDEN], |
| ) |
| |
| def _get_method_name(self, type, use_param, qs, path): |
| path = path.split("?")[0] |
| meth_name = ( |
| path.replace("/", "_") |
| .replace(".", "_") |
| .replace("-", "_") |
| .replace("~", "%7E") |
| ) # Python 3.7 no longer quotes ~ |
| |
| if type: |
| meth_name = "%s_%s" % (meth_name, self.type) |
| |
| if use_param and use_param in qs: |
| param = qs[use_param][0].replace(".", "_").replace("-", "_") |
| meth_name = "%s_%s" % (meth_name, param) |
| |
| if meth_name == "": |
| meth_name = "root" |
| |
| return meth_name |
| |
| def assertUrlContainsQueryParams(self, url, expected_params, strict=False): |
| """ |
| Assert that provided url contains provided query parameters. |
| |
| :param url: URL to assert. |
| :type url: ``str`` |
| |
| :param expected_params: Dictionary of expected query parameters. |
| :type expected_params: ``dict`` |
| |
| :param strict: Assert that provided url contains only expected_params. |
| (defaults to ``False``) |
| :type strict: ``bool`` |
| """ |
| question_mark_index = url.find("?") |
| |
| if question_mark_index != -1: |
| url = url[question_mark_index + 1 :] |
| |
| params = dict(parse_qsl(url)) |
| |
| if strict: |
| assert params == expected_params |
| else: |
| for key, value in expected_params.items(): |
| assert key in params |
| assert params[key] == value |
| |
| |
| class MockConnection(object): |
| def __init__(self, action): |
| self.action = action |
| |
| |
| StorageMockHttp = MockHttp |
| |
| |
| def make_response(status=200, headers={}, connection=None): |
| response = requests.Response() |
| response.status_code = status |
| response.headers = headers |
| return Response(response, connection) |
| |
| |
| def generate_random_data(size): |
| data = "" |
| current_size = 0 |
| while current_size < size: |
| value = str(random.randint(0, 9)) |
| value_size = len(value) |
| data += value |
| current_size += value_size |
| return data |
| |
| |
| if __name__ == "__main__": |
| import doctest |
| |
| doctest.testmod() |