blob: 702d5c367228caa95722b66f3b9b4aee23eef4b6 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for the Beam State and Timer API interfaces."""
# pytype: skip-file
import unittest
from typing import Any
from typing import List
import mock
import apache_beam as beam
from apache_beam.coders import BytesCoder
from apache_beam.coders import IterableCoder
from apache_beam.coders import StrUtf8Coder
from apache_beam.coders import VarIntCoder
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.common import DoFnSignature
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import trigger
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.transforms.combiners import ToListCombineFn
from apache_beam.transforms.combiners import TopCombineFn
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
from apache_beam.transforms.userstate import SetStateSpec
from apache_beam.transforms.userstate import TimerSpec
from apache_beam.transforms.userstate import get_dofn_specs
from apache_beam.transforms.userstate import is_stateful_dofn
from apache_beam.transforms.userstate import on_timer
from apache_beam.transforms.userstate import validate_stateful_dofn
class TestStatefulDoFn(DoFn):
"""An example stateful DoFn with state and timers."""
BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder())
BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK)
EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family', TimeDomain.WATERMARK)
def process(
self,
element,
t=DoFn.TimestampParam,
buffer_1=DoFn.StateParam(BUFFER_STATE_1),
buffer_2=DoFn.StateParam(BUFFER_STATE_2),
timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
yield element
@on_timer(EXPIRY_TIMER_1)
def on_expiry_1(
self,
window=DoFn.WindowParam,
timestamp=DoFn.TimestampParam,
key=DoFn.KeyParam,
buffer=DoFn.StateParam(BUFFER_STATE_1),
timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
yield 'expired1'
@on_timer(EXPIRY_TIMER_2)
def on_expiry_2(
self,
buffer=DoFn.StateParam(BUFFER_STATE_2),
timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
yield 'expired2'
@on_timer(EXPIRY_TIMER_3)
def on_expiry_3(
self,
buffer_1=DoFn.StateParam(BUFFER_STATE_1),
buffer_2=DoFn.StateParam(BUFFER_STATE_2),
timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
yield 'expired3'
@on_timer(EXPIRY_TIMER_FAMILY)
def on_expiry_family(
self,
dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY),
dynamic_timer_tag=DoFn.DynamicTimerTagParam):
yield (dynamic_timer_tag, 'expired_dynamic_timer')
class InterfaceTest(unittest.TestCase):
def _validate_dofn(self, dofn):
# Construction of DoFnSignature performs validation of the given DoFn.
# In particular, it ends up calling userstate._validate_stateful_dofn.
# That behavior is explicitly tested below in test_validate_dofn()
return DoFnSignature(dofn)
@mock.patch('apache_beam.transforms.userstate.validate_stateful_dofn')
def test_validate_dofn(self, unused_mock):
dofn = TestStatefulDoFn()
self._validate_dofn(dofn)
userstate.validate_stateful_dofn.assert_called_with(dofn)
def test_spec_construction(self):
BagStateSpec('statename', VarIntCoder())
with self.assertRaises(TypeError):
BagStateSpec(123, VarIntCoder())
CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10))
with self.assertRaises(TypeError):
CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
with self.assertRaises(TypeError):
CombiningValueStateSpec('statename', VarIntCoder(), object())
SetStateSpec('setstatename', VarIntCoder())
with self.assertRaises(TypeError):
SetStateSpec(123, VarIntCoder())
with self.assertRaises(TypeError):
SetStateSpec('setstatename', object())
ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
with self.assertRaises(TypeError):
ReadModifyWriteStateSpec(123, VarIntCoder())
with self.assertRaises(TypeError):
ReadModifyWriteStateSpec('valuestatename', object())
# TODO: add more spec tests
with self.assertRaises(ValueError):
DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
TimerSpec('timer', TimeDomain.WATERMARK)
TimerSpec('timer', TimeDomain.REAL_TIME)
with self.assertRaises(ValueError):
TimerSpec('timer', 'bogus_time_domain')
with self.assertRaises(ValueError):
DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
def test_param_construction(self):
with self.assertRaises(ValueError):
DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
with self.assertRaises(ValueError):
DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
def test_stateful_dofn_detection(self):
self.assertFalse(is_stateful_dofn(DoFn()))
self.assertTrue(is_stateful_dofn(TestStatefulDoFn()))
def test_good_signatures(self):
class BasicStatefulDoFn(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family_1', TimeDomain.WATERMARK)
def process(
self,
element,
buffer=DoFn.StateParam(BUFFER_STATE),
timer1=DoFn.TimerParam(EXPIRY_TIMER),
dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
yield element
@on_timer(EXPIRY_TIMER)
def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)):
yield element
@on_timer(EXPIRY_TIMER_FAMILY)
def expiry_family_callback(
self, element, dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
yield element
# Validate get_dofn_specs() and timer callbacks in
# DoFnSignature.
stateful_dofn = BasicStatefulDoFn()
signature = self._validate_dofn(stateful_dofn)
expected_specs = (
set([BasicStatefulDoFn.BUFFER_STATE]),
set([
BasicStatefulDoFn.EXPIRY_TIMER,
BasicStatefulDoFn.EXPIRY_TIMER_FAMILY
]),
)
self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
self.assertEqual(
stateful_dofn.expiry_callback,
signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value)
self.assertEqual(
stateful_dofn.expiry_family_callback,
signature.timer_methods[
BasicStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value)
stateful_dofn = TestStatefulDoFn()
signature = self._validate_dofn(stateful_dofn)
expected_specs = (
set([TestStatefulDoFn.BUFFER_STATE_1, TestStatefulDoFn.BUFFER_STATE_2]),
set([
TestStatefulDoFn.EXPIRY_TIMER_1,
TestStatefulDoFn.EXPIRY_TIMER_2,
TestStatefulDoFn.EXPIRY_TIMER_3,
TestStatefulDoFn.EXPIRY_TIMER_FAMILY
]))
self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
self.assertEqual(
stateful_dofn.on_expiry_1,
signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value)
self.assertEqual(
stateful_dofn.on_expiry_2,
signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value)
self.assertEqual(
stateful_dofn.on_expiry_3,
signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
self.assertEqual(
stateful_dofn.on_expiry_family,
signature.timer_methods[
TestStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value)
def test_bad_signatures(self):
# (1) The same state parameter is duplicated on the process method.
class BadStatefulDoFn1(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
def process(
self,
element,
b1=DoFn.StateParam(BUFFER_STATE),
b2=DoFn.StateParam(BUFFER_STATE)):
yield element
with self.assertRaises(ValueError):
self._validate_dofn(BadStatefulDoFn1())
# (2) The same timer parameter is duplicated on the process method.
class BadStatefulDoFn2(DoFn):
TIMER = TimerSpec('timer', TimeDomain.WATERMARK)
def process(
self, element, t1=DoFn.TimerParam(TIMER), t2=DoFn.TimerParam(TIMER)):
yield element
with self.assertRaises(ValueError):
self._validate_dofn(BadStatefulDoFn2())
# (3) The same state parameter is duplicated on the on_timer method.
class BadStatefulDoFn3(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
@on_timer(EXPIRY_TIMER_1)
def expiry_callback(
self,
element,
b1=DoFn.StateParam(BUFFER_STATE),
b2=DoFn.StateParam(BUFFER_STATE)):
yield element
with self.assertRaises(ValueError):
self._validate_dofn(BadStatefulDoFn3())
# (4) The same timer parameter is duplicated on the on_timer method.
class BadStatefulDoFn4(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
@on_timer(EXPIRY_TIMER_1)
def expiry_callback(
self,
element,
t1=DoFn.TimerParam(EXPIRY_TIMER_2),
t2=DoFn.TimerParam(EXPIRY_TIMER_2)):
yield element
with self.assertRaises(ValueError):
self._validate_dofn(BadStatefulDoFn4())
# (5) The same timer family parameter is duplicated on the process method.
class BadStatefulDoFn5(DoFn):
EXPIRY_TIMER_FAMILY = TimerSpec('dynamic_timer', TimeDomain.WATERMARK)
def process(
self,
element,
dynamic_timer_1=DoFn.TimerParam(EXPIRY_TIMER_FAMILY),
dynamic_timer_2=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
yield element
with self.assertRaises(ValueError):
self._validate_dofn(BadStatefulDoFn5())
def test_validation_typos(self):
# (1) Here, the user mistakenly used the same timer spec twice for two
# different timer callbacks.
with self.assertRaisesRegex(
ValueError,
r'Multiple on_timer callbacks registered for TimerSpec\(.*expiry1\).'):
class StatefulDoFnWithTimerWithTypo1(DoFn): # pylint: disable=unused-variable
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
def process(self, element):
pass
@on_timer(EXPIRY_TIMER_1)
def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired1'
# Note that we mistakenly associate this with the first timer.
@on_timer(EXPIRY_TIMER_1)
def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired2'
# (2) Here, the user mistakenly used the same callback name and overwrote
# the first on_expiry_1 callback.
class StatefulDoFnWithTimerWithTypo2(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
def process(
self,
element,
timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
pass
@on_timer(EXPIRY_TIMER_1)
def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired1'
# Note that we mistakenly reuse the "on_expiry_1" name; this is valid
# syntactically in Python.
@on_timer(EXPIRY_TIMER_2)
def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired2'
# Use a stable string value for matching.
def __repr__(self):
return 'StatefulDoFnWithTimerWithTypo2'
dofn = StatefulDoFnWithTimerWithTypo2()
with self.assertRaisesRegex(
ValueError,
(r'The on_timer callback for TimerSpec\(.*expiry1\) is not the '
r'specified .on_expiry_1 method for DoFn '
r'StatefulDoFnWithTimerWithTypo2 \(perhaps it was overwritten\?\).')):
validate_stateful_dofn(dofn)
# (2) Here, the user forgot to add an on_timer decorator for 'expiry2'
class StatefulDoFnWithTimerWithTypo3(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
def process(
self,
element,
timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
pass
@on_timer(EXPIRY_TIMER_1)
def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired1'
def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
yield 'expired2'
# Use a stable string value for matching.
def __repr__(self):
return 'StatefulDoFnWithTimerWithTypo3'
dofn = StatefulDoFnWithTimerWithTypo3()
with self.assertRaisesRegex(
ValueError,
(r'DoFn StatefulDoFnWithTimerWithTypo3 has a TimerSpec without an '
r'associated on_timer callback: TimerSpec\(.*expiry2\).')):
validate_stateful_dofn(dofn)
class StatefulDoFnOnDirectRunnerTest(unittest.TestCase):
# pylint: disable=expression-not-assigned
all_records = None # type: List[Any]
def setUp(self):
# Use state on the TestCase class, since other references would be pickled
# into a closure and not have the desired side effects.
#
# TODO(BEAM-5295): Use assert_that after it works for the cases here in
# streaming mode.
StatefulDoFnOnDirectRunnerTest.all_records = []
def record_dofn(self):
class RecordDoFn(DoFn):
def process(self, element):
StatefulDoFnOnDirectRunnerTest.all_records.append(element)
return RecordDoFn()
def test_simple_stateful_dofn(self):
class SimpleTestStatefulDoFn(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
def process(
self,
element,
buffer=DoFn.StateParam(BUFFER_STATE),
timer1=DoFn.TimerParam(EXPIRY_TIMER)):
unused_key, value = element
buffer.add(b'A' + str(value).encode('latin1'))
timer1.set(20)
@on_timer(EXPIRY_TIMER)
def expiry_callback(
self,
buffer=DoFn.StateParam(BUFFER_STATE),
timer=DoFn.TimerParam(EXPIRY_TIMER)):
yield b''.join(sorted(buffer.read()))
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(10).add_elements(
[1,
2]).add_elements([3]).advance_watermark_to(25).add_elements([4]))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(SimpleTestStatefulDoFn())
| beam.ParDo(self.record_dofn()))
# Two firings should occur: once after element 3 since the timer should
# fire after the watermark passes time 20, and another time after element
# 4, since the timer issued at that point should fire immediately.
self.assertEqual([b'A1A2A3', b'A1A2A3A4'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_clearing_bag_state(self):
class BagStateClearingStatefulDoFn(beam.DoFn):
BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder())
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
def process(
self,
element,
bag_state=beam.DoFn.StateParam(BAG_STATE),
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
value = element[1]
bag_state.add(value)
clear_timer.set(100)
emit_timer.set(1000)
@on_timer(EMIT_TIMER)
def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
for value in bag_state.read():
yield value
yield 'extra'
@on_timer(CLEAR_TIMER)
def clear_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
bag_state.clear()
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('key', 'value')
]).advance_watermark_to(100))
_ = (
p
| test_stream
| beam.ParDo(BagStateClearingStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['extra'], StatefulDoFnOnDirectRunnerTest.all_records)
def test_two_timers_one_function(self):
class BagStateClearingStatefulDoFn(beam.DoFn):
BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder())
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
EMIT_TWICE_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
def process(
self,
element,
bag_state=beam.DoFn.StateParam(BAG_STATE),
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
emit_twice_timer=beam.DoFn.TimerParam(EMIT_TWICE_TIMER)):
value = element[1]
bag_state.add(value)
emit_twice_timer.set(100)
emit_timer.set(1000)
@on_timer(EMIT_TWICE_TIMER)
@on_timer(EMIT_TIMER)
def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
for value in bag_state.read():
yield value
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('key', 'value')
]).advance_watermark_to(100))
_ = (
p
| test_stream
| beam.ParDo(BagStateClearingStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['value', 'value'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_simple_read_modify_write_stateful_dofn(self):
class SimpleTestReadModifyWriteStatefulDoFn(DoFn):
VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder())
def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)):
last_element.write('%s:%s' % element)
yield last_element.read()
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('a', 1)
]).advance_watermark_to(10).add_elements([
('a', 3)
]).advance_watermark_to(20).add_elements([('a', 5)]))
(
p | test_stream
| beam.ParDo(SimpleTestReadModifyWriteStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['a:1', 'a:3', 'a:5'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_clearing_read_modify_write_state(self):
class SimpleClearingReadModifyWriteStatefulDoFn(DoFn):
VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder())
def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)):
value = last_element.read()
if value is not None:
yield value
last_element.clear()
last_element.write("%s:%s" % (last_element.read(), element[1]))
if element[1] == 5:
yield last_element.read()
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('a', 1)
]).advance_watermark_to(10).add_elements([
('a', 3)
]).advance_watermark_to(20).add_elements([('a', 5)]))
(
p | test_stream
| beam.ParDo(SimpleClearingReadModifyWriteStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['None:1', 'None:3', 'None:5'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_simple_set_stateful_dofn(self):
class SimpleTestSetStatefulDoFn(DoFn):
BUFFER_STATE = SetStateSpec('buffer', VarIntCoder())
EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
def process(
self,
element,
buffer=DoFn.StateParam(BUFFER_STATE),
timer1=DoFn.TimerParam(EXPIRY_TIMER)):
unused_key, value = element
buffer.add(value)
timer1.set(20)
@on_timer(EXPIRY_TIMER)
def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)):
yield sorted(buffer.read())
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(10).add_elements(
[1, 2, 3]).add_elements([2]).advance_watermark_to(24))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(SimpleTestSetStatefulDoFn())
| beam.ParDo(self.record_dofn()))
# Two firings should occur: once after element 3 since the timer should
# fire after the watermark passes time 20, and another time after element
# 4, since the timer issued at that point should fire immediately.
self.assertEqual([[1, 2, 3]], StatefulDoFnOnDirectRunnerTest.all_records)
def test_clearing_set_state(self):
class SetStateClearingStatefulDoFn(beam.DoFn):
SET_STATE = SetStateSpec('buffer', StrUtf8Coder())
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
def process(
self,
element,
set_state=beam.DoFn.StateParam(SET_STATE),
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
value = element[1]
set_state.add(value)
clear_timer.set(100)
emit_timer.set(1000)
@on_timer(EMIT_TIMER)
def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
for value in set_state.read():
yield value
@on_timer(CLEAR_TIMER)
def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
set_state.clear()
set_state.add('different-value')
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('key1', 'value1')
]).advance_watermark_to(100))
_ = (
p
| test_stream
| beam.ParDo(SetStateClearingStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['different-value'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_stateful_set_state_portably(self):
class SetStatefulDoFn(beam.DoFn):
SET_STATE = SetStateSpec('buffer', VarIntCoder())
def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)):
_, value = element
aggregated_value = 0
set_state.add(value)
for saved_value in set_state.read():
aggregated_value += saved_value
yield aggregated_value
with TestPipeline() as p:
values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4),
('key', 3)],
reshuffle=False)
actual_values = (values | beam.ParDo(SetStatefulDoFn()))
assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
def test_stateful_set_state_clean_portably(self):
class SetStateClearingStatefulDoFn(beam.DoFn):
SET_STATE = SetStateSpec('buffer', VarIntCoder())
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
def process(
self,
element,
set_state=beam.DoFn.StateParam(SET_STATE),
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
_, value = element
set_state.add(value)
all_elements = [element for element in set_state.read()]
if len(all_elements) == 5:
set_state.clear()
set_state.add(100)
emit_timer.set(1)
@on_timer(EMIT_TIMER)
def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
yield sorted(set_state.read())
with TestPipeline() as p:
values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4),
('key', 5)])
actual_values = (
values
| beam.Map(lambda t: window.TimestampedValue(t, 1))
| beam.WindowInto(window.FixedWindows(1))
| beam.ParDo(SetStateClearingStatefulDoFn()))
assert_that(actual_values, equal_to([[100]]))
def test_stateful_dofn_nonkeyed_input(self):
p = TestPipeline()
values = p | beam.Create([1, 2, 3])
with self.assertRaisesRegex(
ValueError,
('Input elements to the transform .* with stateful DoFn must be '
'key-value pairs.')):
values | beam.ParDo(TestStatefulDoFn())
def test_generate_sequence_with_realtime_timer(self):
from apache_beam.transforms.combiners import CountCombineFn
class GenerateRecords(beam.DoFn):
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.REAL_TIME)
COUNT_STATE = CombiningValueStateSpec(
'count_state', VarIntCoder(), CountCombineFn())
def __init__(self, frequency, total_records):
self.total_records = total_records
self.frequency = frequency
def process(self, element, emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
# Processing time timers should be set on ABSOLUTE TIME.
emit_timer.set(self.frequency)
yield element[1]
@on_timer(EMIT_TIMER)
def emit_values(
self,
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
count_state=beam.DoFn.StateParam(COUNT_STATE)):
count = count_state.read() or 0
if self.total_records == count:
return
count_state.add(1)
# Processing time timers should be set on ABSOLUTE TIME.
emit_timer.set(count + 1 + self.frequency)
yield 'value'
TOTAL_RECORDS = 3
FREQUENCY = 1
test_stream = (
TestStream().advance_watermark_to(0).add_elements([
('key', 0)
]).advance_processing_time(1) # Timestamp: 1
.add_elements([('key', 1)]).advance_processing_time(1) # Timestamp: 2
.add_elements([('key', 2)]).advance_processing_time(1) # Timestamp: 3
.add_elements([('key', 3)]))
with beam.Pipeline(argv=['--streaming', '--runner=DirectRunner']) as p:
_ = (
p
| test_stream
| beam.ParDo(GenerateRecords(FREQUENCY, TOTAL_RECORDS))
| beam.ParDo(self.record_dofn()))
self.assertEqual(
# 4 RECORDS go through process
# 3 values are emitted from timer
# Timestamp moves gradually.
[0, 'value', 1, 'value', 2, 'value', 3],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_simple_stateful_dofn_combining(self):
class SimpleTestStatefulDoFn(DoFn):
BUFFER_STATE = CombiningValueStateSpec(
'buffer', IterableCoder(VarIntCoder()), ToListCombineFn())
EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)
def process(
self,
element,
buffer=DoFn.StateParam(BUFFER_STATE),
timer1=DoFn.TimerParam(EXPIRY_TIMER)):
unused_key, value = element
buffer.add(value)
timer1.set(20)
@on_timer(EXPIRY_TIMER)
def expiry_callback(
self,
buffer=DoFn.StateParam(BUFFER_STATE),
timer=DoFn.TimerParam(EXPIRY_TIMER)):
yield ''.join(str(x) for x in sorted(buffer.read()))
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(10).add_elements(
[1,
2]).add_elements([3]).advance_watermark_to(25).add_elements([4]))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(SimpleTestStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual(['123', '1234'],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_timer_output_timestamp(self):
class TimerEmittingStatefulDoFn(DoFn):
EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK)
EMIT_TIMER_2 = TimerSpec('emit2', TimeDomain.WATERMARK)
EMIT_TIMER_3 = TimerSpec('emit3', TimeDomain.WATERMARK)
def process(
self,
element,
timer1=DoFn.TimerParam(EMIT_TIMER_1),
timer2=DoFn.TimerParam(EMIT_TIMER_2),
timer3=DoFn.TimerParam(EMIT_TIMER_3)):
timer1.set(10)
timer2.set(20)
timer3.set(30)
@on_timer(EMIT_TIMER_1)
def emit_callback_1(self):
yield 'timer1'
@on_timer(EMIT_TIMER_2)
def emit_callback_2(self):
yield 'timer2'
@on_timer(EMIT_TIMER_3)
def emit_callback_3(self):
yield 'timer3'
class TimestampReifyingDoFn(DoFn):
def process(self, element, ts=DoFn.TimestampParam):
yield (element, int(ts))
with TestPipeline() as p:
test_stream = (TestStream().advance_watermark_to(10).add_elements([1]))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(TimerEmittingStatefulDoFn())
| beam.ParDo(TimestampReifyingDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('timer1', 10), ('timer2', 20), ('timer3', 30)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_timer_output_timestamp_and_window(self):
class TimerEmittingStatefulDoFn(DoFn):
EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK)
def process(self, element, timer1=DoFn.TimerParam(EMIT_TIMER_1)):
timer1.set(10)
@on_timer(EMIT_TIMER_1)
def emit_callback_1(
self,
window=DoFn.WindowParam,
ts=DoFn.TimestampParam,
key=DoFn.KeyParam):
yield (
'timer1-{key}'.format(key=key),
int(ts),
int(window.start),
int(window.end))
pipeline_options = PipelineOptions()
with TestPipeline(options=pipeline_options) as p:
test_stream = (TestStream().advance_watermark_to(10).add_elements([1]))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| "window_into" >> beam.WindowInto(
window.FixedWindows(5),
accumulation_mode=trigger.AccumulationMode.DISCARDING)
| beam.ParDo(TimerEmittingStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('timer1-mykey', 10, 10, 15)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_timer_default_tag(self):
class DynamicTimerDoFn(DoFn):
EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
emit.set(10)
emit.set(20, dynamic_timer_tag='')
@on_timer(EMIT_TIMER_FAMILY)
def emit_callback(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
with TestPipeline() as p:
test_stream = (TestStream().advance_watermark_to(10).add_elements(
[1])).advance_watermark_to_infinity()
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(DynamicTimerDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('', 20)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_dynamic_timer_simple_dofn(self):
class DynamicTimerDoFn(DoFn):
EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
emit.set(10, dynamic_timer_tag='emit1')
emit.set(20, dynamic_timer_tag='emit2')
emit.set(30, dynamic_timer_tag='emit3')
@on_timer(EMIT_TIMER_FAMILY)
def emit_callback(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
with TestPipeline() as p:
test_stream = (TestStream().advance_watermark_to(10).add_elements(
[1])).advance_watermark_to_infinity()
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(DynamicTimerDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_dynamic_timer_clear_timer(self):
class DynamicTimerDoFn(DoFn):
EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
if element[1] == 'set':
emit.set(10, dynamic_timer_tag='emit1')
emit.set(20, dynamic_timer_tag='emit2')
emit.set(30, dynamic_timer_tag='emit3')
if element[1] == 'clear':
emit.clear(dynamic_timer_tag='emit3')
@on_timer(EMIT_TIMER_FAMILY)
def emit_callback(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(5).add_elements(
['set']).advance_watermark_to(10).add_elements(
['clear']).advance_watermark_to_infinity())
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(DynamicTimerDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('emit1', 10), ('emit2', 20)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_dynamic_timer_multiple(self):
class DynamicTimerDoFn(DoFn):
EMIT_TIMER_FAMILY1 = TimerSpec('emit_family_1', TimeDomain.WATERMARK)
EMIT_TIMER_FAMILY2 = TimerSpec('emit_family_2', TimeDomain.WATERMARK)
def process(
self,
element,
emit1=DoFn.TimerParam(EMIT_TIMER_FAMILY1),
emit2=DoFn.TimerParam(EMIT_TIMER_FAMILY2)):
emit1.set(10, dynamic_timer_tag='emit11')
emit1.set(20, dynamic_timer_tag='emit12')
emit1.set(30, dynamic_timer_tag='emit13')
emit2.set(30, dynamic_timer_tag='emit21')
emit2.set(20, dynamic_timer_tag='emit22')
emit2.set(10, dynamic_timer_tag='emit23')
@on_timer(EMIT_TIMER_FAMILY1)
def emit_callback(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
@on_timer(EMIT_TIMER_FAMILY2)
def emit_callback_2(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(5).add_elements(
['1']).advance_watermark_to_infinity())
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(DynamicTimerDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('emit11', 10), ('emit12', 20), ('emit13', 30),
('emit21', 30), ('emit22', 20), ('emit23', 10)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_dynamic_timer_and_simple_timer(self):
class DynamicTimerDoFn(DoFn):
EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
GC_TIMER = TimerSpec('gc', TimeDomain.WATERMARK)
def process(
self,
element,
emit=DoFn.TimerParam(EMIT_TIMER_FAMILY),
gc=DoFn.TimerParam(GC_TIMER)):
emit.set(10, dynamic_timer_tag='emit1')
emit.set(20, dynamic_timer_tag='emit2')
emit.set(30, dynamic_timer_tag='emit3')
gc.set(40)
@on_timer(EMIT_TIMER_FAMILY)
def emit_callback(
self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
yield (tag, ts)
@on_timer(GC_TIMER)
def gc(self, ts=DoFn.TimestampParam):
yield ('gc', ts)
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(5).add_elements(
['1']).advance_watermark_to_infinity())
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(DynamicTimerDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30), ('gc', 40)],
sorted(StatefulDoFnOnDirectRunnerTest.all_records))
def test_index_assignment(self):
class IndexAssigningStatefulDoFn(DoFn):
INDEX_STATE = CombiningValueStateSpec('index', sum)
def process(self, element, state=DoFn.StateParam(INDEX_STATE)):
unused_key, value = element
current_index = state.read()
yield (value, current_index)
state.add(1)
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(10).add_elements([
'A', 'B'
]).add_elements(['C']).advance_watermark_to(25).add_elements(['D']))
(
p
| test_stream
| beam.Map(lambda x: ('mykey', x))
| beam.ParDo(IndexAssigningStatefulDoFn())
| beam.ParDo(self.record_dofn()))
self.assertEqual([('A', 0), ('B', 1), ('C', 2), ('D', 3)],
StatefulDoFnOnDirectRunnerTest.all_records)
def test_hash_join(self):
class HashJoinStatefulDoFn(DoFn):
BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK)
def process(
self,
element,
state=DoFn.StateParam(BUFFER_STATE),
timer=DoFn.TimerParam(UNMATCHED_TIMER)):
key, value = element
existing_values = list(state.read())
if not existing_values:
state.add(value)
timer.set(100)
else:
yield b'Record<%s,%s,%s>' % (key, existing_values[0], value)
state.clear()
timer.clear()
@on_timer(UNMATCHED_TIMER)
def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)):
buffered = list(state.read())
assert len(buffered) == 1, buffered
state.clear()
yield b'Unmatched<%s>' % (buffered[0], )
with TestPipeline() as p:
test_stream = (
TestStream().advance_watermark_to(10).add_elements([
(b'A', b'a'), (b'B', b'b')
]).add_elements([
(b'A', b'aa'), (b'C', b'c')
]).advance_watermark_to(25).add_elements([
(b'A', b'aaa'), (b'B', b'bb')
]).add_elements([
(b'D', b'd'), (b'D', b'dd'), (b'D', b'ddd'), (b'D', b'dddd')
]).advance_watermark_to(125).add_elements([(b'C', b'cc')]))
(
p
| test_stream
| beam.ParDo(HashJoinStatefulDoFn())
| beam.ParDo(self.record_dofn()))
equal_to(StatefulDoFnOnDirectRunnerTest.all_records)([
b'Record<A,a,aa>',
b'Record<B,b,bb>',
b'Record<D,d,dd>',
b'Record<D,ddd,dddd>',
b'Unmatched<aaa>',
b'Unmatched<c>',
b'Unmatched<cc>'
])
if __name__ == '__main__':
unittest.main()