| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Unit tests for the retry module.""" |
| |
| from __future__ import absolute_import |
| |
| import unittest |
| from builtins import object |
| |
| from apache_beam.utils import retry |
| |
| # Protect against environments where apitools library is not available. |
| # pylint: disable=wrong-import-order, wrong-import-position |
| # TODO(sourabhbajaj): Remove the GCP specific error code to a submodule |
| try: |
| from apitools.base.py.exceptions import HttpError |
| except ImportError: |
| HttpError = None |
| # pylint: enable=wrong-import-order, wrong-import-position |
| |
| |
| class FakeClock(object): |
| """A fake clock object implementing sleep() and recording calls.""" |
| |
| def __init__(self): |
| self.calls = [] |
| |
| def sleep(self, value): |
| self.calls.append(value) |
| |
| |
| class FakeLogger(object): |
| """A fake logger object implementing log() and recording calls.""" |
| |
| def __init__(self): |
| self.calls = [] |
| |
| def log(self, message, interval, func_name, exn_name, exn_traceback): |
| _ = interval, exn_traceback |
| self.calls.append((message, func_name, exn_name)) |
| |
| |
| @retry.with_exponential_backoff(clock=FakeClock()) |
| def test_function(a, b): |
| _ = a, b |
| raise NotImplementedError |
| |
| |
| @retry.with_exponential_backoff(initial_delay_secs=0.1, num_retries=1) |
| def test_function_with_real_clock(a, b): |
| _ = a, b |
| raise NotImplementedError |
| |
| |
| @retry.no_retries |
| def test_no_retry_function(a, b): |
| _ = a, b |
| raise NotImplementedError |
| |
| |
| class RetryTest(unittest.TestCase): |
| |
| def setUp(self): |
| self.clock = FakeClock() |
| self.logger = FakeLogger() |
| self.calls = 0 |
| |
| def permanent_failure(self, a, b): |
| raise NotImplementedError |
| |
| def transient_failure(self, a, b): |
| self.calls += 1 |
| if self.calls > 4: |
| return a + b |
| raise NotImplementedError |
| |
| def http_error(self, code): |
| if HttpError is None: |
| raise RuntimeError("This is not a valid test as GCP is not enabled") |
| raise HttpError({'status': str(code)}, '', '') |
| |
| def test_with_explicit_decorator(self): |
| # We pass one argument as positional argument and one as keyword argument |
| # so that we cover both code paths for argument handling. |
| self.assertRaises(NotImplementedError, test_function, 10, b=20) |
| |
| def test_with_no_retry_decorator(self): |
| self.assertRaises(NotImplementedError, test_no_retry_function, 1, 2) |
| |
| def test_with_real_clock(self): |
| self.assertRaises(NotImplementedError, |
| test_function_with_real_clock, 10, b=20) |
| |
| def test_with_default_number_of_retries(self): |
| self.assertRaises(NotImplementedError, |
| retry.with_exponential_backoff(clock=self.clock)( |
| self.permanent_failure), |
| 10, b=20) |
| self.assertEqual(len(self.clock.calls), 7) |
| |
| def test_with_explicit_number_of_retries(self): |
| self.assertRaises(NotImplementedError, |
| retry.with_exponential_backoff( |
| clock=self.clock, num_retries=10)( |
| self.permanent_failure), |
| 10, b=20) |
| self.assertEqual(len(self.clock.calls), 10) |
| |
| @unittest.skipIf(HttpError is None, 'google-apitools is not installed') |
| def test_with_http_error_that_should_not_be_retried(self): |
| self.assertRaises(HttpError, |
| retry.with_exponential_backoff( |
| clock=self.clock, num_retries=10)( |
| self.http_error), |
| 404) |
| # Make sure just one call was made. |
| self.assertEqual(len(self.clock.calls), 0) |
| |
| @unittest.skipIf(HttpError is None, 'google-apitools is not installed') |
| def test_with_http_error_that_should_be_retried(self): |
| self.assertRaises(HttpError, |
| retry.with_exponential_backoff( |
| clock=self.clock, num_retries=10)( |
| self.http_error), |
| 500) |
| self.assertEqual(len(self.clock.calls), 10) |
| |
| def test_with_explicit_initial_delay(self): |
| self.assertRaises(NotImplementedError, |
| retry.with_exponential_backoff( |
| initial_delay_secs=10.0, clock=self.clock, |
| fuzz=False)( |
| self.permanent_failure), |
| 10, b=20) |
| self.assertEqual(len(self.clock.calls), 7) |
| self.assertEqual(self.clock.calls[0], 10.0) |
| |
| def test_log_calls_for_permanent_failure(self): |
| self.assertRaises(NotImplementedError, |
| retry.with_exponential_backoff( |
| clock=self.clock, logger=self.logger.log)( |
| self.permanent_failure), |
| 10, b=20) |
| self.assertEqual(len(self.logger.calls), 7) |
| for message, func_name, exn_name in self.logger.calls: |
| self.assertTrue(message.startswith('Retry with exponential backoff:')) |
| self.assertEqual(exn_name, 'NotImplementedError\n') |
| self.assertEqual(func_name, 'permanent_failure') |
| |
| def test_log_calls_for_transient_failure(self): |
| result = retry.with_exponential_backoff( |
| clock=self.clock, logger=self.logger.log, fuzz=False)( |
| self.transient_failure)(10, b=20) |
| self.assertEqual(result, 30) |
| self.assertEqual(len(self.clock.calls), 4) |
| self.assertEqual(self.clock.calls, |
| [5.0 * 1, 5.0 * 2, 5.0 * 4, 5.0 * 8,]) |
| self.assertEqual(len(self.logger.calls), 4) |
| for message, func_name, exn_name in self.logger.calls: |
| self.assertTrue(message.startswith('Retry with exponential backoff:')) |
| self.assertEqual(exn_name, 'NotImplementedError\n') |
| self.assertEqual(func_name, 'transient_failure') |
| |
| |
| class DummyClass(object): |
| def __init__(self, results): |
| self.index = 0 |
| self.results = results |
| |
| @retry.with_exponential_backoff(num_retries=2, initial_delay_secs=0.1) |
| def func(self): |
| self.index += 1 |
| if self.index > len(self.results) or \ |
| self.results[self.index - 1] == "Error": |
| raise ValueError("Error") |
| return self.results[self.index - 1] |
| |
| |
| class RetryStateTest(unittest.TestCase): |
| """The test_two_failures and test_single_failure would fail if we have |
| any shared state for the retry decorator. This test tries to prevent a bug we |
| found where the state in the decorator was shared across objects and retries |
| were not available correctly. |
| |
| The test_call_two_objects would test this inside the same test. |
| """ |
| def test_two_failures(self): |
| dummy = DummyClass(["Error", "Error", "Success"]) |
| dummy.func() |
| self.assertEqual(3, dummy.index) |
| |
| def test_single_failure(self): |
| dummy = DummyClass(["Error", "Success"]) |
| dummy.func() |
| self.assertEqual(2, dummy.index) |
| |
| def test_call_two_objects(self): |
| dummy = DummyClass(["Error", "Error", "Success"]) |
| dummy.func() |
| self.assertEqual(3, dummy.index) |
| |
| dummy2 = DummyClass(["Error", "Success"]) |
| dummy2.func() |
| self.assertEqual(2, dummy2.index) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |