blob: 2572a72ae05c7d528833ef61fa7e6c4d8575fbfb [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for Throttling Handler of GCSIO."""
import unittest
from unittest.mock import Mock
from apache_beam.metrics.execution import MetricsContainer
from apache_beam.metrics.execution import MetricsEnvironment
from apache_beam.metrics.metricbase import MetricName
from apache_beam.runners.worker import statesampler
from apache_beam.utils import counters
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from google.api_core import exceptions as api_exceptions
from apache_beam.io.gcp import gcsio_retry
except ImportError:
gcsio_retry = None
api_exceptions = None
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
@unittest.skipIf((gcsio_retry is None or api_exceptions is None),
'GCP dependencies are not installed')
class TestGCSIORetry(unittest.TestCase):
def test_retry_on_non_retriable(self):
mock = Mock(side_effect=[
Exception('Something wrong!'),
])
retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER
with self.assertRaises(Exception):
retry(mock)()
def test_retry_on_throttling(self):
mock = Mock(
side_effect=[
api_exceptions.TooManyRequests("Slow down!"),
api_exceptions.TooManyRequests("Slow down again!"),
12345
])
retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER
sampler = statesampler.StateSampler('', counters.CounterFactory())
statesampler.set_current_tracker(sampler)
state = sampler.scoped_state(
'my_step', 'my_state', metrics_container=MetricsContainer('my_step'))
try:
sampler.start()
with state:
container = MetricsEnvironment.current_container()
self.assertEqual(
container.get_counter(
MetricName('gcsio',
"cumulativeThrottlingSeconds")).get_cumulative(),
0)
self.assertEqual(12345, retry(mock)())
self.assertGreater(
container.get_counter(
MetricName('gcsio',
"cumulativeThrottlingSeconds")).get_cumulative(),
1)
finally:
sampler.stop()
if __name__ == '__main__':
unittest.main()