blob: 0d0ce1d2c8dc32b8bd37ce4a1bbde53067b2ac74 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for state sampler."""
# pytype: skip-file
import logging
import time
import unittest
from unittest import mock
from unittest.mock import Mock
from unittest.mock import patch
from tenacity import retry
from tenacity import stop_after_attempt
from apache_beam.internal import pickler
from apache_beam.runners import common
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
from apache_beam.transforms import core
from apache_beam.transforms import userstate
from apache_beam.transforms.core import GlobalWindows
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.counters import CounterFactory
from apache_beam.utils.counters import CounterName
from apache_beam.utils.windowed_value import PaneInfo
_LOGGER = logging.getLogger(__name__)
class TimerDoFn(core.DoFn):
TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
def __init__(self, sleep_duration_s=0):
self._sleep_duration_s = sleep_duration_s
@userstate.on_timer(TIMER_SPEC)
def on_timer_f(self):
if self._sleep_duration_s:
time.sleep(self._sleep_duration_s)
class ExceptionTimerDoFn(core.DoFn):
"""A DoFn that raises an exception when its timer fires."""
TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK)
def __init__(self, sleep_duration_s=0):
self._sleep_duration_s = sleep_duration_s
@userstate.on_timer(TIMER_SPEC)
def on_timer_f(self):
if self._sleep_duration_s:
time.sleep(self._sleep_duration_s)
raise RuntimeError("Test exception from timer")
class StateSamplerTest(unittest.TestCase):
# Due to somewhat non-deterministic nature of state sampling and sleep,
# this test is flaky when state duration is low.
# Since increasing state duration significantly would also slow down
# the test suite, we are retrying twice on failure as a mitigation.
@retry(reraise=True, stop=stop_after_attempt(3))
def test_basic_sampler(self):
# Set up state sampler.
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'basic', counter_factory, sampling_period_ms=1)
# Duration of the fastest state. Total test duration is 6 times longer.
state_duration_ms = 1000
margin_of_error = 0.25
# Run basic workload transitioning between 3 states.
sampler.start()
with sampler.scoped_state('step1', 'statea'):
time.sleep(state_duration_ms / 1000)
self.assertEqual(
sampler.current_state().name,
CounterName('statea-msecs', step_name='step1', stage_name='basic'))
with sampler.scoped_state('step1', 'stateb'):
time.sleep(state_duration_ms / 1000)
self.assertEqual(
sampler.current_state().name,
CounterName('stateb-msecs', step_name='step1', stage_name='basic'))
with sampler.scoped_state('step1', 'statec'):
time.sleep(3 * state_duration_ms / 1000)
self.assertEqual(
sampler.current_state().name,
CounterName(
'statec-msecs', step_name='step1', stage_name='basic'))
time.sleep(state_duration_ms / 1000)
sampler.stop()
sampler.commit_counters()
if not statesampler.FAST_SAMPLER:
# The slow sampler does not implement sampling, so we won't test it.
return
# Test that sampled state timings are close to their expected values.
# yapf: disable
expected_counter_values = {
CounterName('statea-msecs', step_name='step1', stage_name='basic'):
state_duration_ms,
CounterName('stateb-msecs', step_name='step1', stage_name='basic'): 2 *
state_duration_ms,
CounterName('statec-msecs', step_name='step1', stage_name='basic'): 3 *
state_duration_ms,
}
# yapf: enable
for counter in counter_factory.get_counters():
self.assertIn(counter.name, expected_counter_values)
expected_value = expected_counter_values[counter.name]
actual_value = counter.value()
deviation = float(abs(actual_value - expected_value)) / expected_value
_LOGGER.info('Sampling deviation from expectation: %f', deviation)
self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error))
self.assertLess(actual_value, expected_value * (1.0 + margin_of_error))
# TODO: This test is flaky when it is run under load. A better solution
# would be to change the test structure to not depend on specific timings.
@retry(reraise=True, stop=stop_after_attempt(3))
def test_sampler_transition_overhead(self):
# Set up state sampler.
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'overhead-', counter_factory, sampling_period_ms=10)
# Run basic workload transitioning between 3 states.
state_a = sampler.scoped_state('step1', 'statea')
state_b = sampler.scoped_state('step1', 'stateb')
state_c = sampler.scoped_state('step1', 'statec')
start_time = time.time()
sampler.start()
for _ in range(100000):
with state_a:
with state_b:
for _ in range(10):
with state_c:
pass
sampler.stop()
elapsed_time = time.time() - start_time
state_transition_count = sampler.get_info().transition_count
overhead_us = 1000000.0 * elapsed_time / state_transition_count
_LOGGER.info('Overhead per transition: %fus', overhead_us)
# Conservative upper bound on overhead in microseconds (we expect this to
# take 0.17us when compiled in opt mode or 0.48 us when compiled with in
# debug mode).
self.assertLess(overhead_us, 20.0)
@retry(reraise=True, stop=stop_after_attempt(3))
# Patch the problematic function to return the correct timer spec
@patch('apache_beam.transforms.userstate.get_dofn_specs')
def test_do_operation_process_timer(self, mock_get_dofn_specs):
fn = TimerDoFn()
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
if not statesampler.FAST_SAMPLER:
self.skipTest('DoOperation test requires FAST_SAMPLER')
state_duration_ms = 200
margin_of_error = 0.75
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'test_do_op', counter_factory, sampling_period_ms=1)
fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0)
spec = operation_specs.WorkerDoFn(
serialized_fn=pickler.dumps(
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
output_tags=[],
input=None,
side_inputs=[],
output_coders=[])
mock_user_state_context = mock.MagicMock()
op = operations.DoOperation(
common.NameContext('step1'),
spec,
counter_factory,
sampler,
user_state_context=mock_user_state_context)
op.setup()
timer_data = Mock()
timer_data.user_key = None
timer_data.windows = [GlobalWindow()]
timer_data.fire_timestamp = 0
timer_data.paneinfo = PaneInfo(
is_first=False,
is_last=False,
timing=0,
index=0,
nonspeculative_index=0)
timer_data.dynamic_timer_tag = ''
sampler.start()
op.process_timer('ts-timer', timer_data=timer_data)
sampler.stop()
sampler.commit_counters()
expected_name = CounterName(
'process-timers-msecs', step_name='step1', stage_name='test_do_op')
found_counter = None
for counter in counter_factory.get_counters():
if counter.name == expected_name:
found_counter = counter
break
self.assertIsNotNone(
found_counter, f"Expected counter '{expected_name}' to be created.")
actual_value = found_counter.value()
logging.info("Actual value %d", actual_value)
self.assertGreater(
actual_value, state_duration_ms * (1.0 - margin_of_error))
@retry(reraise=True, stop=stop_after_attempt(3))
@patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs')
def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs):
fn = ExceptionTimerDoFn()
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
if not statesampler.FAST_SAMPLER:
self.skipTest('DoOperation test requires FAST_SAMPLER')
state_duration_ms = 200
margin_of_error = 0.50
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'test_do_op_exception', counter_factory, sampling_period_ms=1)
fn_for_spec = ExceptionTimerDoFn(
sleep_duration_s=state_duration_ms / 1000.0)
spec = operation_specs.WorkerDoFn(
serialized_fn=pickler.dumps(
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
output_tags=[],
input=None,
side_inputs=[],
output_coders=[])
mock_user_state_context = mock.MagicMock()
op = operations.DoOperation(
common.NameContext('step1'),
spec,
counter_factory,
sampler,
user_state_context=mock_user_state_context)
op.setup()
timer_data = Mock()
timer_data.user_key = None
timer_data.windows = [GlobalWindow()]
timer_data.fire_timestamp = 0
timer_data.paneinfo = PaneInfo(
is_first=False,
is_last=False,
timing=0,
index=0,
nonspeculative_index=0)
timer_data.dynamic_timer_tag = ''
sampler.start()
# Assert that the expected exception is raised
with self.assertRaises(RuntimeError):
op.process_timer('ts-ts-timer', timer_data=timer_data)
sampler.stop()
sampler.commit_counters()
expected_name = CounterName(
'process-timers-msecs',
step_name='step1',
stage_name='test_do_op_exception')
found_counter = None
for counter in counter_factory.get_counters():
if counter.name == expected_name:
found_counter = counter
break
self.assertIsNotNone(
found_counter, f"Expected counter '{expected_name}' to be created.")
actual_value = found_counter.value()
self.assertGreater(
actual_value, state_duration_ms * (1.0 - margin_of_error))
_LOGGER.info("Exception test finished successfully.")
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()