blob: cf2b61d48b50ed6f30919194b6c1bd4a0cb52a93 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""SDK harness for executing Python Fns via the Fn API."""
# pytype: skip-file
import base64
import bisect
import collections
import copy
import json
import logging
import random
import threading
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Container
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from typing import cast
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
import apache_beam as beam
from apache_beam import coders
from apache_beam.coders import WindowedValueCoder
from apache_beam.coders import coder_impl
from apache_beam.internal import pickler
from apache_beam.io import iobase
from apache_beam.metrics import monitoring_infos
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import common
from apache_beam.runners import pipeline_context
from apache_beam.runners.worker import data_sampler
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
from apache_beam.transforms import TimeDomain
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import sideinputs
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.utils import counters
from apache_beam.utils import proto_utils
from apache_beam.utils import timestamp
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
from google.protobuf import message # pylint: disable=ungrouped-imports
from apache_beam import pvalue
from apache_beam.portability.api import metrics_pb2
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import sdk_worker
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.window import BoundedWindow
from apache_beam.utils import windowed_value
T = TypeVar('T')
ConstructorFn = Callable[[
'BeamTransformFactory',
Any,
beam_runner_api_pb2.PTransform,
Union['message.Message', bytes],
Dict[str, List[operations.Operation]]
],
operations.Operation]
OperationT = TypeVar('OperationT', bound=operations.Operation)
FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState',
'CombiningValueRuntimeState',
'SynchronousSetRuntimeState',
'SynchronousBagRuntimeState']
DATA_INPUT_URN = 'beam:runner:source:v1'
DATA_OUTPUT_URN = 'beam:runner:sink:v1'
SYNTHETIC_DATA_SAMPLING_URN = 'beam:internal:sampling:v1'
IDENTITY_DOFN_URN = 'beam:dofn:identity:0.1'
# TODO(vikasrk): Fix this once runner sends appropriate common_urns.
OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN = 'beam:dofn:javasdk:0.1'
OLD_DATAFLOW_RUNNER_HARNESS_READ_URN = 'beam:source:java:0.1'
URNS_NEEDING_PCOLLECTIONS = set([
monitoring_infos.ELEMENT_COUNT_URN, monitoring_infos.SAMPLED_BYTE_SIZE_URN
])
_LOGGER = logging.getLogger(__name__)
class RunnerIOOperation(operations.Operation):
"""Common baseclass for runner harness IO operations."""
def __init__(self,
name_context, # type: common.NameContext
step_name, # type: Any
consumers, # type: Mapping[Any, Iterable[operations.Operation]]
counter_factory, # type: counters.CounterFactory
state_sampler, # type: statesampler.StateSampler
windowed_coder, # type: coders.Coder
transform_id, # type: str
data_channel # type: data_plane.DataChannel
):
# type: (...) -> None
super().__init__(name_context, None, counter_factory, state_sampler)
self.windowed_coder = windowed_coder
self.windowed_coder_impl = windowed_coder.get_impl()
# transform_id represents the consumer for the bytes in the data plane for a
# DataInputOperation or a producer of these bytes for a DataOutputOperation.
self.transform_id = transform_id
self.data_channel = data_channel
for _, consumer_ops in consumers.items():
for consumer in consumer_ops:
self.add_receiver(consumer, 0)
class DataOutputOperation(RunnerIOOperation):
"""A sink-like operation that gathers outputs to be sent back to the runner.
"""
def set_output_stream(self, output_stream):
# type: (data_plane.ClosableOutputStream) -> None
self.output_stream = output_stream
def process(self, windowed_value):
# type: (windowed_value.WindowedValue) -> None
self.windowed_coder_impl.encode_to_stream(
windowed_value, self.output_stream, True)
self.output_stream.maybe_flush()
def finish(self):
# type: () -> None
super().finish()
self.output_stream.close()
class DataInputOperation(RunnerIOOperation):
"""A source-like operation that gathers input from the runner."""
def __init__(self,
operation_name, # type: common.NameContext
step_name,
consumers, # type: Mapping[Any, List[operations.Operation]]
counter_factory, # type: counters.CounterFactory
state_sampler, # type: statesampler.StateSampler
windowed_coder, # type: coders.Coder
transform_id,
data_channel # type: data_plane.GrpcClientDataChannel
):
# type: (...) -> None
super().__init__(
operation_name,
step_name,
consumers,
counter_factory,
state_sampler,
windowed_coder,
transform_id=transform_id,
data_channel=data_channel)
self.consumer = next(iter(consumers.values()))
self.splitting_lock = threading.Lock()
self.index = -1
self.stop = float('inf')
self.started = False
def setup(self, data_sampler=None):
super().setup(data_sampler)
# We must do this manually as we don't have a spec or spec.output_coders.
self.receivers = [
operations.ConsumerSet.create(
counter_factory=self.counter_factory,
step_name=self.name_context.step_name,
output_index=0,
consumers=self.consumer,
coder=self.windowed_coder,
producer_type_hints=self._get_runtime_performance_hints(),
producer_batch_converter=self.get_output_batch_converter())
]
def start(self):
# type: () -> None
super().start()
with self.splitting_lock:
self.started = True
def process(self, windowed_value):
# type: (windowed_value.WindowedValue) -> None
self.output(windowed_value)
def process_encoded(self, encoded_windowed_values):
# type: (bytes) -> None
input_stream = coder_impl.create_InputStream(encoded_windowed_values)
while input_stream.size() > 0:
with self.splitting_lock:
if self.index == self.stop - 1:
return
self.index += 1
try:
decoded_value = self.windowed_coder_impl.decode_from_stream(
input_stream, True)
except Exception as exn:
raise ValueError(
"Error decoding input stream with coder " +
str(self.windowed_coder)) from exn
self.output(decoded_value)
def monitoring_infos(self, transform_id, tag_to_pcollection_id):
# type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
all_monitoring_infos = super().monitoring_infos(
transform_id, tag_to_pcollection_id)
read_progress_info = monitoring_infos.int64_counter(
monitoring_infos.DATA_CHANNEL_READ_INDEX,
self.index,
ptransform=transform_id)
all_monitoring_infos[monitoring_infos.to_key(
read_progress_info)] = read_progress_info
return all_monitoring_infos
# TODO(https://github.com/apache/beam/issues/19737): typing not compatible
# with super type
def try_split( # type: ignore[override]
self, fraction_of_remainder, total_buffer_size, allowed_split_points):
# type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]]
with self.splitting_lock:
if not self.started:
return None
if self.index == -1:
# We are "finished" with the (non-existent) previous element.
current_element_progress = 1.0
else:
current_element_progress_object = (
self.receivers[0].current_element_progress())
if current_element_progress_object is None:
current_element_progress = 0.5
else:
current_element_progress = (
current_element_progress_object.fraction_completed)
# Now figure out where to split.
split = self._compute_split(
self.index,
current_element_progress,
self.stop,
fraction_of_remainder,
total_buffer_size,
allowed_split_points,
self.receivers[0].try_split)
if split:
self.stop = split[-1]
return split
@staticmethod
def _compute_split(
index,
current_element_progress,
stop,
fraction_of_remainder,
total_buffer_size,
allowed_split_points=(),
try_split=lambda fraction: None):
def is_valid_split_point(index):
return not allowed_split_points or index in allowed_split_points
if total_buffer_size < index + 1:
total_buffer_size = index + 1
elif total_buffer_size > stop:
total_buffer_size = stop
# The units here (except for keep_of_element_remainder) are all in
# terms of number of (possibly fractional) elements.
remainder = total_buffer_size - index - current_element_progress
keep = remainder * fraction_of_remainder
if current_element_progress < 1:
keep_of_element_remainder = keep / (1 - current_element_progress)
# If it's less than what's left of the current element,
# try splitting at the current element.
if (keep_of_element_remainder < 1 and is_valid_split_point(index) and
is_valid_split_point(index + 1)):
split = try_split(
keep_of_element_remainder
) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]]
if split:
element_primaries, element_residuals = split
return index - 1, element_primaries, element_residuals, index + 1
# Otherwise, split at the closest element boundary.
# pylint: disable=bad-option-value
stop_index = index + max(1, int(round(current_element_progress + keep)))
if allowed_split_points and stop_index not in allowed_split_points:
# Choose the closest allowed split point.
allowed_split_points = sorted(allowed_split_points)
closest = bisect.bisect(allowed_split_points, stop_index)
if closest == 0:
stop_index = allowed_split_points[0]
elif closest == len(allowed_split_points):
stop_index = allowed_split_points[-1]
else:
prev = allowed_split_points[closest - 1]
next = allowed_split_points[closest]
if index < prev and stop_index - prev < next - stop_index:
stop_index = prev
else:
stop_index = next
if index < stop_index < stop:
return stop_index - 1, [], [], stop_index
else:
return None
def finish(self):
# type: () -> None
super().finish()
with self.splitting_lock:
self.index += 1
self.started = False
def reset(self):
# type: () -> None
with self.splitting_lock:
self.index = -1
self.stop = float('inf')
super().reset()
class _StateBackedIterable(object):
def __init__(self,
state_handler, # type: sdk_worker.CachingStateHandler
state_key, # type: beam_fn_api_pb2.StateKey
coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl]
):
# type: (...) -> None
self._state_handler = state_handler
self._state_key = state_key
if isinstance(coder_or_impl, coders.Coder):
self._coder_impl = coder_or_impl.get_impl()
else:
self._coder_impl = coder_or_impl
def __iter__(self):
# type: () -> Iterator[Any]
return iter(
self._state_handler.blocking_get(self._state_key, self._coder_impl))
def __reduce__(self):
return list, (list(self), )
coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(
_StateBackedIterable)
class StateBackedSideInputMap(object):
_BULK_READ_LIMIT = 100
_BULK_READ_FULLY = "fully"
_BULK_READ_PARTIALLY = "partially"
def __init__(self,
state_handler, # type: sdk_worker.CachingStateHandler
transform_id, # type: str
tag, # type: Optional[str]
side_input_data, # type: pvalue.SideInputData
coder, # type: WindowedValueCoder
use_bulk_read = False, # type: bool
):
# type: (...) -> None
self._state_handler = state_handler
self._transform_id = transform_id
self._tag = tag
self._side_input_data = side_input_data
self._element_coder = coder.wrapped_value_coder
self._target_window_coder = coder.window_coder
# TODO(robertwb): Limit the cache size.
self._cache = {} # type: Dict[BoundedWindow, Any]
self._use_bulk_read = use_bulk_read
def __getitem__(self, window):
target_window = self._side_input_data.window_mapping_fn(window)
if target_window not in self._cache:
state_handler = self._state_handler
access_pattern = self._side_input_data.access_pattern
if access_pattern == common_urns.side_inputs.ITERABLE.urn:
state_key = beam_fn_api_pb2.StateKey(
iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
transform_id=self._transform_id,
side_input_id=self._tag,
window=self._target_window_coder.encode(target_window)))
raw_view = _StateBackedIterable(
state_handler, state_key, self._element_coder)
elif access_pattern == common_urns.side_inputs.MULTIMAP.urn:
state_key = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id=self._transform_id,
side_input_id=self._tag,
window=self._target_window_coder.encode(target_window),
key=b''))
kv_iter_state_key = beam_fn_api_pb2.StateKey(
multimap_keys_values_side_input=beam_fn_api_pb2.StateKey.
MultimapKeysValuesSideInput(
transform_id=self._transform_id,
side_input_id=self._tag,
window=self._target_window_coder.encode(target_window)))
cache = {}
key_coder = self._element_coder.key_coder()
key_coder_impl = key_coder.get_impl()
value_coder = self._element_coder.value_coder()
use_bulk_read = self._use_bulk_read
class MultiMap(object):
_bulk_read = None
_lock = threading.Lock()
def __getitem__(self, key):
if use_bulk_read:
if self._bulk_read is None:
with self._lock:
if self._bulk_read is None:
try:
# Attempt to bulk read the key-values over the iterable
# protocol which, if supported, can be much more efficient
# than point lookups if it fits into memory.
for ix, (k, vs) in enumerate(_StateBackedIterable(
state_handler,
kv_iter_state_key,
coders.TupleCoder(
(key_coder, coders.IterableCoder(value_coder))))):
cache[k] = vs
if ix > StateBackedSideInputMap._BULK_READ_LIMIT:
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_PARTIALLY)
break
else:
# We reached the end of the iteration without breaking.
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_FULLY)
except Exception:
_LOGGER.error(
"Iterable access of map side inputs unsupported.",
exc_info=True)
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_PARTIALLY)
if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY):
return cache.get(key, [])
if key not in cache:
keyed_state_key = beam_fn_api_pb2.StateKey()
keyed_state_key.CopyFrom(state_key)
keyed_state_key.multimap_side_input.key = (
key_coder_impl.encode_nested(key))
cache[key] = _StateBackedIterable(
state_handler, keyed_state_key, value_coder)
return cache[key]
def __reduce__(self):
# TODO(robertwb): Figure out how to support this.
raise TypeError(common_urns.side_inputs.MULTIMAP.urn)
raw_view = MultiMap()
else:
raise ValueError("Unknown access pattern: '%s'" % access_pattern)
self._cache[target_window] = self._side_input_data.view_fn(raw_view)
return self._cache[target_window]
def is_globally_windowed(self):
# type: () -> bool
return (
self._side_input_data.window_mapping_fn ==
sideinputs._global_window_mapping_fn)
def reset(self):
# type: () -> None
# TODO(BEAM-5428): Cross-bundle caching respecting cache tokens.
self._cache = {}
class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState):
def __init__(self, underlying_bag_state):
self._underlying_bag_state = underlying_bag_state
def read(self): # type: () -> Any
values = list(self._underlying_bag_state.read())
if not values:
return None
return values[0]
def write(self, value): # type: (Any) -> None
self.clear()
self._underlying_bag_state.add(value)
def clear(self): # type: () -> None
self._underlying_bag_state.clear()
def commit(self): # type: () -> None
self._underlying_bag_state.commit()
class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState):
def __init__(self, underlying_bag_state, combinefn):
# type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None
self._combinefn = combinefn
self._combinefn.setup()
self._underlying_bag_state = underlying_bag_state
self._finalized = False
def _read_accumulator(self, rewrite=True):
merged_accumulator = self._combinefn.merge_accumulators(
self._underlying_bag_state.read())
if rewrite:
self._underlying_bag_state.clear()
self._underlying_bag_state.add(merged_accumulator)
return merged_accumulator
def read(self):
# type: () -> Iterable[Any]
return self._combinefn.extract_output(self._read_accumulator())
def add(self, value):
# type: (Any) -> None
# Prefer blind writes, but don't let them grow unboundedly.
# This should be tuned to be much lower, but for now exercise
# both paths well.
if random.random() < 0.5:
accumulator = self._read_accumulator(False)
self._underlying_bag_state.clear()
else:
accumulator = self._combinefn.create_accumulator()
self._underlying_bag_state.add(
self._combinefn.add_input(accumulator, value))
def clear(self):
# type: () -> None
self._underlying_bag_state.clear()
def commit(self):
self._underlying_bag_state.commit()
def finalize(self):
if not self._finalized:
self._combinefn.teardown()
self._finalized = True
class _ConcatIterable(object):
"""An iterable that is the concatination of two iterables.
Unlike itertools.chain, this allows reiteration.
"""
def __init__(self, first, second):
# type: (Iterable[Any], Iterable[Any]) -> None
self.first = first
self.second = second
def __iter__(self):
# type: () -> Iterator[Any]
for elem in self.first:
yield elem
for elem in self.second:
yield elem
coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatIterable)
class SynchronousBagRuntimeState(userstate.BagRuntimeState):
def __init__(self,
state_handler, # type: sdk_worker.CachingStateHandler
state_key, # type: beam_fn_api_pb2.StateKey
value_coder # type: coders.Coder
):
# type: (...) -> None
self._state_handler = state_handler
self._state_key = state_key
self._value_coder = value_coder
self._cleared = False
self._added_elements = [] # type: List[Any]
def read(self):
# type: () -> Iterable[Any]
return _ConcatIterable([] if self._cleared else cast(
'Iterable[Any]',
_StateBackedIterable(
self._state_handler, self._state_key, self._value_coder)),
self._added_elements)
def add(self, value):
# type: (Any) -> None
self._added_elements.append(value)
def clear(self):
# type: () -> None
self._cleared = True
self._added_elements = []
def commit(self):
# type: () -> None
to_await = None
if self._cleared:
to_await = self._state_handler.clear(self._state_key)
if self._added_elements:
to_await = self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), self._added_elements)
if to_await:
# To commit, we need to wait on the last state request future to complete.
to_await.get()
class SynchronousSetRuntimeState(userstate.SetRuntimeState):
def __init__(self,
state_handler, # type: sdk_worker.CachingStateHandler
state_key, # type: beam_fn_api_pb2.StateKey
value_coder # type: coders.Coder
):
# type: (...) -> None
self._state_handler = state_handler
self._state_key = state_key
self._value_coder = value_coder
self._cleared = False
self._added_elements = set() # type: Set[Any]
def _compact_data(self, rewrite=True):
accumulator = set(
_ConcatIterable(
set() if self._cleared else _StateBackedIterable(
self._state_handler, self._state_key, self._value_coder),
self._added_elements))
if rewrite and accumulator:
self._state_handler.clear(self._state_key)
self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), accumulator)
# Since everthing is already committed so we can safely reinitialize
# added_elements here.
self._added_elements = set()
return accumulator
def read(self):
# type: () -> Set[Any]
return self._compact_data(rewrite=False)
def add(self, value):
# type: (Any) -> None
if self._cleared:
# This is a good time explicitly clear.
self._state_handler.clear(self._state_key)
self._cleared = False
self._added_elements.add(value)
if random.random() > 0.5:
self._compact_data()
def clear(self):
# type: () -> None
self._cleared = True
self._added_elements = set()
def commit(self):
# type: () -> None
to_await = None
if self._cleared:
to_await = self._state_handler.clear(self._state_key)
if self._added_elements:
to_await = self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), self._added_elements)
if to_await:
# To commit, we need to wait on the last state request future to complete.
to_await.get()
class OutputTimer(userstate.BaseTimer):
def __init__(self,
key,
window, # type: BoundedWindow
timestamp, # type: timestamp.Timestamp
paneinfo, # type: windowed_value.PaneInfo
time_domain, # type: str
timer_family_id, # type: str
timer_coder_impl, # type: coder_impl.TimerCoderImpl
output_stream # type: data_plane.ClosableOutputStream
):
self._key = key
self._window = window
self._input_timestamp = timestamp
self._paneinfo = paneinfo
self._time_domain = time_domain
self._timer_family_id = timer_family_id
self._output_stream = output_stream
self._timer_coder_impl = timer_coder_impl
def set(self, ts: timestamp.TimestampTypes, dynamic_timer_tag='') -> None:
ts = timestamp.Timestamp.of(ts)
timer = userstate.Timer(
user_key=self._key,
dynamic_timer_tag=dynamic_timer_tag,
windows=(self._window, ),
clear_bit=False,
fire_timestamp=ts,
hold_timestamp=ts if TimeDomain.is_event_time(self._time_domain) else
self._input_timestamp,
paneinfo=self._paneinfo)
self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True)
self._output_stream.maybe_flush()
def clear(self, dynamic_timer_tag='') -> None:
timer = userstate.Timer(
user_key=self._key,
dynamic_timer_tag=dynamic_timer_tag,
windows=(self._window, ),
clear_bit=True,
fire_timestamp=None,
hold_timestamp=None,
paneinfo=None)
self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True)
self._output_stream.maybe_flush()
class TimerInfo(object):
"""A data class to store information related to a timer."""
def __init__(self, timer_coder_impl, output_stream=None):
self.timer_coder_impl = timer_coder_impl
self.output_stream = output_stream
class FnApiUserStateContext(userstate.UserStateContext):
"""Interface for state and timers from SDK to Fn API servicer of state.."""
def __init__(self,
state_handler, # type: sdk_worker.CachingStateHandler
transform_id, # type: str
key_coder, # type: coders.Coder
window_coder, # type: coders.Coder
):
# type: (...) -> None
"""Initialize a ``FnApiUserStateContext``.
Args:
state_handler: A StateServicer object.
transform_id: The name of the PTransform that this context is associated.
key_coder: Coder for the key type.
window_coder: Coder for the window type.
"""
self._state_handler = state_handler
self._transform_id = transform_id
self._key_coder = key_coder
self._window_coder = window_coder
# A mapping of {timer_family_id: TimerInfo}
self._timers_info = {} # type: Dict[str, TimerInfo]
self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes]
def add_timer_info(self, timer_family_id, timer_info):
# type: (str, TimerInfo) -> None
self._timers_info[timer_family_id] = timer_info
def get_timer(
self, timer_spec: userstate.TimerSpec, key, window, timestamp,
pane) -> OutputTimer:
assert self._timers_info[timer_spec.name].output_stream is not None
timer_coder_impl = self._timers_info[timer_spec.name].timer_coder_impl
output_stream = self._timers_info[timer_spec.name].output_stream
return OutputTimer(
key,
window,
timestamp,
pane,
timer_spec.time_domain,
timer_spec.name,
timer_coder_impl,
output_stream)
def get_state(self, *args):
# type: (*Any) -> FnApiUserRuntimeStateTypes
state_handle = self._all_states.get(args)
if state_handle is None:
state_handle = self._all_states[args] = self._create_state(*args)
return state_handle
def _create_state(self,
state_spec, # type: userstate.StateSpec
key,
window # type: BoundedWindow
):
# type: (...) -> FnApiUserRuntimeStateTypes
if isinstance(state_spec,
(userstate.BagStateSpec,
userstate.CombiningValueStateSpec,
userstate.ReadModifyWriteStateSpec)):
bag_state = SynchronousBagRuntimeState(
self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
transform_id=self._transform_id,
user_state_id=state_spec.name,
window=self._window_coder.encode(window),
# State keys are expected in nested encoding format
key=self._key_coder.encode_nested(key))),
value_coder=state_spec.coder)
if isinstance(state_spec, userstate.BagStateSpec):
return bag_state
elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
return ReadModifyWriteRuntimeState(bag_state)
else:
return CombiningValueRuntimeState(
bag_state, copy.deepcopy(state_spec.combine_fn))
elif isinstance(state_spec, userstate.SetStateSpec):
return SynchronousSetRuntimeState(
self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
transform_id=self._transform_id,
user_state_id=state_spec.name,
window=self._window_coder.encode(window),
# State keys are expected in nested encoding format
key=self._key_coder.encode_nested(key))),
value_coder=state_spec.coder)
else:
raise NotImplementedError(state_spec)
def commit(self):
# type: () -> None
for state in self._all_states.values():
state.commit()
def reset(self):
# type: () -> None
for state in self._all_states.values():
state.finalize()
self._all_states = {}
def memoize(func):
cache = {}
missing = object()
def wrapper(*args):
result = cache.get(args, missing)
if result is missing:
result = cache[args] = func(*args)
return result
return wrapper
def only_element(iterable):
# type: (Iterable[T]) -> T
element, = iterable
return element
def _environments_compatible(submission, runtime):
# type: (str, str) -> bool
if submission == runtime:
return True
if 'rc' in submission and runtime in submission:
# TODO(https://github.com/apache/beam/issues/28084): Loosen
# the check for RCs until RC containers install the matching version.
return True
return False
def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor):
# type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
runtime_sdk = environments.sdk_base_version_capability()
for t in process_bundle_descriptor.transforms.values():
env = process_bundle_descriptor.environments[t.environment_id]
for c in env.capabilities:
if (c.startswith(environments.SDK_VERSION_CAPABILITY_PREFIX) and
not _environments_compatible(c, runtime_sdk)):
raise RuntimeError(
"Pipeline construction environment and pipeline runtime "
"environment are not compatible. If you use a custom "
"container image, check that the Python interpreter minor version "
"and the Apache Beam version in your image match the versions "
"used at pipeline construction time. "
f"Submission environment: {c}. "
f"Runtime environment: {runtime_sdk}.")
# TODO: Consider warning on mismatches in versions of installed packages.
class BundleProcessor(object):
""" A class for processing bundles of elements. """
def __init__(self,
runner_capabilities, # type: FrozenSet[str]
process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
state_handler, # type: sdk_worker.CachingStateHandler
data_channel_factory, # type: data_plane.DataChannelFactory
data_sampler=None, # type: Optional[data_sampler.DataSampler]
):
# type: (...) -> None
"""Initialize a bundle processor.
Args:
runner_capabilities (``FrozenSet[str]``): The set of capabilities of the
runner with which we will be interacting
process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``):
a description of the stage that this ``BundleProcessor``is to execute.
state_handler (CachingStateHandler).
data_channel_factory (``data_plane.DataChannelFactory``).
"""
self.runner_capabilities = runner_capabilities
self.process_bundle_descriptor = process_bundle_descriptor
self.state_handler = state_handler
self.data_channel_factory = data_channel_factory
self.data_sampler = data_sampler
self.current_instruction_id = None # type: Optional[str]
_verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor)
# There is no guarantee that the runner only set
# timer_api_service_descriptor when having timers. So this field cannot be
# used as an indicator of timers.
if self.process_bundle_descriptor.timer_api_service_descriptor.url:
self.timer_data_channel = (
data_channel_factory.create_data_channel_from_url(
self.process_bundle_descriptor.timer_api_service_descriptor.url))
else:
self.timer_data_channel = None
# A mapping of
# {(transform_id, timer_family_id): TimerInfo}
# The mapping is empty when there is no timer_family_specs in the
# ProcessBundleDescriptor.
self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo]
# TODO(robertwb): Figure out the correct prefix to use for output counters
# from StateSampler.
self.counter_factory = counters.CounterFactory()
self.state_sampler = statesampler.StateSampler(
'fnapi-step-%s' % self.process_bundle_descriptor.id,
self.counter_factory)
self.ops = self.create_execution_tree(self.process_bundle_descriptor)
for op in reversed(self.ops.values()):
op.setup(self.data_sampler)
self.splitting_lock = threading.Lock()
def create_execution_tree(
self,
descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor
):
# type: (...) -> collections.OrderedDict[str, operations.DoOperation]
transform_factory = BeamTransformFactory(
self.runner_capabilities,
descriptor,
self.data_channel_factory,
self.counter_factory,
self.state_sampler,
self.state_handler,
self.data_sampler,
)
self.timers_info = transform_factory.extract_timers_info()
def is_side_input(transform_proto, tag):
if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
return tag in proto_utils.parse_Bytes(
transform_proto.spec.payload,
beam_runner_api_pb2.ParDoPayload).side_inputs
pcoll_consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[str]]
for transform_id, transform_proto in descriptor.transforms.items():
for tag, pcoll_id in transform_proto.inputs.items():
if not is_side_input(transform_proto, tag):
pcoll_consumers[pcoll_id].append(transform_id)
@memoize
def get_operation(transform_id):
# type: (str) -> operations.Operation
transform_consumers = {
tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
for tag,
pcoll_id in descriptor.transforms[transform_id].outputs.items()
}
# Initialize transform-specific state in the Data Sampler.
if self.data_sampler:
self.data_sampler.initialize_samplers(
transform_id, descriptor, transform_factory.get_coder)
return transform_factory.create_operation(
transform_id, transform_consumers)
# Operations must be started (hence returned) in order.
@memoize
def topological_height(transform_id):
# type: (str) -> int
return 1 + max([0] + [
topological_height(consumer)
for pcoll in descriptor.transforms[transform_id].outputs.values()
for consumer in pcoll_consumers[pcoll]
])
return collections.OrderedDict([(
transform_id,
cast(operations.DoOperation,
get_operation(transform_id))) for transform_id in sorted(
descriptor.transforms, key=topological_height, reverse=True)])
def reset(self):
# type: () -> None
self.counter_factory.reset()
self.state_sampler.reset()
# Side input caches.
for op in self.ops.values():
op.reset()
def process_bundle(self, instruction_id):
# type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]
expected_input_ops = [] # type: List[DataInputOperation]
for op in self.ops.values():
if isinstance(op, DataOutputOperation):
# TODO(robertwb): Is there a better way to pass the instruction id to
# the operation?
op.set_output_stream(
op.data_channel.output_stream(instruction_id, op.transform_id))
elif isinstance(op, DataInputOperation):
# We must wait until we receive "end of stream" for each of these ops.
expected_input_ops.append(op)
try:
execution_context = ExecutionContext(instruction_id=instruction_id)
self.current_instruction_id = instruction_id
self.state_sampler.start()
# Start all operations.
for op in reversed(self.ops.values()):
_LOGGER.debug('start %s', op)
op.execution_context = execution_context
op.start()
# Each data_channel is mapped to a list of expected inputs which includes
# both data input and timer input. The data input is identied by
# transform_id. The data input is identified by
# (transform_id, timer_family_id).
data_channels = collections.defaultdict(
list
) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]]
# Add expected data inputs for each data channel.
input_op_by_transform_id = {}
for input_op in expected_input_ops:
data_channels[input_op.data_channel].append(input_op.transform_id)
input_op_by_transform_id[input_op.transform_id] = input_op
# Update timer_data channel with expected timer inputs.
if self.timer_data_channel:
data_channels[self.timer_data_channel].extend(
list(self.timers_info.keys()))
# Set up timer output stream for DoOperation.
for ((transform_id, timer_family_id),
timer_info) in self.timers_info.items():
output_stream = self.timer_data_channel.output_timer_stream(
instruction_id, transform_id, timer_family_id)
timer_info.output_stream = output_stream
self.ops[transform_id].add_timer_info(timer_family_id, timer_info)
# Process data and timer inputs
for data_channel, expected_inputs in data_channels.items():
for element in data_channel.input_elements(instruction_id,
expected_inputs):
if isinstance(element, beam_fn_api_pb2.Elements.Timers):
timer_coder_impl = (
self.timers_info[(
element.transform_id,
element.timer_family_id)].timer_coder_impl)
for timer_data in timer_coder_impl.decode_all(element.timers):
self.ops[element.transform_id].process_timer(
element.timer_family_id, timer_data)
elif isinstance(element, beam_fn_api_pb2.Elements.Data):
input_op_by_transform_id[element.transform_id].process_encoded(
element.data)
# Finish all operations.
for op in self.ops.values():
_LOGGER.debug('finish %s', op)
op.finish()
# Close every timer output stream
for timer_info in self.timers_info.values():
assert timer_info.output_stream is not None
timer_info.output_stream.close()
return ([
self.delayed_bundle_application(op, residual) for op,
residual in execution_context.delayed_applications
],
self.requires_finalization())
finally:
# Ensure any in-flight split attempts complete.
with self.splitting_lock:
self.current_instruction_id = None
self.state_sampler.stop_if_still_running()
def finalize_bundle(self):
# type: () -> beam_fn_api_pb2.FinalizeBundleResponse
for op in self.ops.values():
op.finalize_bundle()
return beam_fn_api_pb2.FinalizeBundleResponse()
def requires_finalization(self):
# type: () -> bool
return any(op.needs_finalization() for op in self.ops.values())
def try_split(self, bundle_split_request):
# type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse
split_response = beam_fn_api_pb2.ProcessBundleSplitResponse()
with self.splitting_lock:
if bundle_split_request.instruction_id != self.current_instruction_id:
# This may be a delayed split for a former bundle, see BEAM-12475.
return split_response
for op in self.ops.values():
if isinstance(op, DataInputOperation):
desired_split = bundle_split_request.desired_splits.get(
op.transform_id)
if desired_split:
split = op.try_split(
desired_split.fraction_of_remainder,
desired_split.estimated_input_elements,
desired_split.allowed_split_points)
if split:
(
primary_end,
element_primaries,
element_residuals,
residual_start,
) = split
for element_primary in element_primaries:
split_response.primary_roots.add().CopyFrom(
self.bundle_application(*element_primary))
for element_residual in element_residuals:
split_response.residual_roots.add().CopyFrom(
self.delayed_bundle_application(*element_residual))
split_response.channel_splits.extend([
beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit(
transform_id=op.transform_id,
last_primary_element=primary_end,
first_residual_element=residual_start)
])
return split_response
def delayed_bundle_application(self,
op, # type: operations.DoOperation
deferred_remainder # type: SplitResultResidual
):
# type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
assert op.input_info is not None
# TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
(element_and_restriction, current_watermark, deferred_timestamp) = (
deferred_remainder)
if deferred_timestamp:
assert isinstance(deferred_timestamp, timestamp.Duration)
proto_deferred_watermark = proto_utils.from_micros(
duration_pb2.Duration,
deferred_timestamp.micros) # type: Optional[duration_pb2.Duration]
else:
proto_deferred_watermark = None
return beam_fn_api_pb2.DelayedBundleApplication(
requested_time_delay=proto_deferred_watermark,
application=self.construct_bundle_application(
op.input_info, current_watermark, element_and_restriction))
def bundle_application(self,
op, # type: operations.DoOperation
primary # type: SplitResultPrimary
):
# type: (...) -> beam_fn_api_pb2.BundleApplication
assert op.input_info is not None
return self.construct_bundle_application(
op.input_info, None, primary.primary_value)
def construct_bundle_application(self,
op_input_info, # type: operations.OpInputInfo
output_watermark, # type: Optional[timestamp.Timestamp]
element
):
# type: (...) -> beam_fn_api_pb2.BundleApplication
transform_id, main_input_tag, main_input_coder, outputs = op_input_info
if output_watermark:
proto_output_watermark = proto_utils.from_micros(
timestamp_pb2.Timestamp, output_watermark.micros)
output_watermarks = {
output: proto_output_watermark
for output in outputs
} # type: Optional[Dict[str, timestamp_pb2.Timestamp]]
else:
output_watermarks = None
return beam_fn_api_pb2.BundleApplication(
transform_id=transform_id,
input_id=main_input_tag,
output_watermarks=output_watermarks,
element=main_input_coder.get_impl().encode_nested(element))
def monitoring_infos(self):
# type: () -> List[metrics_pb2.MonitoringInfo]
"""Returns the list of MonitoringInfos collected processing this bundle."""
# Construct a new dict first to remove duplicates.
all_monitoring_infos_dict = {}
for transform_id, op in self.ops.items():
tag_to_pcollection_id = self.process_bundle_descriptor.transforms[
transform_id].outputs
all_monitoring_infos_dict.update(
op.monitoring_infos(transform_id, dict(tag_to_pcollection_id)))
return list(all_monitoring_infos_dict.values())
def shutdown(self):
# type: () -> None
for op in self.ops.values():
op.teardown()
@dataclass
class ExecutionContext:
# Any splits to be processed later.
delayed_applications: List[Tuple[operations.DoOperation,
common.SplitResultResidual]] = field(
default_factory=list)
# The exception sampler for the currently executing PTransform.
output_sampler: Optional[data_sampler.OutputSampler] = None
# The current instruction being executed.
instruction_id: Optional[str] = None
class BeamTransformFactory(object):
"""Factory for turning transform_protos into executable operations."""
def __init__(self,
runner_capabilities, # type: FrozenSet[str]
descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
data_channel_factory, # type: data_plane.DataChannelFactory
counter_factory, # type: counters.CounterFactory
state_sampler, # type: statesampler.StateSampler
state_handler, # type: sdk_worker.CachingStateHandler
data_sampler, # type: Optional[data_sampler.DataSampler]
):
self.runner_capabilities = runner_capabilities
self.descriptor = descriptor
self.data_channel_factory = data_channel_factory
self.counter_factory = counter_factory
self.state_sampler = state_sampler
self.state_handler = state_handler
self.context = pipeline_context.PipelineContext(
descriptor,
iterable_state_read=lambda token,
element_coder_impl: _StateBackedIterable(
state_handler,
beam_fn_api_pb2.StateKey(
runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
element_coder_impl))
self.data_sampler = data_sampler
_known_urns = {
} # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]]
@classmethod
def register_urn(
cls,
urn, # type: str
parameter_type # type: Optional[Type[T]]
):
# type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]]
def wrapper(func):
cls._known_urns[urn] = func, parameter_type
return func
return wrapper
def create_operation(self,
transform_id, # type: str
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.Operation
transform_proto = self.descriptor.transforms[transform_id]
if not transform_proto.unique_name:
_LOGGER.debug("No unique name set for transform %s" % transform_id)
transform_proto.unique_name = transform_id
creator, parameter_type = self._known_urns[transform_proto.spec.urn]
payload = proto_utils.parse_Bytes(
transform_proto.spec.payload, parameter_type)
return creator(self, transform_id, transform_proto, payload, consumers)
def extract_timers_info(self):
# type: () -> Dict[Tuple[str, str], TimerInfo]
timers_info = {}
for transform_id, transform_proto in self.descriptor.transforms.items():
if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
pardo_payload = proto_utils.parse_Bytes(
transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload)
for (timer_family_id,
timer_family_spec) in pardo_payload.timer_family_specs.items():
timer_coder_impl = self.get_coder(
timer_family_spec.timer_family_coder_id).get_impl()
# The output_stream should be updated when processing a bundle.
timers_info[(transform_id, timer_family_id)] = TimerInfo(
timer_coder_impl=timer_coder_impl)
return timers_info
def get_coder(self, coder_id):
# type: (str) -> coders.Coder
if coder_id not in self.descriptor.coders:
raise KeyError("No such coder: %s" % coder_id)
coder_proto = self.descriptor.coders[coder_id]
if coder_proto.spec.urn:
return self.context.coders.get_by_id(coder_id)
else:
# No URN, assume cloud object encoding json bytes.
return operation_specs.get_coder_from_spec(
json.loads(coder_proto.spec.payload.decode('utf-8')))
def get_windowed_coder(self, pcoll_id):
# type: (str) -> WindowedValueCoder
coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
# TODO(robertwb): Remove this condition once all runners are consistent.
if not isinstance(coder, WindowedValueCoder):
windowing_strategy = self.descriptor.windowing_strategies[
self.descriptor.pcollections[pcoll_id].windowing_strategy_id]
return WindowedValueCoder(
coder, self.get_coder(windowing_strategy.window_coder_id))
else:
return coder
def get_output_coders(self, transform_proto):
# type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder]
return {
tag: self.get_windowed_coder(pcoll_id)
for tag,
pcoll_id in transform_proto.outputs.items()
}
def get_only_output_coder(self, transform_proto):
# type: (beam_runner_api_pb2.PTransform) -> coders.Coder
return only_element(self.get_output_coders(transform_proto).values())
def get_input_coders(self, transform_proto):
# type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder]
return {
tag: self.get_windowed_coder(pcoll_id)
for tag,
pcoll_id in transform_proto.inputs.items()
}
def get_only_input_coder(self, transform_proto):
# type: (beam_runner_api_pb2.PTransform) -> coders.Coder
return only_element(list(self.get_input_coders(transform_proto).values()))
def get_input_windowing(self, transform_proto):
# type: (beam_runner_api_pb2.PTransform) -> Windowing
pcoll_id = only_element(transform_proto.inputs.values())
windowing_strategy_id = self.descriptor.pcollections[
pcoll_id].windowing_strategy_id
return self.context.windowing_strategies.get_by_id(windowing_strategy_id)
# TODO(robertwb): Update all operations to take these in the constructor.
@staticmethod
def augment_oldstyle_op(
op, # type: OperationT
step_name, # type: str
consumers, # type: Mapping[str, Iterable[operations.Operation]]
tag_list=None # type: Optional[List[str]]
):
# type: (...) -> OperationT
op.step_name = step_name
for tag, op_consumers in consumers.items():
for consumer in op_consumers:
op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
return op
@BeamTransformFactory.register_urn(
DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
def create_source_runner(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> DataInputOperation
output_coder = factory.get_coder(grpc_port.coder_id)
return DataInputOperation(
common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
consumers,
factory.counter_factory,
factory.state_sampler,
output_coder,
transform_id=transform_id,
data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
@BeamTransformFactory.register_urn(
DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
def create_sink_runner(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> DataOutputOperation
output_coder = factory.get_coder(grpc_port.coder_id)
return DataOutputOperation(
common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
consumers,
factory.counter_factory,
factory.state_sampler,
output_coder,
transform_id=transform_id,
data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
@BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None)
def create_source_java(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter,
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.ReadOperation
# The Dataflow runner harness strips the base64 encoding.
source = pickler.loads(base64.b64encode(parameter))
spec = operation_specs.WorkerRead(
iobase.SourceBundle(1.0, source, None, None),
[factory.get_only_output_coder(transform_proto)])
return factory.augment_oldstyle_op(
operations.ReadOperation(
common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(
common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload)
def create_deprecated_read(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter, # type: beam_runner_api_pb2.ReadPayload
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.ReadOperation
source = iobase.BoundedSource.from_runner_api(
parameter.source, factory.context)
spec = operation_specs.WorkerRead(
iobase.SourceBundle(1.0, source, None, None),
[WindowedValueCoder(source.default_output_coder())])
return factory.augment_oldstyle_op(
operations.ReadOperation(
common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(
python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
def create_read_from_impulse_python(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter, # type: beam_runner_api_pb2.ReadPayload
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.ImpulseReadOperation
return operations.ImpulseReadOperation(
common.NameContext(transform_proto.unique_name, transform_id),
factory.counter_factory,
factory.state_sampler,
consumers,
iobase.BoundedSource.from_runner_api(parameter.source, factory.context),
factory.get_only_output_coder(transform_proto))
@BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None)
def create_dofn_javasdk(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
serialized_fn,
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_pardo_operation(
factory, transform_id, transform_proto, consumers, serialized_fn)
@BeamTransformFactory.register_urn(
common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
beam_runner_api_pb2.ParDoPayload)
def create_pair_with_restriction(*args):
class PairWithRestriction(beam.DoFn):
def __init__(self, fn, restriction_provider, watermark_estimator_provider):
self.restriction_provider = restriction_provider
self.watermark_estimator_provider = watermark_estimator_provider
def process(self, element, *args, **kwargs):
# TODO(SDF): Do we want to allow mutation of the element?
# (E.g. it could be nice to shift bulky description to the portion
# that can be distributed.)
initial_restriction = self.restriction_provider.initial_restriction(
element)
initial_estimator_state = (
self.watermark_estimator_provider.initial_estimator_state(
element, initial_restriction))
yield (element, (initial_restriction, initial_estimator_state))
return _create_sdf_operation(PairWithRestriction, *args)
@BeamTransformFactory.register_urn(
common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
beam_runner_api_pb2.ParDoPayload)
def create_split_and_size_restrictions(*args):
class SplitAndSizeRestrictions(beam.DoFn):
def __init__(self, fn, restriction_provider, watermark_estimator_provider):
self.restriction_provider = restriction_provider
self.watermark_estimator_provider = watermark_estimator_provider
def process(self, element_restriction, *args, **kwargs):
element, (restriction, _) = element_restriction
for part, size in self.restriction_provider.split_and_size(
element, restriction):
if size < 0:
raise ValueError('Expected size >= 0 but received %s.' % size)
estimator_state = (
self.watermark_estimator_provider.initial_estimator_state(
element, part))
yield ((element, (part, estimator_state)), size)
return _create_sdf_operation(SplitAndSizeRestrictions, *args)
@BeamTransformFactory.register_urn(
common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
beam_runner_api_pb2.ParDoPayload)
def create_truncate_sized_restriction(*args):
class TruncateAndSizeRestriction(beam.DoFn):
def __init__(self, fn, restriction_provider, watermark_estimator_provider):
self.restriction_provider = restriction_provider
def process(self, element_restriction, *args, **kwargs):
((element, (restriction, estimator_state)), _) = element_restriction
truncated_restriction = self.restriction_provider.truncate(
element, restriction)
if truncated_restriction:
truncated_restriction_size = (
self.restriction_provider.restriction_size(
element, truncated_restriction))
if truncated_restriction_size < 0:
raise ValueError(
'Expected size >= 0 but received %s.' %
truncated_restriction_size)
yield ((element, (truncated_restriction, estimator_state)),
truncated_restriction_size)
return _create_sdf_operation(
TruncateAndSizeRestriction,
*args,
operation_cls=operations.SdfTruncateSizedRestrictions)
@BeamTransformFactory.register_urn(
common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
beam_runner_api_pb2.ParDoPayload)
def create_process_sized_elements_and_restrictions(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter, # type: beam_runner_api_pb2.ParDoPayload
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
core.DoFnInfo.from_runner_api(parameter.do_fn,
factory.context).serialized_dofn_data(),
parameter,
operation_cls=operations.SdfProcessSizedElements)
def _create_sdf_operation(
proxy_dofn,
factory,
transform_id,
transform_proto,
parameter,
consumers,
operation_cls=operations.DoOperation):
dofn_data = pickler.loads(parameter.do_fn.payload)
dofn = dofn_data[0]
restriction_provider = common.DoFnSignature(dofn).get_restriction_provider()
watermark_estimator_provider = (
common.DoFnSignature(dofn).get_watermark_estimator_provider())
serialized_fn = pickler.dumps(
(proxy_dofn(dofn, restriction_provider, watermark_estimator_provider), ) +
dofn_data[1:])
return _create_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
serialized_fn,
parameter,
operation_cls=operation_cls)
@BeamTransformFactory.register_urn(
common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
def create_par_do(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter, # type: beam_runner_api_pb2.ParDoPayload
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.DoOperation
return _create_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
core.DoFnInfo.from_runner_api(parameter.do_fn,
factory.context).serialized_dofn_data(),
parameter)
def _create_pardo_operation(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
consumers,
serialized_fn,
pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload]
operation_cls=operations.DoOperation
):
if pardo_proto and pardo_proto.side_inputs:
input_tags_to_coders = factory.get_input_coders(transform_proto)
tagged_side_inputs = [
(tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
for tag,
si in pardo_proto.side_inputs.items()
]
tagged_side_inputs.sort(
key=lambda tag_si: sideinputs.get_sideinput_index(tag_si[0]))
side_input_maps = [
StateBackedSideInputMap(
factory.state_handler,
transform_id,
tag,
si,
input_tags_to_coders[tag],
use_bulk_read=(
common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn
in factory.runner_capabilities))
for (tag, si) in tagged_side_inputs
]
else:
side_input_maps = []
output_tags = list(transform_proto.outputs.keys())
dofn_data = pickler.loads(serialized_fn)
if not dofn_data[-1]:
# Windowing not set.
if pardo_proto:
other_input_tags = set.union(
set(pardo_proto.side_inputs),
set(pardo_proto.timer_family_specs)) # type: Container[str]
else:
other_input_tags = ()
pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
if tag not in other_input_tags]
windowing = factory.context.windowing_strategies.get_by_id(
factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, ))
if pardo_proto and (pardo_proto.timer_family_specs or pardo_proto.state_specs
or pardo_proto.restriction_coder_id):
found_input_coder = None
for tag, pcoll_id in transform_proto.inputs.items():
if tag in pardo_proto.side_inputs:
pass
else:
# Must be the main input
assert found_input_coder is None
main_input_tag = tag
found_input_coder = factory.get_windowed_coder(pcoll_id)
assert found_input_coder is not None
main_input_coder = found_input_coder
if pardo_proto.timer_family_specs or pardo_proto.state_specs:
user_state_context = FnApiUserStateContext(
factory.state_handler,
transform_id,
main_input_coder.key_coder(),
main_input_coder.window_coder
) # type: Optional[FnApiUserStateContext]
else:
user_state_context = None
else:
user_state_context = None
output_coders = factory.get_output_coders(transform_proto)
spec = operation_specs.WorkerDoFn(
serialized_fn=serialized_fn,
output_tags=output_tags,
input=None,
side_inputs=None, # Fn API uses proto definitions and the Fn State API
output_coders=[output_coders[tag] for tag in output_tags])
result = factory.augment_oldstyle_op(
operation_cls(
common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler,
side_input_maps,
user_state_context),
transform_proto.unique_name,
consumers,
output_tags)
if pardo_proto and pardo_proto.restriction_coder_id:
result.input_info = operations.OpInputInfo(
transform_id,
main_input_tag,
main_input_coder,
transform_proto.outputs.keys())
return result
def _create_simple_pardo_operation(factory, # type: BeamTransformFactory
transform_id,
transform_proto,
consumers,
dofn, # type: beam.DoFn
):
serialized_fn = pickler.dumps((dofn, (), {}, [], None))
return _create_pardo_operation(
factory, transform_id, transform_proto, consumers, serialized_fn)
@BeamTransformFactory.register_urn(
common_urns.primitives.ASSIGN_WINDOWS.urn,
beam_runner_api_pb2.WindowingStrategy)
def create_assign_windows(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter, # type: beam_runner_api_pb2.WindowingStrategy
consumers # type: Dict[str, List[operations.Operation]]
):
class WindowIntoDoFn(beam.DoFn):
def __init__(self, windowing):
self.windowing = windowing
def process(
self,
element,
timestamp=beam.DoFn.TimestampParam,
window=beam.DoFn.WindowParam):
new_windows = self.windowing.windowfn.assign(
WindowFn.AssignContext(timestamp, element=element, window=window))
yield WindowedValue(element, timestamp, new_windows)
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.window import WindowFn
windowing = Windowing.from_runner_api(parameter, factory.context)
return _create_simple_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
WindowIntoDoFn(windowing))
@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
def create_identity_dofn(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
parameter,
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.FlattenOperation
return factory.augment_oldstyle_op(
operations.FlattenOperation(
common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerFlatten(
None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(
common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn,
beam_runner_api_pb2.CombinePayload)
def create_combine_per_key_precombine(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload, # type: beam_runner_api_pb2.CombinePayload
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.PGBKCVOperation
serialized_combine_fn = pickler.dumps((
beam.CombineFn.from_runner_api(payload.combine_fn,
factory.context), [], {}))
return factory.augment_oldstyle_op(
operations.PGBKCVOperation(
common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerPartialGroupByKey(
serialized_combine_fn,
None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler,
factory.get_input_windowing(transform_proto)),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(
common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
beam_runner_api_pb2.CombinePayload)
def create_combbine_per_key_merge_accumulators(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload, # type: beam_runner_api_pb2.CombinePayload
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_combine_phase_operation(
factory, transform_id, transform_proto, payload, consumers, 'merge')
@BeamTransformFactory.register_urn(
common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
beam_runner_api_pb2.CombinePayload)
def create_combine_per_key_extract_outputs(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload, # type: beam_runner_api_pb2.CombinePayload
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_combine_phase_operation(
factory, transform_id, transform_proto, payload, consumers, 'extract')
@BeamTransformFactory.register_urn(
common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn,
beam_runner_api_pb2.CombinePayload)
def create_combine_per_key_convert_to_accumulators(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload, # type: beam_runner_api_pb2.CombinePayload
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_combine_phase_operation(
factory, transform_id, transform_proto, payload, consumers, 'convert')
@BeamTransformFactory.register_urn(
common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
beam_runner_api_pb2.CombinePayload)
def create_combine_grouped_values(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload, # type: beam_runner_api_pb2.CombinePayload
consumers # type: Dict[str, List[operations.Operation]]
):
return _create_combine_phase_operation(
factory, transform_id, transform_proto, payload, consumers, 'all')
def _create_combine_phase_operation(
factory, transform_id, transform_proto, payload, consumers, phase):
# type: (...) -> operations.CombineOperation
serialized_combine_fn = pickler.dumps((
beam.CombineFn.from_runner_api(payload.combine_fn,
factory.context), [], {}))
return factory.augment_oldstyle_op(
operations.CombineOperation(
common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerCombineFn(
serialized_combine_fn,
phase,
None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None)
def create_flatten(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
payload,
consumers # type: Dict[str, List[operations.Operation]]
):
# type: (...) -> operations.FlattenOperation
return factory.augment_oldstyle_op(
operations.FlattenOperation(
common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerFlatten(
None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
@BeamTransformFactory.register_urn(
common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec)
def create_map_windows(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec
consumers # type: Dict[str, List[operations.Operation]]
):
assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN
window_mapping_fn = pickler.loads(mapping_fn_spec.payload)
class MapWindows(beam.DoFn):
def process(self, element):
key, window = element
return [(key, window_mapping_fn(window))]
return _create_simple_pardo_operation(
factory, transform_id, transform_proto, consumers, MapWindows())
@BeamTransformFactory.register_urn(
common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec)
def create_merge_windows(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec
consumers # type: Dict[str, List[operations.Operation]]
):
assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN
window_fn = pickler.loads(mapping_fn_spec.payload)
class MergeWindows(beam.DoFn):
def process(self, element):
nonce, windows = element
original_windows = set(windows) # type: Set[window.BoundedWindow]
merged_windows = collections.defaultdict(
set
) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821
class RecordingMergeContext(window.WindowFn.MergeContext):
def merge(
self,
to_be_merged, # type: Iterable[window.BoundedWindow]
merge_result, # type: window.BoundedWindow
):
originals = merged_windows[merge_result]
for window in to_be_merged:
if window in original_windows:
originals.add(window)
original_windows.remove(window)
else:
originals.update(merged_windows.pop(window))
window_fn.merge(RecordingMergeContext(windows))
yield nonce, (original_windows, merged_windows.items())
return _create_simple_pardo_operation(
factory, transform_id, transform_proto, consumers, MergeWindows())
@BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None)
def create_to_string_fn(
factory, # type: BeamTransformFactory
transform_id, # type: str
transform_proto, # type: beam_runner_api_pb2.PTransform
mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec
consumers # type: Dict[str, List[operations.Operation]]
):
class ToString(beam.DoFn):
def process(self, element):
key, value = element
return [(key, str(value))]
return _create_simple_pardo_operation(
factory, transform_id, transform_proto, consumers, ToString())