blob: e67ac2de0de68e9f7bf2edf07fbbfc055c83186b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import requests
from libcloud.common.base import Response
from libcloud.http import LibcloudConnection
from libcloud.utils.py3 import PY2
if PY2:
from StringIO import StringIO
else:
from io import StringIO
import requests_mock
from libcloud.utils.py3 import httplib
from libcloud.utils.py3 import urlparse
from libcloud.utils.py3 import parse_qs
from libcloud.utils.py3 import parse_qsl
from libcloud.utils.py3 import urlquote
XML_HEADERS = {"content-type": "application/xml"}
class LibcloudTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
self._visited_urls = []
self._executed_mock_methods = []
super(LibcloudTestCase, self).__init__(*args, **kwargs)
def setUp(self):
self._visited_urls = []
self._executed_mock_methods = []
def _add_visited_url(self, url):
self._visited_urls.append(url)
def _add_executed_mock_method(self, method_name):
self._executed_mock_methods.append(method_name)
def assertExecutedMethodCount(self, expected):
actual = len(self._executed_mock_methods)
self.assertEqual(
actual,
expected,
"expected %d, but %d mock methods were executed" % (expected, actual),
)
class multipleresponse(object):
"""
A decorator that allows MockHttp objects to return multi responses
"""
count = 0
func = None
def __init__(self, f):
self.func = f
def __call__(self, *args, **kwargs):
ret = self.func(self.func.__class__, *args, **kwargs)
response = ret[self.count]
self.count = self.count + 1
return response
class BodyStream(StringIO):
def next(self, chunk_size=None):
return StringIO.next(self)
def __next__(self, chunk_size=None):
return StringIO.__next__(self)
def read(self, chunk_size=None):
return StringIO.read(self)
class MockHttp(LibcloudConnection):
"""
A mock HTTP client/server suitable for testing purposes. This replaces
`HTTPConnection` by implementing its API and returning a mock response.
Define methods by request path, replacing slashes (/) with underscores (_).
Each of these mock methods should return a tuple of:
(int status, str body, dict headers, str reason)
"""
type = None
use_param = None # will use this param to namespace the request function
test = None # TestCase instance which is using this mock
proxy_url = None
def __init__(self, *args, **kwargs):
# Load assertion methods into the class, incase people want to assert
# within a response
if isinstance(self, unittest.TestCase):
unittest.TestCase.__init__(self, "__init__")
super(MockHttp, self).__init__(*args, **kwargs)
def _get_request(self, method, url, body=None, headers=None):
# Find a method we can use for this request
parsed = urlparse.urlparse(url)
_, _, path, _, query, _ = parsed
qs = parse_qs(query)
if path.endswith("/"):
path = path[:-1]
meth_name = self._get_method_name(
type=self.type, use_param=self.use_param, qs=qs, path=path
)
meth = getattr(self, meth_name.replace("%", "_"))
if self.test and isinstance(self.test, LibcloudTestCase):
self.test._add_visited_url(url=url)
self.test._add_executed_mock_method(method_name=meth_name)
return meth(method, url, body, headers)
def request(self, method, url, body=None, headers=None, raw=False, stream=False):
headers = self._normalize_headers(headers=headers)
r_status, r_body, r_headers, r_reason = self._get_request(
method, url, body, headers
)
if r_body is None:
r_body = ""
# this is to catch any special chars e.g. ~ in the request. URL
url = urlquote(url)
with requests_mock.mock() as m:
m.register_uri(
method,
url,
text=r_body,
reason=r_reason,
headers=r_headers,
status_code=r_status,
)
try:
super(MockHttp, self).request(
method=method,
url=url,
body=body,
headers=headers,
raw=raw,
stream=stream,
)
except requests_mock.exceptions.NoMockAddress as nma:
raise AttributeError(
"Failed to mock out URL {0} - {1}".format(url, nma.request.url)
)
def prepared_request(
self, method, url, body=None, headers=None, raw=False, stream=False
):
headers = self._normalize_headers(headers=headers)
r_status, r_body, r_headers, r_reason = self._get_request(
method, url, body, headers
)
with requests_mock.mock() as m:
m.register_uri(
method,
url,
text=r_body,
reason=r_reason,
headers=r_headers,
status_code=r_status,
)
super(MockHttp, self).prepared_request(
method=method,
url=url,
body=body,
headers=headers,
raw=raw,
stream=stream,
)
# Mock request/response example
def _example(self, method, url, body, headers):
"""
Return a simple message and header, regardless of input.
"""
return (
httplib.OK,
"Hello World!",
{"X-Foo": "libcloud"},
httplib.responses[httplib.OK],
)
def _example_fail(self, method, url, body, headers):
return (
httplib.FORBIDDEN,
"Oh Noes!",
{"X-Foo": "fail"},
httplib.responses[httplib.FORBIDDEN],
)
def _get_method_name(self, type, use_param, qs, path):
path = path.split("?")[0]
meth_name = (
path.replace("/", "_")
.replace(".", "_")
.replace("-", "_")
.replace("~", "%7E")
) # Python 3.7 no longer quotes ~
if type:
meth_name = "%s_%s" % (meth_name, self.type)
if use_param and use_param in qs:
param = qs[use_param][0].replace(".", "_").replace("-", "_")
meth_name = "%s_%s" % (meth_name, param)
if meth_name == "":
meth_name = "root"
return meth_name
def assertUrlContainsQueryParams(self, url, expected_params, strict=False):
"""
Assert that provided url contains provided query parameters.
:param url: URL to assert.
:type url: ``str``
:param expected_params: Dictionary of expected query parameters.
:type expected_params: ``dict``
:param strict: Assert that provided url contains only expected_params.
(defaults to ``False``)
:type strict: ``bool``
"""
question_mark_index = url.find("?")
if question_mark_index != -1:
url = url[question_mark_index + 1 :]
params = dict(parse_qsl(url))
if strict:
assert params == expected_params
else:
for key, value in expected_params.items():
assert key in params
assert params[key] == value
class MockConnection(object):
def __init__(self, action):
self.action = action
StorageMockHttp = MockHttp
def make_response(status=200, headers={}, connection=None):
response = requests.Response()
response.status_code = status
response.headers = headers
return Response(response, connection)
def generate_random_data(size):
data = ""
current_size = 0
while current_size < size:
value = str(random.randint(0, 9))
value_size = len(value)
data += value
current_size += value_size
return data
if __name__ == "__main__":
import doctest
doctest.testmod()