blob: 7153ed620a77c775e046d303cb6900fb9d236536 [file] [log] [blame]
import time
from tools.datahelp import flatten_into_set
from tools.misc import list_to_hashed_dict
class Page(object):
data = None
def __init__(self):
self.data = []
def add_row(self, row):
self.data.append(row)
class PageFetcher(object):
"""
Requests pages, handles their receipt,
and provides paged data for testing.
The first page is automatically retrieved, so an initial
call to request_one is actually getting the *second* page!
"""
pages = None
error = None
future = None
requested_pages = None
retrieved_pages = None
retrieved_empty_pages = None
def __init__(self, future):
self.pages = []
# the first page is automagically returned (eventually)
# so we'll count this as a request, but the retrieved count
# won't be incremented until it actually arrives
self.requested_pages = 1
self.retrieved_pages = 0
self.retrieved_empty_pages = 0
self.future = future
self.future.add_callbacks(
callback=self.handle_page,
errback=self.handle_error
)
# wait for the first page to arrive, otherwise we may call
# future.has_more_pages too early, since it should only be
# called after the first page is returned
self.wait(seconds=30)
def handle_page(self, rows):
# occasionally get a final blank page that is useless
if rows == []:
self.retrieved_empty_pages += 1
return
page = Page()
self.pages.append(page)
for row in rows:
page.add_row(row)
self.retrieved_pages += 1
def handle_error(self, exc):
self.error = exc
raise exc
def request_one(self, timeout=None):
"""
Requests the next page if there is one.
If the future is exhausted, this is a no-op.
@param timeout Time, in seconds, to wait for all pages.
"""
if self.future.has_more_pages:
self.future.start_fetching_next_page()
self.requested_pages += 1
self.wait(seconds=timeout)
return self
def request_all(self, timeout=None):
"""
Requests any remaining pages.
If the future is exhausted, this is a no-op.
@param timeout Time, in seconds, to wait for all pages.
"""
while self.future.has_more_pages:
self.future.start_fetching_next_page()
self.requested_pages += 1
self.wait(seconds=timeout)
return self
def wait(self, seconds=None):
"""
Blocks until all *requested* pages have been returned.
Requests are made by calling request_one and/or request_all.
Raises RuntimeError if seconds is exceeded.
"""
seconds = 5 if seconds is None else seconds
expiry = time.time() + seconds
while time.time() < expiry:
if self.requested_pages == (self.retrieved_pages + self.retrieved_empty_pages):
return self
# small wait so we don't need excess cpu to keep checking
time.sleep(0.1)
raise RuntimeError(
"Requested pages were not delivered before timeout. " +
"Requested: {}; retrieved: {}; empty retrieved: {}".format(self.requested_pages, self.retrieved_pages, self.retrieved_empty_pages))
def pagecount(self):
"""
Returns count of *retrieved* pages which were not empty.
Pages are retrieved by requesting them with request_one and/or request_all.
"""
return len(self.pages)
def num_results(self, page_num):
"""
Returns the number of results found at page_num
"""
return len(self.pages[page_num - 1].data)
def num_results_all(self):
return [len(page.data) for page in self.pages]
def page_data(self, page_num):
"""
Returns retreived data found at pagenum.
The page should have already been requested with request_one and/or request_all.
"""
return self.pages[page_num - 1].data
def all_data(self):
"""
Returns all retrieved data flattened into a single list (instead of separated into Page objects).
The page(s) should have already been requested with request_one and/or request_all.
"""
all_pages_combined = []
for page in self.pages:
all_pages_combined.extend(page.data[:])
return all_pages_combined
@property # make property to match python driver api
def has_more_pages(self):
"""
Returns bool indicating if there are any pages not retrieved.
"""
return self.future.has_more_pages
class PageAssertionMixin(object):
"""Can be added to subclasses of unittest.Tester"""
def assertEqualIgnoreOrder(self, actual, expected):
hashed_expected = list_to_hashed_dict(expected)
hashed_actual = list_to_hashed_dict(actual)
for key, expected in hashed_expected.items():
assert key in hashed_actual, "expected %s not in actual" % str(expected)
actual = hashed_actual[key]
assert actual == expected, "actual %s not same as expected %s" % (str(actual), str(expected))
for key, actual in hashed_actual.items():
assert key in hashed_expected, "actual %s not in expected" % str(actual)
expected = hashed_expected[key]
assert expected == actual, "expected %s not same as actual %s" % (str(expected), str(actual))
assert hashed_expected == hashed_actual
def assertIsSubsetOf(self, subset, superset):
assert flatten_into_set(subset) <= flatten_into_set(superset)