blob: 528b2d1f576bb1c6d73bad48bf79021869347297 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains Splittable DoFn logic that is specific to DirectRunner.
"""
# pytype: skip-file
import uuid
from threading import Lock
from threading import Timer
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Optional
import apache_beam as beam
from apache_beam import TimeDomain
from apache_beam import pvalue
from apache_beam.coders import typecoders
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.common import DoFnContext
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import OutputHandler
from apache_beam.runners.direct.evaluation_context import DirectStepContext
from apache_beam.runners.direct.util import KeyedWorkItem
from apache_beam.runners.direct.watermark_manager import WatermarkManager
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import ProcessContinuation
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import _ReadModifyWriteStateTag
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
from apache_beam.iobase import WatermarkEstimator
class SplittableParDoOverride(PTransformOverride):
"""A transform override for ParDo transformss of SplittableDoFns.
Replaces the ParDo transform with a SplittableParDo transform that performs
SDF specific logic.
"""
def matches(self, applied_ptransform):
assert isinstance(applied_ptransform, AppliedPTransform)
transform = applied_ptransform.transform
if isinstance(transform, ParDo):
signature = DoFnSignature(transform.fn)
return signature.is_splittable_dofn()
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
ptransform = applied_ptransform.transform
assert isinstance(ptransform, ParDo)
do_fn = ptransform.fn
signature = DoFnSignature(do_fn)
if signature.is_splittable_dofn():
return SplittableParDo(ptransform)
else:
return ptransform
class SplittableParDo(PTransform):
"""A transform that processes a PCollection using a Splittable DoFn."""
def __init__(self, ptransform):
assert isinstance(ptransform, ParDo)
self._ptransform = ptransform
def expand(self, pcoll):
sdf = self._ptransform.fn
signature = DoFnSignature(sdf)
restriction_coder = signature.get_restriction_coder()
element_coder = typecoders.registry.get_coder(pcoll.element_type)
keyed_elements = (
pcoll
| 'pair' >> ParDo(PairWithRestrictionFn(sdf))
| 'split' >> ParDo(SplitRestrictionFn(sdf))
| 'explode' >> ParDo(ExplodeWindowsFn())
| 'random' >> ParDo(RandomUniqueKeyFn()))
return keyed_elements | ProcessKeyedElements(
sdf,
element_coder,
restriction_coder,
pcoll.windowing,
self._ptransform.args,
self._ptransform.kwargs,
self._ptransform.side_inputs)
class ElementAndRestriction(object):
"""A holder for an element, restriction, and watermark estimator state."""
def __init__(self, element, restriction, watermark_estimator_state):
self.element = element
self.restriction = restriction
self.watermark_estimator_state = watermark_estimator_state
class PairWithRestrictionFn(beam.DoFn):
"""A transform that pairs each element with a restriction."""
def __init__(self, do_fn):
self._signature = DoFnSignature(do_fn)
def start_bundle(self):
self._invoker = DoFnInvoker.create_invoker(
self._signature,
output_handler=_NoneShallPassOutputHandler(),
process_invocation=False)
def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
initial_restriction = self._invoker.invoke_initial_restriction(element)
watermark_estimator_state = (
self._signature.process_method.watermark_estimator_provider.
initial_estimator_state(element, initial_restriction))
yield ElementAndRestriction(
element, initial_restriction, watermark_estimator_state)
class SplitRestrictionFn(beam.DoFn):
"""A transform that perform initial splitting of Splittable DoFn inputs."""
def __init__(self, do_fn):
self._do_fn = do_fn
def start_bundle(self):
signature = DoFnSignature(self._do_fn)
self._invoker = DoFnInvoker.create_invoker(
signature,
output_handler=_NoneShallPassOutputHandler(),
process_invocation=False)
def process(self, element_and_restriction, *args, **kwargs):
element = element_and_restriction.element
restriction = element_and_restriction.restriction
restriction_parts = self._invoker.invoke_split(element, restriction)
for part in restriction_parts:
yield ElementAndRestriction(
element, part, element_and_restriction.watermark_estimator_state)
class ExplodeWindowsFn(beam.DoFn):
"""A transform that forces the runner to explode windows.
This is done to make sure that Splittable DoFn proceses an element for each of
the windows that element belongs to.
"""
def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
yield element
class RandomUniqueKeyFn(beam.DoFn):
"""A transform that assigns a unique key to each element."""
def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
# We ignore UUID collisions here since they are extremely rare.
yield (uuid.uuid4().bytes, element)
class ProcessKeyedElements(PTransform):
"""A primitive transform that performs SplittableDoFn magic.
Input to this transform should be a PCollection of keyed ElementAndRestriction
objects.
"""
def __init__(
self,
sdf,
element_coder,
restriction_coder,
windowing_strategy,
ptransform_args,
ptransform_kwargs,
ptransform_side_inputs):
self.sdf = sdf
self.element_coder = element_coder
self.restriction_coder = restriction_coder
self.windowing_strategy = windowing_strategy
self.ptransform_args = ptransform_args
self.ptransform_kwargs = ptransform_kwargs
self.ptransform_side_inputs = ptransform_side_inputs
def expand(self, pcoll):
return pvalue.PCollection.from_(pcoll)
class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride):
"""A transform override for ProcessElements transform."""
def matches(self, applied_ptransform):
return isinstance(applied_ptransform.transform, ProcessKeyedElements)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
return ProcessKeyedElementsViaKeyedWorkItems(applied_ptransform.transform)
class ProcessKeyedElementsViaKeyedWorkItems(PTransform):
"""A transform that processes Splittable DoFn input via KeyedWorkItems."""
def __init__(self, process_keyed_elements_transform):
self._process_keyed_elements_transform = process_keyed_elements_transform
def expand(self, pcoll):
process_elements = ProcessElements(self._process_keyed_elements_transform)
process_elements.args = (
self._process_keyed_elements_transform.ptransform_args)
process_elements.kwargs = (
self._process_keyed_elements_transform.ptransform_kwargs)
process_elements.side_inputs = (
self._process_keyed_elements_transform.ptransform_side_inputs)
return pcoll | beam.core.GroupByKey() | process_elements
class ProcessElements(PTransform):
"""A primitive transform for processing keyed elements or KeyedWorkItems.
Will be evaluated by
`runners.direct.transform_evaluator._ProcessElementsEvaluator`.
"""
def __init__(self, process_keyed_elements_transform):
self._process_keyed_elements_transform = process_keyed_elements_transform
self.sdf = self._process_keyed_elements_transform.sdf
def expand(self, pcoll):
return pvalue.PCollection.from_(pcoll)
def new_process_fn(self, sdf):
return ProcessFn(
sdf,
self._process_keyed_elements_transform.ptransform_args,
self._process_keyed_elements_transform.ptransform_kwargs)
class ProcessFn(beam.DoFn):
"""A `DoFn` that executes machineary for invoking a Splittable `DoFn`.
Input to the `ParDo` step that includes a `ProcessFn` will be a `PCollection`
of `ElementAndRestriction` objects.
This class is mainly responsible for following.
(1) setup environment for properly invoking a Splittable `DoFn`.
(2) invoke `process()` method of a Splittable `DoFn`.
(3) after the `process()` invocation of the Splittable `DoFn`, determine if a
re-invocation of the element is needed. If this is the case, set state and
a timer for a re-invocation and hold output watermark till this
re-invocation.
(4) after the final invocation of a given element clear any previous state set
for re-invoking the element and release the output watermark.
"""
def __init__(self, sdf, args_for_invoker, kwargs_for_invoker):
self.sdf = sdf
self._element_tag = _ReadModifyWriteStateTag('element')
self._restriction_tag = _ReadModifyWriteStateTag('restriction')
self._watermark_state_tag = _ReadModifyWriteStateTag(
'watermark_estimator_state')
self.watermark_hold_tag = _ReadModifyWriteStateTag('watermark_hold')
self._process_element_invoker = None
self._output_processor = _OutputHandler()
self.sdf_invoker = DoFnInvoker.create_invoker(
DoFnSignature(self.sdf),
context=DoFnContext('unused_context'),
output_handler=self._output_processor,
input_args=args_for_invoker,
input_kwargs=kwargs_for_invoker)
self._step_context = None
@property
def step_context(self):
return self._step_context
@step_context.setter
def step_context(self, step_context):
assert isinstance(step_context, DirectStepContext)
self._step_context = step_context
def set_process_element_invoker(self, process_element_invoker):
assert isinstance(process_element_invoker, SDFProcessElementInvoker)
self._process_element_invoker = process_element_invoker
def process(
self,
element,
timestamp=beam.DoFn.TimestampParam,
window=beam.DoFn.WindowParam,
*args,
**kwargs):
if isinstance(element, KeyedWorkItem):
# Must be a timer firing.
key = element.encoded_key
else:
key, values = element
values = list(values)
assert len(values) == 1
# Value here will either be a WindowedValue or an ElementAndRestriction
# object.
# TODO: handle key collisions here.
assert len(values) == 1, 'Internal error. Processing of splittable ' \
'DoFn cannot continue since elements did not ' \
'have unique keys.'
value = values[0]
if len(values) != 1:
raise ValueError('')
state = self._step_context.get_keyed_state(key)
element_state = state.get_state(window, self._element_tag)
# Initially element_state is an empty list.
is_seed_call = not element_state
if not is_seed_call:
element = state.get_state(window, self._element_tag)
restriction = state.get_state(window, self._restriction_tag)
watermark_estimator_state = state.get_state(
window, self._watermark_state_tag)
windowed_element = WindowedValue(element, timestamp, [window])
else:
# After values iterator is expanded above we should have gotten a list
# with a single ElementAndRestriction object.
assert isinstance(value, ElementAndRestriction)
element_and_restriction = value
element = element_and_restriction.element
restriction = element_and_restriction.restriction
watermark_estimator_state = (
element_and_restriction.watermark_estimator_state)
if isinstance(value, WindowedValue):
windowed_element = WindowedValue(
element, value.timestamp, value.windows)
else:
windowed_element = WindowedValue(element, timestamp, [window])
assert self._process_element_invoker
assert isinstance(self._process_element_invoker, SDFProcessElementInvoker)
output_values = self._process_element_invoker.invoke_process_element(
self.sdf_invoker,
self._output_processor,
windowed_element,
restriction,
watermark_estimator_state,
*args,
**kwargs)
sdf_result = None
for output in output_values:
if isinstance(output, SDFProcessElementInvoker.Result):
# SDFProcessElementInvoker.Result should be the last item yielded.
sdf_result = output
break
yield output
assert sdf_result, ('SDFProcessElementInvoker must return a '
'SDFProcessElementInvoker.Result object as the last '
'value of a SDF invoke_process_element() invocation.')
if not sdf_result.residual_restriction:
# All work for current residual and restriction pair is complete.
state.clear_state(window, self._element_tag)
state.clear_state(window, self._restriction_tag)
state.clear_state(window, self._watermark_state_tag)
# Releasing output watermark by setting it to positive infinity.
state.add_state(
window, self.watermark_hold_tag, WatermarkManager.WATERMARK_POS_INF)
else:
state.add_state(window, self._element_tag, element)
state.add_state(
window, self._restriction_tag, sdf_result.residual_restriction)
state.add_state(
window, self._watermark_state_tag, watermark_estimator_state)
# Holding output watermark by setting it to negative infinity.
state.add_state(
window, self.watermark_hold_tag, WatermarkManager.WATERMARK_NEG_INF)
# Setting a timer to be reinvoked to continue processing the element.
# Currently Python SDK only supports setting timers based on watermark. So
# forcing a reinvocation by setting a timer for watermark negative
# infinity.
# TODO(chamikara): update this by setting a timer for the proper
# processing time when Python SDK supports that.
state.set_timer(
window, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_NEG_INF)
class SDFProcessElementInvoker(object):
"""A utility that invokes SDF `process()` method and requests checkpoints.
This class is responsible for invoking the `process()` method of a Splittable
`DoFn` and making sure that invocation terminated properly. Based on the input
configuration, this class may decide to request a checkpoint for a `process()`
execution so that runner can process current output and resume the invocation
at a later time.
More specifically, when initializing a `SDFProcessElementInvoker`, caller may
specify the number of output elements or processing time after which a
checkpoint should be requested. This class is responsible for properly
requesting a checkpoint based on either of these criteria.
When the `process()` call of Splittable `DoFn` ends, this class performs
validations to make sure that processing ended gracefully and returns a
`SDFProcessElementInvoker.Result` that contains information which can be used
by the caller to perform another `process()` invocation for the residual.
A `process()` invocation may decide to give up processing voluntarily by
returning a `ProcessContinuation` object (see documentation of
`ProcessContinuation` for more details). So if a 'ProcessContinuation' is
produced this class ends the execution and performs steps to finalize the
current invocation.
"""
class Result(object):
def __init__(
self,
residual_restriction=None,
process_continuation=None,
future_output_watermark=None):
"""Returned as a result of a `invoke_process_element()` invocation.
Args:
residual_restriction: a restriction for the unprocessed part of the
element.
process_continuation: a `ProcessContinuation` if one was returned as the
last element of the SDF `process()` invocation.
future_output_watermark: output watermark of the results that will be
produced when invoking the Splittable `DoFn`
for the current element with
`residual_restriction`.
"""
self.residual_restriction = residual_restriction
self.process_continuation = process_continuation
self.future_output_watermark = future_output_watermark
def __init__(self, max_num_outputs, max_duration):
self._max_num_outputs = max_num_outputs
self._max_duration = max_duration
self._checkpoint_lock = Lock()
def test_method(self):
raise ValueError
def invoke_process_element(
self,
sdf_invoker,
output_processor,
element,
restriction,
watermark_estimator_state,
*args,
**kwargs):
"""Invokes `process()` method of a Splittable `DoFn` for a given element.
Args:
sdf_invoker: a `DoFnInvoker` for the Splittable `DoFn`.
element: the element to process
Returns:
a `SDFProcessElementInvoker.Result` object.
"""
assert isinstance(sdf_invoker, DoFnInvoker)
class CheckpointState(object):
def __init__(self):
self.checkpointed = None
self.residual_restriction = None
checkpoint_state = CheckpointState()
def initiate_checkpoint():
with self._checkpoint_lock:
if checkpoint_state.checkpointed:
return
checkpoint_state.checkpointed = object()
split = sdf_invoker.try_split(0)
if split:
_, checkpoint_state.residual_restriction = split
else:
# Clear the checkpoint if the split didn't happen. This counters
# a very unlikely race condition that the Timer attempted to initiate
# a checkpoint before invoke_process set the current element allowing
# for another attempt to checkpoint.
checkpoint_state.checkpointed = None
output_processor.reset()
Timer(self._max_duration, initiate_checkpoint).start()
sdf_invoker.invoke_process(
element,
additional_args=args,
restriction=restriction,
watermark_estimator_state=watermark_estimator_state)
assert output_processor.output_iter is not None
output_count = 0
# We have to expand and re-yield here to support ending execution for a
# given number of output elements as well as to capture the
# ProcessContinuation of one was returned.
process_continuation = None
for output in output_processor.output_iter:
# A ProcessContinuation, if returned, should be the last element.
assert not process_continuation
if isinstance(output, ProcessContinuation):
# Taking a checkpoint so that we can determine primary and residual
# restrictions.
initiate_checkpoint()
# A ProcessContinuation should always be the last element produced by
# the output iterator.
# TODO: support continuing after the specified amount of delay.
# Continuing here instead of breaking to enforce that this is the last
# element.
process_continuation = output
continue
yield output
output_count += 1
if self._max_num_outputs and output_count >= self._max_num_outputs:
initiate_checkpoint()
result = (
SDFProcessElementInvoker.Result(
residual_restriction=checkpoint_state.residual_restriction)
if checkpoint_state.residual_restriction else
SDFProcessElementInvoker.Result())
yield result
class _OutputHandler(OutputHandler):
def __init__(self):
self.output_iter = None
def handle_process_outputs(
self, windowed_input_element, output_iter, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
self.output_iter = output_iter
def reset(self):
self.output_iter = None
class _NoneShallPassOutputHandler(OutputHandler):
def handle_process_outputs(
self, windowed_input_element, output_iter, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
raise RuntimeError()