blob: 7a1cef4005e4b46f52b3501e6e7009b687e5ab3e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: language_level=3
"""Worker operations executor.
For internal use only; no backwards-compatibility guarantees.
"""
# pytype: skip-file
import collections
import copy
import logging
import sys
import threading
import traceback
from enum import Enum
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from apache_beam.coders import TupleCoder
from apache_beam.coders import coders
from apache_beam.internal import util
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
from apache_beam.transforms import DoFn
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms.core import RestrictionProvider
from apache_beam.transforms.core import WatermarkEstimatorProvider
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.typehints import typehints
from apache_beam.typehints.batch import BatchConverter
from apache_beam.utils.counters import Counter
from apache_beam.utils.counters import CounterName
from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import HomogeneousWindowedBatch
from apache_beam.utils.windowed_value import WindowedBatch
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
from apache_beam.runners.worker.bundle_processor import ExecutionContext
from apache_beam.transforms import sideinputs
from apache_beam.transforms.core import TimerSpec
from apache_beam.io.iobase import RestrictionProgress
from apache_beam.iobase import RestrictionTracker
from apache_beam.iobase import WatermarkEstimator
IMPULSE_VALUE_CODER_IMPL = coders.WindowedValueCoder(
coders.BytesCoder(), coders.GlobalWindowCoder()).get_impl()
ENCODED_IMPULSE_VALUE = IMPULSE_VALUE_CODER_IMPL.encode_nested(
GlobalWindows.windowed_value(b''))
_LOGGER = logging.getLogger(__name__)
class NameContext(object):
"""Holds the name information for a step."""
def __init__(self, step_name, transform_id=None):
# type: (str, Optional[str]) -> None
"""Creates a new step NameContext.
Args:
step_name: The name of the step.
"""
self.step_name = step_name
self.transform_id = transform_id
def __eq__(self, other):
return self.step_name == other.step_name
def __repr__(self):
return 'NameContext(%s)' % self.__dict__
def __hash__(self):
return hash(self.step_name)
def metrics_name(self):
"""Returns the step name used for metrics reporting."""
return self.step_name
def logging_name(self):
"""Returns the step name used for logging."""
return self.step_name
class Receiver(object):
"""For internal use only; no backwards-compatibility guarantees.
An object that consumes a WindowedValue.
This class can be efficiently used to pass values between the
sdk and worker harnesses.
"""
def receive(self, windowed_value):
# type: (WindowedValue) -> None
raise NotImplementedError
def receive_batch(self, windowed_batch):
# type: (WindowedBatch) -> None
raise NotImplementedError
def flush(self):
raise NotImplementedError
class MethodWrapper(object):
"""For internal use only; no backwards-compatibility guarantees.
Represents a method that can be invoked by `DoFnInvoker`."""
def __init__(self, obj_to_invoke, method_name):
"""
Initiates a ``MethodWrapper``.
Args:
obj_to_invoke: the object that contains the method. Has to either be a
`DoFn` object or a `RestrictionProvider` object.
method_name: name of the method as a string.
"""
if not isinstance(obj_to_invoke,
(DoFn, RestrictionProvider, WatermarkEstimatorProvider)):
raise ValueError(
'\'obj_to_invoke\' has to be either a \'DoFn\' or '
'a \'RestrictionProvider\'. Received %r instead.' % obj_to_invoke)
self.args, self.defaults = core.get_function_arguments(obj_to_invoke,
method_name)
# TODO(BEAM-5878) support kwonlyargs on Python 3.
self.method_value = getattr(obj_to_invoke, method_name)
self.method_name = method_name
self.has_userstate_arguments = False
self.state_args_to_replace = {} # type: Dict[str, core.StateSpec]
self.timer_args_to_replace = {} # type: Dict[str, core.TimerSpec]
self.timestamp_arg_name = None # type: Optional[str]
self.window_arg_name = None # type: Optional[str]
self.key_arg_name = None # type: Optional[str]
self.restriction_provider = None
self.restriction_provider_arg_name = None
self.watermark_estimator_provider = None
self.watermark_estimator_provider_arg_name = None
self.dynamic_timer_tag_arg_name = None
if hasattr(self.method_value, 'unbounded_per_element'):
self.unbounded_per_element = True
else:
self.unbounded_per_element = False
for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
if isinstance(v, core.DoFn.StateParam):
self.state_args_to_replace[kw] = v.state_spec
self.has_userstate_arguments = True
elif isinstance(v, core.DoFn.TimerParam):
self.timer_args_to_replace[kw] = v.timer_spec
self.has_userstate_arguments = True
elif core.DoFn.TimestampParam == v:
self.timestamp_arg_name = kw
elif core.DoFn.WindowParam == v:
self.window_arg_name = kw
elif core.DoFn.KeyParam == v:
self.key_arg_name = kw
elif isinstance(v, core.DoFn.RestrictionParam):
self.restriction_provider = v.restriction_provider or obj_to_invoke
self.restriction_provider_arg_name = kw
elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
self.watermark_estimator_provider = (
v.watermark_estimator_provider or obj_to_invoke)
self.watermark_estimator_provider_arg_name = kw
elif core.DoFn.DynamicTimerTagParam == v:
self.dynamic_timer_tag_arg_name = kw
# Create NoOpWatermarkEstimatorProvider if there is no
# WatermarkEstimatorParam provided.
if self.watermark_estimator_provider is None:
self.watermark_estimator_provider = NoOpWatermarkEstimatorProvider()
def invoke_timer_callback(
self,
user_state_context,
key,
window,
timestamp,
pane_info,
dynamic_timer_tag):
# TODO(ccy): support side inputs.
kwargs = {}
if self.has_userstate_arguments:
for kw, state_spec in self.state_args_to_replace.items():
kwargs[kw] = user_state_context.get_state(state_spec, key, window)
for kw, timer_spec in self.timer_args_to_replace.items():
kwargs[kw] = user_state_context.get_timer(
timer_spec, key, window, timestamp, pane_info)
if self.timestamp_arg_name:
kwargs[self.timestamp_arg_name] = Timestamp.of(timestamp)
if self.window_arg_name:
kwargs[self.window_arg_name] = window
if self.key_arg_name:
kwargs[self.key_arg_name] = key
if self.dynamic_timer_tag_arg_name:
kwargs[self.dynamic_timer_tag_arg_name] = dynamic_timer_tag
if kwargs:
return self.method_value(**kwargs)
else:
return self.method_value()
class BatchingPreference(Enum):
DO_NOT_CARE = 1 # This operation can operate on batches or element-at-a-time
# TODO: Should we also store batching parameters here? (time/size preferences)
BATCH_REQUIRED = 2 # This operation can only operate on batches
BATCH_FORBIDDEN = 3 # This operation can only work element-at-a-time
# Other possibilities: BATCH_PREFERRED (with min batch size specified)
@property
def supports_batches(self) -> bool:
return self in (self.BATCH_REQUIRED, self.DO_NOT_CARE)
@property
def supports_elements(self) -> bool:
return self in (self.BATCH_FORBIDDEN, self.DO_NOT_CARE)
@property
def requires_batches(self) -> bool:
return self == self.BATCH_REQUIRED
class DoFnSignature(object):
"""Represents the signature of a given ``DoFn`` object.
Signature of a ``DoFn`` provides a view of the properties of a given ``DoFn``.
Among other things, this will give an extensible way for for (1) accessing the
structure of the ``DoFn`` including methods and method parameters
(2) identifying features that a given ``DoFn`` support, for example, whether
a given ``DoFn`` is a Splittable ``DoFn`` (
https://s.apache.org/splittable-do-fn) (3) validating a ``DoFn`` based on the
feature set offered by it.
"""
def __init__(self, do_fn):
# type: (core.DoFn) -> None
# We add a property here for all methods defined by Beam DoFn features.
assert isinstance(do_fn, core.DoFn)
self.do_fn = do_fn
self.process_method = MethodWrapper(do_fn, 'process')
self.process_batch_method = MethodWrapper(do_fn, 'process_batch')
self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')
self.setup_lifecycle_method = MethodWrapper(do_fn, 'setup')
self.teardown_lifecycle_method = MethodWrapper(do_fn, 'teardown')
restriction_provider = self.get_restriction_provider()
watermark_estimator_provider = self.get_watermark_estimator_provider()
self.create_watermark_estimator_method = (
MethodWrapper(
watermark_estimator_provider, 'create_watermark_estimator'))
self.initial_restriction_method = (
MethodWrapper(restriction_provider, 'initial_restriction')
if restriction_provider else None)
self.create_tracker_method = (
MethodWrapper(restriction_provider, 'create_tracker')
if restriction_provider else None)
self.split_method = (
MethodWrapper(restriction_provider, 'split')
if restriction_provider else None)
self._validate()
# Handle stateful DoFns.
self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
self.timer_methods = {} # type: Dict[TimerSpec, MethodWrapper]
if self._is_stateful_dofn:
# Populate timer firing methods, keyed by TimerSpec.
_, all_timer_specs = userstate.get_dofn_specs(do_fn)
for timer_spec in all_timer_specs:
method = timer_spec._attached_callback
self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
def get_restriction_provider(self):
# type: () -> RestrictionProvider
return self.process_method.restriction_provider
def get_watermark_estimator_provider(self):
# type: () -> WatermarkEstimatorProvider
return self.process_method.watermark_estimator_provider
def is_unbounded_per_element(self):
return self.process_method.unbounded_per_element
def _validate(self):
# type: () -> None
self._validate_process()
self._validate_process_batch()
self._validate_bundle_method(self.start_bundle_method)
self._validate_bundle_method(self.finish_bundle_method)
self._validate_stateful_dofn()
def _check_duplicate_dofn_params(self, method: MethodWrapper):
param_ids = [
d.param_id for d in method.defaults if isinstance(d, core._DoFnParam)
]
if len(param_ids) != len(set(param_ids)):
raise ValueError(
'DoFn %r has duplicate %s method parameters: %s.' %
(self.do_fn, method.method_name, param_ids))
def _validate_process(self):
# type: () -> None
"""Validate that none of the DoFnParameters are repeated in the function
"""
self._check_duplicate_dofn_params(self.process_method)
def _validate_process_batch(self):
# type: () -> None
self._check_duplicate_dofn_params(self.process_batch_method)
for d in self.process_batch_method.defaults:
if not isinstance(d, core._DoFnParam):
continue
# Helpful errors for params which will be supported in the future
if d == (core.DoFn.ElementParam):
# We currently assume we can just get the typehint from the first
# parameter. ElementParam breaks this assumption
raise NotImplementedError(
f"DoFn {self.do_fn!r} uses unsupported DoFn param ElementParam.")
if d in (core.DoFn.KeyParam, core.DoFn.StateParam, core.DoFn.TimerParam):
raise NotImplementedError(
f"DoFn {self.do_fn!r} has unsupported per-key DoFn param {d}. "
"Per-key DoFn params are not yet supported for process_batch "
"(https://github.com/apache/beam/issues/21653).")
# Fallback to catch anything not explicitly supported
if not d in (core.DoFn.WindowParam,
core.DoFn.TimestampParam,
core.DoFn.PaneInfoParam):
raise ValueError(
f"DoFn {self.do_fn!r} has unsupported process_batch "
f"method parameter {d}")
def _validate_bundle_method(self, method_wrapper):
"""Validate that none of the DoFnParameters are used in the function
"""
for param in core.DoFn.DoFnProcessParams:
if param in method_wrapper.defaults:
raise ValueError(
'DoFn.process() method-only parameter %s cannot be used in %s.' %
(param, method_wrapper))
def _validate_stateful_dofn(self):
# type: () -> None
userstate.validate_stateful_dofn(self.do_fn)
def is_splittable_dofn(self):
# type: () -> bool
return self.get_restriction_provider() is not None
def get_restriction_coder(self):
# type: () -> Optional[TupleCoder]
"""Get coder for a restriction when processing an SDF. """
if self.is_splittable_dofn():
return TupleCoder([
(self.get_restriction_provider().restriction_coder()),
(self.get_watermark_estimator_provider().estimator_state_coder())
])
else:
return None
def is_stateful_dofn(self):
# type: () -> bool
return self._is_stateful_dofn
def has_timers(self):
# type: () -> bool
_, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
return bool(all_timer_specs)
def has_bundle_finalization(self):
for sig in (self.start_bundle_method,
self.process_method,
self.finish_bundle_method):
for d in sig.defaults:
try:
if d == DoFn.BundleFinalizerParam:
return True
except Exception: # pylint: disable=broad-except
# Default value might be incomparable.
pass
return False
class DoFnInvoker(object):
"""An abstraction that can be used to execute DoFn methods.
A DoFnInvoker describes a particular way for invoking methods of a DoFn
represented by a given DoFnSignature."""
def __init__(self,
output_handler, # type: _OutputHandler
signature # type: DoFnSignature
):
# type: (...) -> None
"""
Initializes `DoFnInvoker`
:param output_handler: an OutputHandler for receiving elements produced
by invoking functions of the DoFn.
:param signature: a DoFnSignature for the DoFn being invoked
"""
self.output_handler = output_handler
self.signature = signature
self.user_state_context = None # type: Optional[userstate.UserStateContext]
self.bundle_finalizer_param = None # type: Optional[core._BundleFinalizerParam]
@staticmethod
def create_invoker(
signature, # type: DoFnSignature
output_handler, # type: OutputHandler
context=None, # type: Optional[DoFnContext]
side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]]
input_args=None, input_kwargs=None,
process_invocation=True,
user_state_context=None, # type: Optional[userstate.UserStateContext]
bundle_finalizer_param=None # type: Optional[core._BundleFinalizerParam]
):
# type: (...) -> DoFnInvoker
""" Creates a new DoFnInvoker based on given arguments.
Args:
output_handler: an OutputHandler for receiving elements produced by
invoking functions of the DoFn.
signature: a DoFnSignature for the DoFn being invoked.
context: Context to be used when invoking the DoFn (deprecated).
side_inputs: side inputs to be used when invoking th process method.
input_args: arguments to be used when invoking the process method. Some
of the arguments given here might be placeholders (for
example for side inputs) that get filled before invoking the
process method.
input_kwargs: keyword arguments to be used when invoking the process
method. Some of the keyword arguments given here might be
placeholders (for example for side inputs) that get filled
before invoking the process method.
process_invocation: If True, this function may return an invoker that
performs extra optimizations for invoking process()
method efficiently.
user_state_context: The UserStateContext instance for the current
Stateful DoFn.
bundle_finalizer_param: The param that passed to a process method, which
allows a callback to be registered.
"""
side_inputs = side_inputs or []
use_per_window_invoker = process_invocation and (
side_inputs or input_args or input_kwargs or
signature.process_method.defaults or
signature.process_batch_method.defaults or signature.is_stateful_dofn())
if not use_per_window_invoker:
return SimpleInvoker(output_handler, signature)
else:
if context is None:
raise TypeError("Must provide context when not using SimpleInvoker")
return PerWindowInvoker(
output_handler,
signature,
context,
side_inputs,
input_args,
input_kwargs,
user_state_context,
bundle_finalizer_param)
def invoke_process(self,
windowed_value, # type: WindowedValue
restriction=None,
watermark_estimator_state=None,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Iterable[SplitResultResidual]
"""Invokes the DoFn.process() function.
Args:
windowed_value: a WindowedValue object that gives the element for which
process() method should be invoked along with the window
the element belongs to.
restriction: The restriction to use when executing this splittable DoFn.
Should only be specified for splittable DoFns.
watermark_estimator_state: The watermark estimator state to use when
executing this splittable DoFn. Should only
be specified for splittable DoFns.
additional_args: additional arguments to be passed to the current
`DoFn.process()` invocation, usually as side inputs.
additional_kwargs: additional keyword arguments to be passed to the
current `DoFn.process()` invocation.
"""
raise NotImplementedError
def invoke_process_batch(self,
windowed_batch, # type: WindowedBatch
additional_args=None,
additional_kwargs=None
):
# type: (...) -> None
"""Invokes the DoFn.process() function.
Args:
windowed_batch: a WindowedBatch object that gives a batch of elements for
which process_batch() method should be invoked, along with
the window each element belongs to.
additional_args: additional arguments to be passed to the current
`DoFn.process()` invocation, usually as side inputs.
additional_kwargs: additional keyword arguments to be passed to the
current `DoFn.process()` invocation.
"""
raise NotImplementedError
def invoke_setup(self):
# type: () -> None
"""Invokes the DoFn.setup() method
"""
self.signature.setup_lifecycle_method.method_value()
def invoke_start_bundle(self):
# type: () -> None
"""Invokes the DoFn.start_bundle() method.
"""
self.output_handler.start_bundle_outputs(
self.signature.start_bundle_method.method_value())
def invoke_finish_bundle(self):
# type: () -> None
"""Invokes the DoFn.finish_bundle() method.
"""
self.output_handler.finish_bundle_outputs(
self.signature.finish_bundle_method.method_value())
def invoke_teardown(self):
# type: () -> None
"""Invokes the DoFn.teardown() method
"""
self.signature.teardown_lifecycle_method.method_value()
def invoke_user_timer(
self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
# self.output_handler is Optional, but in practice it won't be None here
self.output_handler.handle_process_outputs(
WindowedValue(None, timestamp, (window, )),
self.signature.timer_methods[timer_spec].invoke_timer_callback(
self.user_state_context,
key,
window,
timestamp,
pane_info,
dynamic_timer_tag))
def invoke_create_watermark_estimator(self, estimator_state):
return self.signature.create_watermark_estimator_method.method_value(
estimator_state)
def invoke_split(self, element, restriction):
return self.signature.split_method.method_value(element, restriction)
def invoke_initial_restriction(self, element):
return self.signature.initial_restriction_method.method_value(element)
def invoke_create_tracker(self, restriction):
return self.signature.create_tracker_method.method_value(restriction)
class SimpleInvoker(DoFnInvoker):
"""An invoker that processes elements ignoring windowing information."""
def __init__(self,
output_handler, # type: OutputHandler
signature # type: DoFnSignature
):
# type: (...) -> None
super().__init__(output_handler, signature)
self.process_method = signature.process_method.method_value
self.process_batch_method = signature.process_batch_method.method_value
def invoke_process(self,
windowed_value, # type: WindowedValue
restriction=None,
watermark_estimator_state=None,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Iterable[SplitResultResidual]
self.output_handler.handle_process_outputs(
windowed_value, self.process_method(windowed_value.value))
return []
def invoke_process_batch(self,
windowed_batch, # type: WindowedBatch
restriction=None,
watermark_estimator_state=None,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> None
self.output_handler.handle_process_batch_outputs(
windowed_batch, self.process_batch_method(windowed_batch.values))
def _get_arg_placeholders(
method: MethodWrapper,
input_args: Optional[List[Any]],
input_kwargs: Optional[Dict[str, any]]):
input_args = input_args if input_args else []
input_kwargs = input_kwargs if input_kwargs else {}
arg_names = method.args
default_arg_values = method.defaults
# Create placeholder for element parameter of DoFn.process() method.
# Not to be confused with ArgumentPlaceHolder, which may be passed in
# input_args and is a placeholder for side-inputs.
class ArgPlaceholder(object):
def __init__(self, placeholder):
self.placeholder = placeholder
if all(core.DoFn.ElementParam != arg for arg in default_arg_values):
# TODO(https://github.com/apache/beam/issues/19631): Handle cases in which
# len(arg_names) == len(default_arg_values).
args_to_pick = len(arg_names) - len(default_arg_values) - 1
# Positional argument values for process(), with placeholders for special
# values such as the element, timestamp, etc.
args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] +
input_args[:args_to_pick])
else:
args_to_pick = len(arg_names) - len(default_arg_values)
args_with_placeholders = input_args[:args_to_pick]
# Fill the OtherPlaceholders for context, key, window or timestamp
remaining_args_iter = iter(input_args[args_to_pick:])
for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values):
if core.DoFn.ElementParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
elif core.DoFn.KeyParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
elif core.DoFn.WindowParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
elif core.DoFn.TimestampParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
elif core.DoFn.PaneInfoParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
elif core.DoFn.SideInputParam == d:
# If no more args are present then the value must be passed via kwarg
try:
args_with_placeholders.append(next(remaining_args_iter))
except StopIteration:
if a not in input_kwargs:
raise ValueError("Value for sideinput %s not provided" % a)
elif isinstance(d, core.DoFn.StateParam):
args_with_placeholders.append(ArgPlaceholder(d))
elif isinstance(d, core.DoFn.TimerParam):
args_with_placeholders.append(ArgPlaceholder(d))
elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d:
args_with_placeholders.append(ArgPlaceholder(d))
else:
# If no more args are present then the value must be passed via kwarg
try:
args_with_placeholders.append(next(remaining_args_iter))
except StopIteration:
pass
args_with_placeholders.extend(list(remaining_args_iter))
# Stash the list of placeholder positions for performance
placeholders = [(i, x.placeholder)
for (i, x) in enumerate(args_with_placeholders)
if isinstance(x, ArgPlaceholder)]
return placeholders, args_with_placeholders, input_kwargs
class PerWindowInvoker(DoFnInvoker):
"""An invoker that processes elements considering windowing information."""
def __init__(self,
output_handler, # type: OutputHandler
signature, # type: DoFnSignature
context, # type: DoFnContext
side_inputs, # type: Iterable[sideinputs.SideInputMap]
input_args,
input_kwargs,
user_state_context, # type: Optional[userstate.UserStateContext]
bundle_finalizer_param # type: Optional[core._BundleFinalizerParam]
):
super().__init__(output_handler, signature)
self.side_inputs = side_inputs
self.context = context
self.process_method = signature.process_method.method_value
default_arg_values = signature.process_method.defaults
self.has_windowed_inputs = (
not all(si.is_globally_windowed() for si in side_inputs) or any(
core.DoFn.WindowParam == arg
for arg in signature.process_method.defaults) or any(
core.DoFn.WindowParam == arg
for arg in signature.process_batch_method.defaults) or
signature.is_stateful_dofn())
self.user_state_context = user_state_context
self.is_splittable = signature.is_splittable_dofn()
self.is_key_param_required = any(
core.DoFn.KeyParam == arg for arg in default_arg_values)
self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker]
self.threadsafe_watermark_estimator = None # type: Optional[ThreadsafeWatermarkEstimator]
self.current_windowed_value = None # type: Optional[WindowedValue]
self.bundle_finalizer_param = bundle_finalizer_param
if self.is_splittable:
self.splitting_lock = threading.Lock()
self.current_window_index = None
self.stop_window_index = None
# TODO(https://github.com/apache/beam/issues/28776): Remove caching after
# fully rolling out.
# If true, always recalculate window args. If false, has_cached_window_args
# and has_cached_window_batch_args will be set to true if the corresponding
# self.args_for_process,have been updated and should be reused directly.
self.recalculate_window_args = (
self.has_windowed_inputs or 'disable_global_windowed_args_caching' in
RuntimeValueProvider.experiments)
self.has_cached_window_args = False
self.has_cached_window_batch_args = False
# Try to prepare all the arguments that can just be filled in
# without any additional work. in the process function.
# Also cache all the placeholders needed in the process function.
input_args = list(input_args)
(
self.placeholders_for_process,
self.args_for_process,
self.kwargs_for_process) = _get_arg_placeholders(
signature.process_method, input_args, input_kwargs)
self.process_batch_method = signature.process_batch_method.method_value
(
self.placeholders_for_process_batch,
self.args_for_process_batch,
self.kwargs_for_process_batch) = _get_arg_placeholders(
signature.process_batch_method, input_args, input_kwargs)
def invoke_process(self,
windowed_value, # type: WindowedValue
restriction=None,
watermark_estimator_state=None,
additional_args=None,
additional_kwargs=None
):
# type: (...) -> Iterable[SplitResultResidual]
if not additional_args:
additional_args = []
if not additional_kwargs:
additional_kwargs = {}
self.context.set_element(windowed_value)
# Call for the process function for each window if has windowed side inputs
# or if the process accesses the window parameter. We can just call it once
# otherwise as none of the arguments are changing
residuals = []
if self.is_splittable:
if restriction is None:
# This may be a SDF invoked as an ordinary DoFn on runners that don't
# understand SDF. See, e.g. BEAM-11472.
# In this case, processing the element is simply processing it against
# the entire initial restriction.
restriction = self.signature.initial_restriction_method.method_value(
windowed_value.value)
with self.splitting_lock:
self.current_windowed_value = windowed_value
self.restriction = restriction
self.watermark_estimator_state = watermark_estimator_state
try:
if self.has_windowed_inputs and len(windowed_value.windows) > 1:
for i, w in enumerate(windowed_value.windows):
if not self._should_process_window_for_sdf(
windowed_value, additional_kwargs, i):
break
residual = self._invoke_process_per_window(
WindowedValue(
windowed_value.value, windowed_value.timestamp, (w, )),
additional_args,
additional_kwargs)
if residual:
residuals.append(residual)
else:
if self._should_process_window_for_sdf(windowed_value,
additional_kwargs):
residual = self._invoke_process_per_window(
windowed_value, additional_args, additional_kwargs)
if residual:
residuals.append(residual)
finally:
with self.splitting_lock:
self.current_windowed_value = None
self.restriction = None
self.watermark_estimator_state = None
self.current_window_index = None
self.threadsafe_restriction_tracker = None
self.threadsafe_watermark_estimator = None
elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
for w in windowed_value.windows:
self._invoke_process_per_window(
WindowedValue(
windowed_value.value, windowed_value.timestamp, (w, )),
additional_args,
additional_kwargs)
else:
self._invoke_process_per_window(
windowed_value, additional_args, additional_kwargs)
return residuals
def invoke_process_batch(self,
windowed_batch, # type: WindowedBatch
additional_args=None,
additional_kwargs=None
):
# type: (...) -> None
if not additional_args:
additional_args = []
if not additional_kwargs:
additional_kwargs = {}
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
if self.has_windowed_inputs and len(windowed_batch.windows) != 1:
for w in windowed_batch.windows:
self._invoke_process_batch_per_window(
HomogeneousWindowedBatch.of(
windowed_batch.values,
windowed_batch.timestamp, (w, ),
windowed_batch.pane_info),
additional_args,
additional_kwargs)
else:
self._invoke_process_batch_per_window(
windowed_batch, additional_args, additional_kwargs)
def _should_process_window_for_sdf(
self,
windowed_value, # type: WindowedValue
additional_kwargs,
window_index=None, # type: Optional[int]
):
restriction_tracker = self.invoke_create_tracker(self.restriction)
watermark_estimator = self.invoke_create_watermark_estimator(
self.watermark_estimator_state)
with self.splitting_lock:
if window_index:
self.current_window_index = window_index
if window_index == 0:
self.stop_window_index = len(windowed_value.windows)
if window_index == self.stop_window_index:
return False
self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker(
restriction_tracker)
self.threadsafe_watermark_estimator = (
ThreadsafeWatermarkEstimator(watermark_estimator))
restriction_tracker_param = (
self.signature.process_method.restriction_provider_arg_name)
if not restriction_tracker_param:
raise ValueError(
'DoFn is splittable but DoFn does not have a '
'RestrictionTrackerParam defined')
additional_kwargs[restriction_tracker_param] = (
RestrictionTrackerView(self.threadsafe_restriction_tracker))
watermark_param = (
self.signature.process_method.watermark_estimator_provider_arg_name)
# When the watermark_estimator is a NoOpWatermarkEstimator, the system
# will not add watermark_param into the DoFn param list.
if watermark_param is not None:
additional_kwargs[watermark_param] = self.threadsafe_watermark_estimator
return True
def _invoke_process_per_window(self,
windowed_value, # type: WindowedValue
additional_args,
additional_kwargs,
):
# type: (...) -> Optional[SplitResultResidual]
if self.has_cached_window_args:
args_for_process, kwargs_for_process = (
self.args_for_process, self.kwargs_for_process)
else:
if self.has_windowed_inputs:
assert len(windowed_value.windows) <= 1
window, = windowed_value.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process, kwargs_for_process = util.insert_values_in_args(
self.args_for_process, self.kwargs_for_process, side_inputs)
if not self.recalculate_window_args:
self.args_for_process, self.kwargs_for_process = (
args_for_process, kwargs_for_process)
self.has_cached_window_args = True
# Extract key in the case of a stateful DoFn. Note that in the case of a
# stateful DoFn, we set during __init__ self.has_windowed_inputs to be
# True. Therefore, windows will be exploded coming into this method, and
# we can rely on the window variable being set above.
if self.user_state_context or self.is_key_param_required:
try:
key, unused_value = windowed_value.value
except (TypeError, ValueError):
raise ValueError((
'Input value to a stateful DoFn or KeyParam must be a KV tuple; '
'instead, got \'%s\'.') % (windowed_value.value, ))
for i, p in self.placeholders_for_process:
if core.DoFn.ElementParam == p:
args_for_process[i] = windowed_value.value
elif core.DoFn.KeyParam == p:
args_for_process[i] = key
elif core.DoFn.WindowParam == p:
args_for_process[i] = window
elif core.DoFn.TimestampParam == p:
args_for_process[i] = windowed_value.timestamp
elif core.DoFn.PaneInfoParam == p:
args_for_process[i] = windowed_value.pane_info
elif isinstance(p, core.DoFn.StateParam):
assert self.user_state_context is not None
args_for_process[i] = (
self.user_state_context.get_state(p.state_spec, key, window))
elif isinstance(p, core.DoFn.TimerParam):
assert self.user_state_context is not None
args_for_process[i] = (
self.user_state_context.get_timer(
p.timer_spec,
key,
window,
windowed_value.timestamp,
windowed_value.pane_info))
elif core.DoFn.BundleFinalizerParam == p:
args_for_process[i] = self.bundle_finalizer_param
kwargs_for_process = kwargs_for_process or {}
if additional_kwargs:
kwargs_for_process.update(additional_kwargs)
self.output_handler.handle_process_outputs(
windowed_value,
self.process_method(*args_for_process, **kwargs_for_process),
self.threadsafe_watermark_estimator)
if self.is_splittable:
assert self.threadsafe_restriction_tracker is not None
self.threadsafe_restriction_tracker.check_done()
deferred_status = self.threadsafe_restriction_tracker.deferred_status()
if deferred_status:
deferred_restriction, deferred_timestamp = deferred_status
element = windowed_value.value
size = self.signature.get_restriction_provider().restriction_size(
element, deferred_restriction)
if size < 0:
raise ValueError('Expected size >= 0 but received %s.' % size)
current_watermark = (
self.threadsafe_watermark_estimator.current_watermark())
estimator_state = (
self.threadsafe_watermark_estimator.get_estimator_state())
residual_value = ((element, (deferred_restriction, estimator_state)),
size)
return SplitResultResidual(
residual_value=windowed_value.with_value(residual_value),
current_watermark=current_watermark,
deferred_timestamp=deferred_timestamp)
return None
def _invoke_process_batch_per_window(
self,
windowed_batch: WindowedBatch,
additional_args,
additional_kwargs,
):
# type: (...) -> Optional[SplitResultResidual]
if self.has_cached_window_batch_args:
args_for_process_batch, kwargs_for_process_batch = (
self.args_for_process_batch, self.kwargs_for_process_batch)
else:
if self.has_windowed_inputs:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
assert len(windowed_batch.windows) <= 1
window, = windowed_batch.windows
else:
window = GlobalWindow()
side_inputs = [si[window] for si in self.side_inputs]
side_inputs.extend(additional_args)
args_for_process_batch, kwargs_for_process_batch = (
util.insert_values_in_args(
self.args_for_process_batch,
self.kwargs_for_process_batch,
side_inputs,
)
)
if not self.recalculate_window_args:
self.args_for_process_batch, self.kwargs_for_process_batch = (
args_for_process_batch, kwargs_for_process_batch)
self.has_cached_window_batch_args = True
for i, p in self.placeholders_for_process_batch:
if core.DoFn.ElementParam == p:
args_for_process_batch[i] = windowed_batch.values
elif core.DoFn.KeyParam == p:
raise NotImplementedError(
'https://github.com/apache/beam/issues/21653: Per-key process_batch'
)
elif core.DoFn.WindowParam == p:
args_for_process_batch[i] = window
elif core.DoFn.TimestampParam == p:
args_for_process_batch[i] = windowed_batch.timestamp
elif core.DoFn.PaneInfoParam == p:
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
args_for_process_batch[i] = windowed_batch.pane_info
elif isinstance(p, core.DoFn.StateParam):
raise NotImplementedError(
"https://github.com/apache/beam/issues/21653: "
"Per-key process_batch")
elif isinstance(p, core.DoFn.TimerParam):
raise NotImplementedError(
"https://github.com/apache/beam/issues/21653: "
"Per-key process_batch")
kwargs_for_process_batch = kwargs_for_process_batch or {}
if additional_kwargs:
kwargs_for_process_batch.update(additional_kwargs)
self.output_handler.handle_process_batch_outputs(
windowed_batch,
self.process_batch_method(
*args_for_process_batch, **kwargs_for_process_batch),
self.threadsafe_watermark_estimator)
@staticmethod
def _try_split(fraction,
window_index, # type: Optional[int]
stop_window_index, # type: Optional[int]
windowed_value, # type: WindowedValue
restriction,
watermark_estimator_state,
restriction_provider, # type: RestrictionProvider
restriction_tracker, # type: RestrictionTracker
watermark_estimator, # type: WatermarkEstimator
):
# type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual], Optional[int]]]
"""Try to split returning a primaries, residuals and a new stop index.
For non-window observing splittable DoFns we split the current restriction
and assign the primary and residual to all the windows.
For window observing splittable DoFns, we:
1) return a split at a window boundary if the fraction lies outside of the
current window.
2) attempt to split the current restriction, if successful then return
the primary and residual for the current window and an additional
primary and residual for any fully processed and fully unprocessed
windows.
3) fall back to returning a split at the window boundary if possible
Args:
window_index: the current index of the window being processed or None
if the splittable DoFn is not window observing.
stop_window_index: the current index to stop processing at or None
if the splittable DoFn is not window observing.
windowed_value: the current windowed value
restriction: the initial restriction when processing was started.
watermark_estimator_state: the initial watermark estimator state when
processing was started.
restriction_provider: the DoFn's restriction provider
restriction_tracker: the current restriction tracker
watermark_estimator: the current watermark estimator
Returns:
A tuple containing (primaries, residuals, new_stop_index) or None if
splitting was not possible. new_stop_index will only be set if the
splittable DoFn is window observing otherwise it will be None.
"""
def compute_whole_window_split(to_index, from_index):
restriction_size = restriction_provider.restriction_size(
windowed_value, restriction)
if restriction_size < 0:
raise ValueError(
'Expected size >= 0 but received %s.' % restriction_size)
# The primary and residual both share the same value only differing
# by the set of windows they are in.
value = ((windowed_value.value, (restriction, watermark_estimator_state)),
restriction_size)
primary_restriction = SplitResultPrimary(
primary_value=WindowedValue(
value,
windowed_value.timestamp,
windowed_value.windows[:to_index])) if to_index > 0 else None
# Don't report any updated watermarks for the residual since they have
# not processed any part of the restriction.
residual_restriction = SplitResultResidual(
residual_value=WindowedValue(
value,
windowed_value.timestamp,
windowed_value.windows[from_index:stop_window_index]),
current_watermark=None,
deferred_timestamp=None) if from_index < stop_window_index else None
return (primary_restriction, residual_restriction)
primary_restrictions = []
residual_restrictions = []
window_observing = window_index is not None
# If we are processing each window separately and we aren't on the last
# window then compute whether the split lies within the current window
# or a future window.
if window_observing and window_index != stop_window_index - 1:
progress = restriction_tracker.current_progress()
if not progress:
# Assume no work has been completed for the current window if progress
# is unavailable.
from apache_beam.io.iobase import RestrictionProgress
progress = RestrictionProgress(completed=0, remaining=1)
scaled_progress = PerWindowInvoker._scale_progress(
progress, window_index, stop_window_index)
# Compute the fraction of the remainder relative to the scaled progress.
# If the value is greater than or equal to progress.remaining_work then we
# should split at the closest window boundary.
fraction_of_remainder = scaled_progress.remaining_work * fraction
if fraction_of_remainder >= progress.remaining_work:
# The fraction is outside of the current window and hence we will
# split at the closest window boundary. Favor a split and return the
# last window if we would have rounded up to the end of the window
# based upon the fraction.
new_stop_window_index = min(
stop_window_index - 1,
window_index + max(
1,
int(
round((
progress.completed_work +
scaled_progress.remaining_work * fraction) /
progress.total_work))))
primary, residual = compute_whole_window_split(
new_stop_window_index, new_stop_window_index)
assert primary is not None
assert residual is not None
return ([primary], [residual], new_stop_window_index)
else:
# The fraction is within the current window being processed so compute
# the updated fraction based upon the number of windows being processed.
new_stop_window_index = window_index + 1
fraction = fraction_of_remainder / progress.remaining_work
# Attempt to split below, if we can't then we'll compute a split
# using only window boundaries
else:
# We aren't splitting within multiple windows so we don't change our
# stop index.
new_stop_window_index = stop_window_index
# Temporary workaround for [BEAM-7473]: get current_watermark before
# split, in case watermark gets advanced before getting split results.
# In worst case, current_watermark is always stale, which is ok.
current_watermark = (watermark_estimator.current_watermark())
current_estimator_state = (watermark_estimator.get_estimator_state())
split = restriction_tracker.try_split(fraction)
if split:
primary, residual = split
element = windowed_value.value
primary_size = restriction_provider.restriction_size(
windowed_value.value, primary)
if primary_size < 0:
raise ValueError('Expected size >= 0 but received %s.' % primary_size)
residual_size = restriction_provider.restriction_size(
windowed_value.value, residual)
if residual_size < 0:
raise ValueError('Expected size >= 0 but received %s.' % residual_size)
# We use the watermark estimator state for the original process call
# for the primary and the updated watermark estimator state for the
# residual for the split.
primary_split_value = ((element, (primary, watermark_estimator_state)),
primary_size)
residual_split_value = ((element, (residual, current_estimator_state)),
residual_size)
windows = (
windowed_value.windows[window_index],
) if window_observing else windowed_value.windows
primary_restrictions.append(
SplitResultPrimary(
primary_value=WindowedValue(
primary_split_value, windowed_value.timestamp, windows)))
residual_restrictions.append(
SplitResultResidual(
residual_value=WindowedValue(
residual_split_value, windowed_value.timestamp, windows),
current_watermark=current_watermark,
deferred_timestamp=None))
if window_observing:
assert new_stop_window_index == window_index + 1
primary, residual = compute_whole_window_split(
window_index, window_index + 1)
if primary:
primary_restrictions.append(primary)
if residual:
residual_restrictions.append(residual)
return (
primary_restrictions, residual_restrictions, new_stop_window_index)
elif new_stop_window_index and new_stop_window_index != stop_window_index:
# If we failed to split but have a new stop index then return a split
# at the window boundary.
primary, residual = compute_whole_window_split(
new_stop_window_index, new_stop_window_index)
assert primary is not None
assert residual is not None
return ([primary], [residual], new_stop_window_index)
else:
return None
def try_split(self, fraction):
# type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
if not self.is_splittable:
return None
with self.splitting_lock:
if not self.threadsafe_restriction_tracker:
return None
# Make a local reference to member variables that change references during
# processing under lock before attempting to split so we have a consistent
# view of all the references.
result = PerWindowInvoker._try_split(
fraction,
self.current_window_index,
self.stop_window_index,
self.current_windowed_value,
self.restriction,
self.watermark_estimator_state,
self.signature.get_restriction_provider(),
self.threadsafe_restriction_tracker,
self.threadsafe_watermark_estimator)
if not result:
return None
residuals, primaries, self.stop_window_index = result
return (residuals, primaries)
@staticmethod
def _scale_progress(progress, window_index, stop_window_index):
# We scale progress based upon the amount of work we will do for one
# window and have it apply for all windows.
completed = window_index * progress.total_work + progress.completed_work
remaining = (
stop_window_index -
(window_index + 1)) * progress.total_work + progress.remaining_work
from apache_beam.io.iobase import RestrictionProgress
return RestrictionProgress(completed=completed, remaining=remaining)
def current_element_progress(self):
# type: () -> Optional[RestrictionProgress]
if not self.is_splittable:
return None
with self.splitting_lock:
current_window_index = self.current_window_index
stop_window_index = self.stop_window_index
threadsafe_restriction_tracker = self.threadsafe_restriction_tracker
if not threadsafe_restriction_tracker:
return None
progress = threadsafe_restriction_tracker.current_progress()
if not current_window_index or not progress:
return progress
# stop_window_index should always be set if current_window_index is set,
# it is an error otherwise.
assert stop_window_index
return PerWindowInvoker._scale_progress(
progress, current_window_index, stop_window_index)
class DoFnRunner:
"""For internal use only; no backwards-compatibility guarantees.
A helper class for executing ParDo operations.
"""
def __init__(self,
fn, # type: core.DoFn
args,
kwargs,
side_inputs, # type: Iterable[sideinputs.SideInputMap]
windowing,
tagged_receivers, # type: Mapping[Optional[str], Receiver]
step_name=None, # type: Optional[str]
logging_context=None,
state=None,
scoped_metrics_container=None,
operation_name=None,
transform_id=None,
user_state_context=None, # type: Optional[userstate.UserStateContext]
):
"""Initializes a DoFnRunner.
Args:
fn: user DoFn to invoke
args: positional side input arguments (static and placeholder), if any
kwargs: keyword side input arguments (static and placeholder), if any
side_inputs: list of sideinput.SideInputMaps for deferred side inputs
windowing: windowing properties of the output PCollection(s)
tagged_receivers: a dict of tag name to Receiver objects
step_name: the name of this step
logging_context: DEPRECATED [BEAM-4728]
state: handle for accessing DoFn state
scoped_metrics_container: DEPRECATED
operation_name: The system name assigned by the runner for this operation.
transform_id: The PTransform Id in the pipeline proto for this DoFn.
user_state_context: The UserStateContext instance for the current
Stateful DoFn.
"""
# Need to support multiple iterations.
side_inputs = list(side_inputs)
self.step_name = step_name
self.transform_id = transform_id
self.context = DoFnContext(step_name, state=state)
self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
self.execution_context = None # type: Optional[ExecutionContext]
do_fn_signature = DoFnSignature(fn)
# Optimize for the common case.
main_receivers = tagged_receivers[None]
# TODO(https://github.com/apache/beam/issues/18886): Remove if block after
# output counter released.
if 'outputs_per_element_counter' in RuntimeValueProvider.experiments:
# TODO(BEAM-3955): Make step_name and operation_name less confused.
output_counter_name = (
CounterName('per-element-output-count', step_name=operation_name))
per_element_output_counter = state._counter_factory.get_counter(
output_counter_name, Counter.DATAFLOW_DISTRIBUTION).accumulator
else:
per_element_output_counter = None
output_handler = _OutputHandler(
windowing.windowfn,
main_receivers,
tagged_receivers,
per_element_output_counter,
getattr(fn, 'output_batch_converter', None),
getattr(
do_fn_signature.process_method.method_value,
'_beam_yields_batches',
False),
getattr(
do_fn_signature.process_batch_method.method_value,
'_beam_yields_elements',
False),
)
if do_fn_signature.is_stateful_dofn() and not user_state_context:
raise Exception(
'Requested execution of a stateful DoFn, but no user state context '
'is available. This likely means that the current runner does not '
'support the execution of stateful DoFns.')
self.do_fn_invoker = DoFnInvoker.create_invoker(
do_fn_signature,
output_handler,
self.context,
side_inputs,
args,
kwargs,
user_state_context=user_state_context,
bundle_finalizer_param=self.bundle_finalizer_param)
def process(self, windowed_value):
# type: (WindowedValue) -> Iterable[SplitResultResidual]
try:
return self.do_fn_invoker.invoke_process(windowed_value)
except BaseException as exn:
self._reraise_augmented(exn, windowed_value)
return []
def _maybe_sample_exception(
self, exn: BaseException,
windowed_value: Optional[WindowedValue]) -> None:
if self.execution_context is None:
return
output_sampler = self.execution_context.output_sampler
if output_sampler is None:
return
output_sampler.sample_exception(
windowed_value,
exn,
self.transform_id,
self.execution_context.instruction_id)
def process_batch(self, windowed_batch):
# type: (WindowedBatch) -> None
try:
self.do_fn_invoker.invoke_process_batch(windowed_batch)
except BaseException as exn:
self._reraise_augmented(exn)
def process_with_sized_restriction(self, windowed_value):
# type: (WindowedValue) -> Iterable[SplitResultResidual]
(element, (restriction, estimator_state)), _ = windowed_value.value
return self.do_fn_invoker.invoke_process(
windowed_value.with_value(element),
restriction=restriction,
watermark_estimator_state=estimator_state)
def try_split(self, fraction):
# type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.try_split(fraction)
def current_element_progress(self):
# type: () -> Optional[RestrictionProgress]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.current_element_progress()
def process_user_timer(
self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
try:
self.do_fn_invoker.invoke_user_timer(
timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag)
except BaseException as exn:
self._reraise_augmented(exn)
def _invoke_bundle_method(self, bundle_method):
try:
self.context.set_element(None)
bundle_method()
except BaseException as exn:
self._reraise_augmented(exn)
def _invoke_lifecycle_method(self, lifecycle_method):
try:
self.context.set_element(None)
lifecycle_method()
except BaseException as exn:
self._reraise_augmented(exn)
def setup(self):
# type: () -> None
self._invoke_lifecycle_method(self.do_fn_invoker.invoke_setup)
def start(self):
# type: () -> None
self._invoke_bundle_method(self.do_fn_invoker.invoke_start_bundle)
def finish(self):
# type: () -> None
self._invoke_bundle_method(self.do_fn_invoker.invoke_finish_bundle)
def teardown(self):
# type: () -> None
self._invoke_lifecycle_method(self.do_fn_invoker.invoke_teardown)
def finalize(self):
# type: () -> None
self.bundle_finalizer_param.finalize_bundle()
def _reraise_augmented(self, exn, windowed_value=None):
if getattr(exn, '_tagged_with_step', False) or not self.step_name:
raise exn
step_annotation = " [while running '%s']" % self.step_name
# To emulate exception chaining (not available in Python 2).
try:
# Attempt to construct the same kind of exception
# with an augmented message.
new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:])
new_exn._tagged_with_step = True # Could raise attribute error.
except: # pylint: disable=bare-except
# If anything goes wrong, construct a RuntimeError whose message
# records the original exception's type and message.
new_exn = RuntimeError(
traceback.format_exception_only(type(exn), exn)[-1].strip() +
step_annotation)
new_exn._tagged_with_step = True
exc_info = sys.exc_info()
_, _, tb = exc_info
new_exn = new_exn.with_traceback(tb)
self._maybe_sample_exception(exc_info, windowed_value)
_LOGGER.exception(new_exn)
raise new_exn
class OutputHandler(object):
def handle_process_outputs(
self, windowed_input_element, results, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
raise NotImplementedError
def handle_process_batch_outputs(
self, windowed_input_element, results, watermark_estimator=None):
# type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None
raise NotImplementedError
class _OutputHandler(OutputHandler):
"""Processes output produced by DoFn method invocations."""
def __init__(self,
window_fn,
main_receivers, # type: Receiver
tagged_receivers, # type: Mapping[Optional[str], Receiver]
per_element_output_counter,
output_batch_converter, # type: Optional[BatchConverter]
process_yields_batches, # type: bool
process_batch_yields_elements, # type: bool
):
"""Initializes ``_OutputHandler``.
Args:
window_fn: a windowing function (WindowFn).
main_receivers: a dict of tag name to Receiver objects.
tagged_receivers: main receiver object.
per_element_output_counter: per_element_output_counter of one work_item.
could be none if experimental flag turn off
"""
self.window_fn = window_fn
self.main_receivers = main_receivers
self.tagged_receivers = tagged_receivers
if (per_element_output_counter is not None and
per_element_output_counter.is_cythonized):
self.per_element_output_counter = per_element_output_counter
else:
self.per_element_output_counter = None
self.output_batch_converter = output_batch_converter
self._process_yields_batches = process_yields_batches
self._process_batch_yields_elements = process_batch_yields_elements
def handle_process_outputs(
self, windowed_input_element, results, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
"""Dispatch the result of process computation to the appropriate receivers.
A value wrapped in a TaggedOutput object will be unwrapped and
then dispatched to the appropriate indexed output.
"""
if results is None:
results = []
# TODO(https://github.com/apache/beam/issues/20404): Verify that the
# results object is a valid iterable type if
# performance_runtime_type_check is active, without harming performance
output_element_count = 0
for result in results:
tag, result = self._handle_tagged_output(result)
if not self._process_yields_batches:
# process yields elements
windowed_value = self._maybe_propagate_windowing_info(
windowed_input_element, result)
output_element_count += 1
self._write_value_to_tag(tag, windowed_value, watermark_estimator)
else: # process yields batches
self._verify_batch_output(result)
if isinstance(result, WindowedBatch):
assert isinstance(result, HomogeneousWindowedBatch)
windowed_batch = result
if (windowed_input_element is not None and
len(windowed_input_element.windows) != 1):
windowed_batch.windows *= len(windowed_input_element.windows)
else:
windowed_batch = (
HomogeneousWindowedBatch.from_batch_and_windowed_value(
batch=result, windowed_value=windowed_input_element))
output_element_count += self.output_batch_converter.get_length(
windowed_batch.values)
self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)
# TODO(https://github.com/apache/beam/issues/18886): Remove if block after
# output counter released. Only enable per_element_output_counter when
# counter cythonized
if self.per_element_output_counter is not None:
self.per_element_output_counter.add_input(output_element_count)
def handle_process_batch_outputs(
self, windowed_input_batch, results, watermark_estimator=None):
# type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None
"""Dispatch the result of process_batch computation to the appropriate
receivers.
A value wrapped in a TaggedOutput object will be unwrapped and
then dispatched to the appropriate indexed output.
"""
if results is None:
results = []
output_element_count = 0
for result in results:
tag, result = self._handle_tagged_output(result)
if not self._process_batch_yields_elements:
# process_batch yields batches
assert self.output_batch_converter is not None
self._verify_batch_output(result)
if isinstance(result, WindowedBatch):
assert isinstance(result, HomogeneousWindowedBatch)
windowed_batch = result
if (windowed_input_batch is not None and
len(windowed_input_batch.windows) != 1):
windowed_batch.windows *= len(windowed_input_batch.windows)
else:
windowed_batch = windowed_input_batch.with_values(result)
output_element_count += self.output_batch_converter.get_length(
windowed_batch.values)
self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)
else: # process_batch yields elements
assert isinstance(windowed_input_batch, HomogeneousWindowedBatch)
windowed_value = self._maybe_propagate_windowing_info(
windowed_input_batch.as_empty_windowed_value(), result)
output_element_count += 1
self._write_value_to_tag(tag, windowed_value, watermark_estimator)
# TODO(https://github.com/apache/beam/issues/18886): Remove if block after
# output counter released. Only enable per_element_output_counter when
# counter cythonized
if self.per_element_output_counter is not None:
self.per_element_output_counter.add_input(output_element_count)
def _maybe_propagate_windowing_info(self, windowed_input_element, result):
# type: (WindowedValue, Any) -> WindowedValue
if isinstance(result, WindowedValue):
windowed_value = result
if (windowed_input_element is not None and
len(windowed_input_element.windows) != 1):
windowed_value.windows *= len(windowed_input_element.windows)
return windowed_value
elif isinstance(result, TimestampedValue):
assign_context = WindowFn.AssignContext(result.timestamp, result.value)
windowed_value = WindowedValue(
result.value, result.timestamp, self.window_fn.assign(assign_context))
if len(windowed_input_element.windows) != 1:
windowed_value.windows *= len(windowed_input_element.windows)
return windowed_value
else:
return windowed_input_element.with_value(result)
def _handle_tagged_output(self, result):
if isinstance(result, TaggedOutput):
tag = result.tag
if not isinstance(tag, str):
raise TypeError('In %s, tag %s is not a string' % (self, tag))
return tag, result.value
return None, result
def _write_value_to_tag(self, tag, windowed_value, watermark_estimator):
if watermark_estimator is not None:
watermark_estimator.observe_timestamp(windowed_value.timestamp)
if tag is None:
self.main_receivers.receive(windowed_value)
else:
self.tagged_receivers[tag].receive(windowed_value)
def _write_batch_to_tag(self, tag, windowed_batch, watermark_estimator):
if watermark_estimator is not None:
for timestamp in windowed_batch.timestamps:
watermark_estimator.observe_timestamp(timestamp)
if tag is None:
self.main_receivers.receive_batch(windowed_batch)
else:
self.tagged_receivers[tag].receive_batch(windowed_batch)
def _verify_batch_output(self, result):
if isinstance(result, (WindowedValue, TimestampedValue)):
raise TypeError(
f"Received {type(result).__name__} from DoFn that was "
"expected to produce a batch.")
def start_bundle_outputs(self, results):
"""Validate that start_bundle does not output any elements"""
if results is None:
return
raise RuntimeError(
'Start Bundle should not output any elements but got %s' % results)
def finish_bundle_outputs(self, results):
"""Dispatch the result of finish_bundle to the appropriate receivers.
A value wrapped in a TaggedOutput object will be unwrapped and
then dispatched to the appropriate indexed output.
"""
if results is None:
return
for result in results:
tag = None
if isinstance(result, TaggedOutput):
tag = result.tag
if not isinstance(tag, str):
raise TypeError('In %s, tag %s is not a string' % (self, tag))
result = result.value
if isinstance(result, WindowedValue):
windowed_value = result
else:
raise RuntimeError('Finish Bundle should only output WindowedValue ' +\
'type but got %s' % type(result))
if tag is None:
self.main_receivers.receive(windowed_value)
else:
self.tagged_receivers[tag].receive(windowed_value)
class _NoContext(WindowFn.AssignContext):
"""An uninspectable WindowFn.AssignContext."""
NO_VALUE = object()
def __init__(self, value, timestamp=NO_VALUE):
self.value = value
self._timestamp = timestamp
@property
def timestamp(self):
if self._timestamp is self.NO_VALUE:
raise ValueError('No timestamp in this context.')
else:
return self._timestamp
@property
def existing_windows(self):
raise ValueError('No existing_windows in this context.')
class DoFnState(object):
"""For internal use only; no backwards-compatibility guarantees.
Keeps track of state that DoFns want, currently, user counters.
"""
def __init__(self, counter_factory):
self.step_name = ''
self._counter_factory = counter_factory
def counter_for(self, aggregator):
"""Looks up the counter for this aggregator, creating one if necessary."""
return self._counter_factory.get_aggregator_counter(
self.step_name, aggregator)
# TODO(robertwb): Replace core.DoFnContext with this.
class DoFnContext(object):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(self, label, element=None, state=None):
self.label = label
self.state = state
if element is not None:
self.set_element(element)
def set_element(self, windowed_value):
# type: (Optional[WindowedValue]) -> None
self.windowed_value = windowed_value
@property
def element(self):
if self.windowed_value is None:
raise AttributeError('element not accessible in this context')
else:
return self.windowed_value.value
@property
def timestamp(self):
if self.windowed_value is None:
raise AttributeError('timestamp not accessible in this context')
else:
return self.windowed_value.timestamp
@property
def windows(self):
if self.windowed_value is None:
raise AttributeError('windows not accessible in this context')
else:
return self.windowed_value.windows
def group_by_key_input_visitor(deterministic_key_coders=True):
# Importing here to avoid a circular dependency
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.pipeline import PipelineVisitor
from apache_beam.transforms.core import GroupByKey
class GroupByKeyInputVisitor(PipelineVisitor):
"""A visitor that replaces `Any` element type for input `PCollection` of
a `GroupByKey` with a `KV` type.
TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
we could directly replace the coder instead of mutating the element type.
"""
def __init__(self, deterministic_key_coders=True):
self.deterministic_key_coders = deterministic_key_coders
def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)
def visit_transform(self, transform_node):
if isinstance(transform_node.transform, GroupByKey):
pcoll = transform_node.inputs[0]
pcoll.element_type = typehints.coerce_to_kv_type(
pcoll.element_type, transform_node.full_label)
pcoll.requires_deterministic_key_coder = (
self.deterministic_key_coders and transform_node.full_label)
key_type, value_type = pcoll.element_type.tuple_types
if transform_node.outputs:
key = next(iter(transform_node.outputs.keys()))
transform_node.outputs[key].element_type = typehints.KV[
key_type, typehints.Iterable[value_type]]
transform_node.outputs[key].requires_deterministic_key_coder = (
self.deterministic_key_coders and transform_node.full_label)
return GroupByKeyInputVisitor(deterministic_key_coders)
def validate_pipeline_graph(pipeline_proto):
"""Ensures this is a correctly constructed Beam pipeline.
"""
def get_coder(pcoll_id):
return pipeline_proto.components.coders[
pipeline_proto.components.pcollections[pcoll_id].coder_id]
def validate_transform(transform_id):
transform_proto = pipeline_proto.components.transforms[transform_id]
# Currently the only validation we perform is that GBK operations have
# their coders set properly.
if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
if len(transform_proto.inputs) != 1:
raise ValueError("Unexpected number of inputs: %s" % transform_proto)
if len(transform_proto.outputs) != 1:
raise ValueError("Unexpected number of outputs: %s" % transform_proto)
input_coder = get_coder(next(iter(transform_proto.inputs.values())))
output_coder = get_coder(next(iter(transform_proto.outputs.values())))
if input_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for input of %s: %s" % (transform_id, input_coder))
if output_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for output of %s: %s" % (transform_id, output_coder))
output_values_coder = pipeline_proto.components.coders[
output_coder.component_coder_ids[1]]
if (input_coder.component_coder_ids[0] !=
output_coder.component_coder_ids[0] or
output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
output_values_coder.component_coder_ids[0] !=
input_coder.component_coder_ids[1]):
raise ValueError(
"Incompatible input coder %s and output coder %s for transform %s" %
(transform_id, input_coder, output_coder))
elif transform_proto.spec.urn == common_urns.primitives.ASSIGN_WINDOWS.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)
elif transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)
for t in transform_proto.subtransforms:
validate_transform(t)
for t in pipeline_proto.root_transform_ids:
validate_transform(t)
def merge_common_environments(pipeline_proto, inplace=False):
def dep_key(dep):
if dep.type_urn == common_urns.artifact_types.FILE.urn:
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'path', payload.path
elif dep.type_urn == common_urns.artifact_types.URL.urn:
payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'url', payload.url
else:
type_info = dep.type_urn, dep.type_payload
return type_info, dep.role_urn, dep.role_payload
def base_env_key(env):
return (
env.urn,
env.payload,
tuple(sorted(env.capabilities)),
tuple(sorted(env.resource_hints.items())),
tuple(sorted(dep_key(dep) for dep in env.dependencies)))
def env_key(env):
return tuple(
sorted(
base_env_key(e)
for e in environments.expand_anyof_environments(env)))
canonical_environments = collections.defaultdict(list)
for env_id, env in pipeline_proto.components.environments.items():
canonical_environments[env_key(env)].append(env_id)
if len(canonical_environments) == len(pipeline_proto.components.environments):
# All environments are already sufficiently distinct.
return pipeline_proto
environment_remappings = {
e: es[0]
for es in canonical_environments.values() for e in es
}
if not inplace:
pipeline_proto = copy.copy(pipeline_proto)
for t in pipeline_proto.components.transforms.values():
if t.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if t.environment_id:
t.environment_id = environment_remappings[t.environment_id]
for w in pipeline_proto.components.windowing_strategies.values():
if w.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if w.environment_id:
w.environment_id = environment_remappings[w.environment_id]
for e in set(pipeline_proto.components.environments.keys()) - set(
environment_remappings.values()):
del pipeline_proto.components.environments[e]
return pipeline_proto