| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """User-facing interfaces for the Beam State and Timer APIs. |
| |
| Experimental; no backwards-compatibility guarantees. |
| """ |
| |
| # pytype: skip-file |
| # mypy: disallow-untyped-defs |
| |
| from __future__ import absolute_import |
| |
| import collections |
| import types |
| from builtins import object |
| from typing import TYPE_CHECKING |
| from typing import Any |
| from typing import Callable |
| from typing import Dict |
| from typing import Iterable |
| from typing import NamedTuple |
| from typing import Optional |
| from typing import Set |
| from typing import Tuple |
| from typing import TypeVar |
| |
| from apache_beam.coders import Coder |
| from apache_beam.coders import coders |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.transforms.timeutil import TimeDomain |
| |
| if TYPE_CHECKING: |
| from apache_beam.runners.pipeline_context import PipelineContext |
| from apache_beam.transforms.core import CombineFn, DoFn |
| from apache_beam.utils import windowed_value |
| from apache_beam.utils.timestamp import Timestamp |
| |
| CallableT = TypeVar('CallableT', bound=Callable) |
| |
| |
| class StateSpec(object): |
| """Specification for a user DoFn state cell.""" |
| def __init__(self, name, coder): |
| # type: (str, Coder) -> None |
| if not isinstance(name, str): |
| raise TypeError("name is not a string") |
| if not isinstance(coder, Coder): |
| raise TypeError("coder is not of type Coder") |
| self.name = name |
| self.coder = coder |
| |
| def __repr__(self): |
| # type: () -> str |
| return '%s(%s)' % (self.__class__.__name__, self.name) |
| |
| def to_runner_api(self, context): |
| # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec |
| raise NotImplementedError |
| |
| |
| class ReadModifyWriteStateSpec(StateSpec): |
| """Specification for a user DoFn value state cell.""" |
| def to_runner_api(self, context): |
| # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec |
| return beam_runner_api_pb2.StateSpec( |
| read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( |
| coder_id=context.coders.get_id(self.coder))) |
| |
| |
| class BagStateSpec(StateSpec): |
| """Specification for a user DoFn bag state cell.""" |
| def to_runner_api(self, context): |
| # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec |
| return beam_runner_api_pb2.StateSpec( |
| bag_spec=beam_runner_api_pb2.BagStateSpec( |
| element_coder_id=context.coders.get_id(self.coder))) |
| |
| |
| class SetStateSpec(StateSpec): |
| """Specification for a user DoFn Set State cell""" |
| def to_runner_api(self, context): |
| # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec |
| return beam_runner_api_pb2.StateSpec( |
| set_spec=beam_runner_api_pb2.SetStateSpec( |
| element_coder_id=context.coders.get_id(self.coder))) |
| |
| |
| class CombiningValueStateSpec(StateSpec): |
| """Specification for a user DoFn combining value state cell.""" |
| def __init__(self, name, coder=None, combine_fn=None): |
| # type: (str, Optional[Coder], Any) -> None |
| |
| """Initialize the specification for CombiningValue state. |
| |
| CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value |
| state spec. |
| CombiningValueStateSpec(name, coder, combine_fn) -> Combining value state |
| spec with coder and combine_fn specified. |
| |
| Args: |
| name (str): The name by which the state is identified. |
| coder (Coder): Coder specifying how to encode the values to be combined. |
| May be inferred. |
| combine_fn (``CombineFn`` or ``callable``): Function specifying how to |
| combine the values passed to state. |
| """ |
| # Avoid circular import. |
| from apache_beam.transforms.core import CombineFn |
| # We want the coder to be optional, but unfortunately it comes |
| # before the non-optional combine_fn parameter, which we can't |
| # change for backwards compatibility reasons. |
| # |
| # Instead, allow it to be omitted (by either passing two arguments |
| # or combine_fn by keyword.) |
| if combine_fn is None: |
| if coder is None: |
| raise ValueError('combine_fn must be provided') |
| else: |
| coder, combine_fn = None, coder |
| self.combine_fn = CombineFn.maybe_from_callable(combine_fn) |
| # The coder here should be for the accumulator type of the given CombineFn. |
| if coder is None: |
| coder = self.combine_fn.get_accumulator_coder() |
| |
| super(CombiningValueStateSpec, self).__init__(name, coder) |
| |
| def to_runner_api(self, context): |
| # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec |
| return beam_runner_api_pb2.StateSpec( |
| combining_spec=beam_runner_api_pb2.CombiningStateSpec( |
| combine_fn=self.combine_fn.to_runner_api(context), |
| accumulator_coder_id=context.coders.get_id(self.coder))) |
| |
| |
| # TODO(BEAM-9562): Update Timer to have of() and clear() APIs. |
| Timer = NamedTuple( |
| 'Timer', |
| [ |
| ('user_key', Any), |
| ('dynamic_timer_tag', str), |
| ('windows', Tuple['windowed_value.BoundedWindow', ...]), |
| ('clear_bit', bool), |
| ('fire_timestamp', Optional['Timestamp']), |
| ('hold_timestamp', Optional['Timestamp']), |
| ('paneinfo', Optional['windowed_value.PaneInfo']), |
| ]) |
| |
| |
| # TODO(BEAM-9562): Plumb through actual key_coder and window_coder. |
| class TimerSpec(object): |
| """Specification for a user stateful DoFn timer.""" |
| prefix = "ts-" |
| |
| def __init__(self, name, time_domain): |
| # type: (str, str) -> None |
| self.name = self.prefix + name |
| if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): |
| raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, )) |
| self.time_domain = time_domain |
| self._attached_callback = None # type: Optional[Callable] |
| |
| def __repr__(self): |
| # type: () -> str |
| return '%s(%s)' % (self.__class__.__name__, self.name) |
| |
| def to_runner_api(self, context, key_coder, window_coder): |
| # type: (PipelineContext, Coder, Coder) -> beam_runner_api_pb2.TimerFamilySpec |
| return beam_runner_api_pb2.TimerFamilySpec( |
| time_domain=TimeDomain.to_runner_api(self.time_domain), |
| timer_family_coder_id=context.coders.get_id( |
| coders._TimerCoder(key_coder, window_coder))) |
| |
| |
| def on_timer(timer_spec): |
| # type: (TimerSpec) -> Callable[[CallableT], CallableT] |
| |
| """Decorator for timer firing DoFn method. |
| |
| This decorator allows a user to specify an on_timer processing method |
| in a stateful DoFn. Sample usage:: |
| |
| class MyDoFn(DoFn): |
| TIMER_SPEC = TimerSpec('timer', TimeDomain.WATERMARK) |
| |
| @on_timer(TIMER_SPEC) |
| def my_timer_expiry_callback(self): |
| logging.info('Timer expired!') |
| """ |
| |
| if not isinstance(timer_spec, TimerSpec): |
| raise ValueError('@on_timer decorator expected TimerSpec.') |
| |
| def _inner(method): |
| # type: (CallableT) -> CallableT |
| if not callable(method): |
| raise ValueError('@on_timer decorator expected callable.') |
| if timer_spec._attached_callback: |
| raise ValueError( |
| 'Multiple on_timer callbacks registered for %r.' % timer_spec) |
| timer_spec._attached_callback = method |
| return method |
| |
| return _inner |
| |
| |
| def get_dofn_specs(dofn): |
| # type: (DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]] |
| |
| """Gets the state and timer specs for a DoFn, if any. |
| |
| Args: |
| dofn (apache_beam.transforms.core.DoFn): The DoFn instance to introspect for |
| timer and state specs. |
| """ |
| |
| # Avoid circular import. |
| from apache_beam.runners.common import MethodWrapper |
| from apache_beam.transforms.core import _DoFnParam |
| from apache_beam.transforms.core import _StateDoFnParam |
| from apache_beam.transforms.core import _TimerDoFnParam |
| |
| all_state_specs = set() |
| all_timer_specs = set() |
| |
| # Validate params to process(), start_bundle(), finish_bundle() and to |
| # any on_timer callbacks. |
| for method_name in dir(dofn): |
| if not isinstance(getattr(dofn, method_name, None), types.MethodType): |
| continue |
| method = MethodWrapper(dofn, method_name) |
| param_ids = [ |
| d.param_id for d in method.defaults if isinstance(d, _DoFnParam) |
| ] |
| if len(param_ids) != len(set(param_ids)): |
| raise ValueError( |
| 'DoFn %r has duplicate %s method parameters: %s.' % |
| (dofn, method_name, param_ids)) |
| for d in method.defaults: |
| if isinstance(d, _StateDoFnParam): |
| all_state_specs.add(d.state_spec) |
| elif isinstance(d, _TimerDoFnParam): |
| all_timer_specs.add(d.timer_spec) |
| |
| return all_state_specs, all_timer_specs |
| |
| |
| def is_stateful_dofn(dofn): |
| # type: (DoFn) -> bool |
| |
| """Determines whether a given DoFn is a stateful DoFn.""" |
| |
| # A Stateful DoFn is a DoFn that uses user state or timers. |
| all_state_specs, all_timer_specs = get_dofn_specs(dofn) |
| return bool(all_state_specs or all_timer_specs) |
| |
| |
| def validate_stateful_dofn(dofn): |
| # type: (DoFn) -> None |
| |
| """Validates the proper specification of a stateful DoFn.""" |
| |
| # Get state and timer specs. |
| all_state_specs, all_timer_specs = get_dofn_specs(dofn) |
| |
| # Reject DoFns that have multiple state or timer specs with the same name. |
| if len(all_state_specs) != len(set(s.name for s in all_state_specs)): |
| raise ValueError( |
| 'DoFn %r has multiple StateSpecs with the same name: %s.' % |
| (dofn, all_state_specs)) |
| if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)): |
| raise ValueError( |
| 'DoFn %r has multiple TimerSpecs with the same name: %s.' % |
| (dofn, all_timer_specs)) |
| |
| # Reject DoFns that use timer specs without corresponding timer callbacks. |
| for timer_spec in all_timer_specs: |
| if not timer_spec._attached_callback: |
| raise ValueError(( |
| 'DoFn %r has a TimerSpec without an associated on_timer ' |
| 'callback: %s.') % (dofn, timer_spec)) |
| method_name = timer_spec._attached_callback.__name__ |
| if (timer_spec._attached_callback != getattr(dofn, method_name, |
| None).__func__): |
| raise ValueError(( |
| 'The on_timer callback for %s is not the specified .%s method ' |
| 'for DoFn %r (perhaps it was overwritten?).') % |
| (timer_spec, method_name, dofn)) |
| |
| |
| class BaseTimer(object): |
| def clear(self, dynamic_timer_tag=''): |
| # type: (str) -> None |
| raise NotImplementedError |
| |
| def set(self, timestamp, dynamic_timer_tag=''): |
| # type: (Timestamp, str) -> None |
| raise NotImplementedError |
| |
| |
| _TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) |
| |
| |
| class RuntimeTimer(BaseTimer): |
| """Timer interface object passed to user code.""" |
| def __init__(self) -> None: |
| self._timer_recordings = {} # type: Dict[str, _TimerTuple] |
| self._cleared = False |
| self._new_timestamp = None # type: Optional[Timestamp] |
| |
| def clear(self, dynamic_timer_tag=''): |
| # type: (str) -> None |
| self._timer_recordings[dynamic_timer_tag] = _TimerTuple( |
| cleared=True, timestamp=None) |
| |
| def set(self, timestamp, dynamic_timer_tag=''): |
| # type: (Timestamp, str) -> None |
| self._timer_recordings[dynamic_timer_tag] = _TimerTuple( |
| cleared=False, timestamp=timestamp) |
| |
| |
| class RuntimeState(object): |
| """State interface object passed to user code.""" |
| def prefetch(self): |
| # type: () -> None |
| # The default implementation here does nothing. |
| pass |
| |
| def finalize(self): |
| # type: () -> None |
| pass |
| |
| |
| class ReadModifyWriteRuntimeState(RuntimeState): |
| def read(self): |
| # type: () -> Any |
| raise NotImplementedError(type(self)) |
| |
| def write(self, value): |
| # type: (Any) -> None |
| raise NotImplementedError(type(self)) |
| |
| def clear(self): |
| # type: () -> None |
| raise NotImplementedError(type(self)) |
| |
| def commit(self): |
| # type: () -> None |
| raise NotImplementedError(type(self)) |
| |
| |
| class AccumulatingRuntimeState(RuntimeState): |
| def read(self): |
| # type: () -> Iterable[Any] |
| raise NotImplementedError(type(self)) |
| |
| def add(self, value): |
| # type: (Any) -> None |
| raise NotImplementedError(type(self)) |
| |
| def clear(self): |
| # type: () -> None |
| raise NotImplementedError(type(self)) |
| |
| def commit(self): |
| # type: () -> None |
| raise NotImplementedError(type(self)) |
| |
| |
| class BagRuntimeState(AccumulatingRuntimeState): |
| """Bag state interface object passed to user code.""" |
| |
| |
| class SetRuntimeState(AccumulatingRuntimeState): |
| """Set state interface object passed to user code.""" |
| |
| |
| class CombiningValueRuntimeState(AccumulatingRuntimeState): |
| """Combining value state interface object passed to user code.""" |
| |
| |
| class UserStateContext(object): |
| """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" |
| def get_timer(self, |
| timer_spec, # type: TimerSpec |
| key, # type: Any |
| window, # type: windowed_value.BoundedWindow |
| timestamp, # type: Timestamp |
| pane, # type: windowed_value.PaneInfo |
| ): |
| # type: (...) -> BaseTimer |
| raise NotImplementedError(type(self)) |
| |
| def get_state(self, |
| state_spec, # type: StateSpec |
| key, # type: Any |
| window, # type: windowed_value.BoundedWindow |
| ): |
| # type: (...) -> RuntimeState |
| raise NotImplementedError(type(self)) |
| |
| def commit(self): |
| # type: () -> None |
| raise NotImplementedError(type(self)) |