| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """SDK harness for executing Python Fns via the Fn API.""" |
| |
| # pytype: skip-file |
| |
| from __future__ import annotations |
| |
| import base64 |
| import bisect |
| import collections |
| import concurrent.futures |
| import copy |
| import heapq |
| import itertools |
| import json |
| import logging |
| import random |
| import threading |
| from dataclasses import dataclass |
| from dataclasses import field |
| from itertools import chain |
| from typing import TYPE_CHECKING |
| from typing import Any |
| from typing import Callable |
| from typing import Container |
| from typing import DefaultDict |
| from typing import Dict |
| from typing import FrozenSet |
| from typing import Iterable |
| from typing import Iterator |
| from typing import List |
| from typing import Mapping |
| from typing import MutableMapping |
| from typing import Optional |
| from typing import Set |
| from typing import Tuple |
| from typing import Type |
| from typing import TypeVar |
| from typing import Union |
| from typing import cast |
| |
| from google.protobuf import duration_pb2 |
| from google.protobuf import timestamp_pb2 |
| from sortedcontainers import SortedDict |
| from sortedcontainers import SortedList |
| |
| import apache_beam as beam |
| from apache_beam import coders |
| from apache_beam.coders import WindowedValueCoder |
| from apache_beam.coders import coder_impl |
| from apache_beam.internal import pickler |
| from apache_beam.io import iobase |
| from apache_beam.metrics import monitoring_infos |
| from apache_beam.portability import common_urns |
| from apache_beam.portability import python_urns |
| from apache_beam.portability.api import beam_fn_api_pb2 |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.runners import common |
| from apache_beam.runners import pipeline_context |
| from apache_beam.runners.worker import data_sampler |
| from apache_beam.runners.worker import operation_specs |
| from apache_beam.runners.worker import operations |
| from apache_beam.runners.worker import statesampler |
| from apache_beam.runners.worker.worker_status import thread_dump |
| from apache_beam.transforms import TimeDomain |
| from apache_beam.transforms import core |
| from apache_beam.transforms import environments |
| from apache_beam.transforms import sideinputs |
| from apache_beam.transforms import userstate |
| from apache_beam.transforms import window |
| from apache_beam.utils import counters |
| from apache_beam.utils import proto_utils |
| from apache_beam.utils import timestamp |
| from apache_beam.utils.windowed_value import WindowedValue |
| |
| if TYPE_CHECKING: |
| from google.protobuf import message # pylint: disable=ungrouped-imports |
| |
| from apache_beam import pvalue |
| from apache_beam.portability.api import metrics_pb2 |
| from apache_beam.runners.sdf_utils import SplitResultPrimary |
| from apache_beam.runners.sdf_utils import SplitResultResidual |
| from apache_beam.runners.worker import data_plane |
| from apache_beam.runners.worker import sdk_worker |
| from apache_beam.transforms.core import Windowing |
| from apache_beam.transforms.window import BoundedWindow |
| from apache_beam.utils import windowed_value |
| |
| T = TypeVar('T') |
| ConstructorFn = Callable[[ |
| 'BeamTransformFactory', |
| Any, |
| beam_runner_api_pb2.PTransform, |
| Union['message.Message', bytes], |
| Dict[str, List[operations.Operation]] |
| ], |
| operations.Operation] |
| OperationT = TypeVar('OperationT', bound=operations.Operation) |
| FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState', |
| 'CombiningValueRuntimeState', |
| 'SynchronousSetRuntimeState', |
| 'SynchronousBagRuntimeState', |
| 'SynchronousOrderedListRuntimeState'] |
| |
| DATA_INPUT_URN = 'beam:runner:source:v1' |
| DATA_OUTPUT_URN = 'beam:runner:sink:v1' |
| SYNTHETIC_DATA_SAMPLING_URN = 'beam:internal:sampling:v1' |
| IDENTITY_DOFN_URN = 'beam:dofn:identity:0.1' |
| # TODO(vikasrk): Fix this once runner sends appropriate common_urns. |
| OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN = 'beam:dofn:javasdk:0.1' |
| OLD_DATAFLOW_RUNNER_HARNESS_READ_URN = 'beam:source:java:0.1' |
| URNS_NEEDING_PCOLLECTIONS = set([ |
| monitoring_infos.ELEMENT_COUNT_URN, monitoring_infos.SAMPLED_BYTE_SIZE_URN |
| ]) |
| |
| _LOGGER = logging.getLogger(__name__) |
| |
| |
| class RunnerIOOperation(operations.Operation): |
| """Common baseclass for runner harness IO operations.""" |
| def __init__( |
| self, |
| name_context: common.NameContext, |
| step_name: Any, |
| consumers: Mapping[Any, Iterable[operations.Operation]], |
| counter_factory: counters.CounterFactory, |
| state_sampler: statesampler.StateSampler, |
| windowed_coder: coders.Coder, |
| transform_id: str, |
| data_channel: data_plane.DataChannel) -> None: |
| super().__init__(name_context, None, counter_factory, state_sampler) |
| self.windowed_coder = windowed_coder |
| self.windowed_coder_impl = windowed_coder.get_impl() |
| # transform_id represents the consumer for the bytes in the data plane for a |
| # DataInputOperation or a producer of these bytes for a DataOutputOperation. |
| self.transform_id = transform_id |
| self.data_channel = data_channel |
| for _, consumer_ops in consumers.items(): |
| for consumer in consumer_ops: |
| self.add_receiver(consumer, 0) |
| |
| |
| class DataOutputOperation(RunnerIOOperation): |
| """A sink-like operation that gathers outputs to be sent back to the runner. |
| """ |
| def set_output_stream( |
| self, output_stream: data_plane.ClosableOutputStream) -> None: |
| self.output_stream = output_stream |
| |
| def process(self, windowed_value: windowed_value.WindowedValue) -> None: |
| self.windowed_coder_impl.encode_to_stream( |
| windowed_value, self.output_stream, True) |
| self.output_stream.maybe_flush() |
| |
| def finish(self) -> None: |
| super().finish() |
| self.output_stream.close() |
| |
| |
| class DataInputOperation(RunnerIOOperation): |
| """A source-like operation that gathers input from the runner.""" |
| def __init__( |
| self, |
| operation_name: common.NameContext, |
| step_name, |
| consumers: Mapping[Any, List[operations.Operation]], |
| counter_factory: counters.CounterFactory, |
| state_sampler: statesampler.StateSampler, |
| windowed_coder: coders.Coder, |
| transform_id, |
| data_channel: data_plane.GrpcClientDataChannel) -> None: |
| super().__init__( |
| operation_name, |
| step_name, |
| consumers, |
| counter_factory, |
| state_sampler, |
| windowed_coder, |
| transform_id=transform_id, |
| data_channel=data_channel) |
| |
| self.consumer = next(iter(consumers.values())) |
| self.splitting_lock = threading.Lock() |
| self.index = -1 |
| self.stop = float('inf') |
| self.started = False |
| |
| def setup(self, data_sampler=None): |
| super().setup(data_sampler) |
| # We must do this manually as we don't have a spec or spec.output_coders. |
| self.receivers = [ |
| operations.ConsumerSet.create( |
| counter_factory=self.counter_factory, |
| step_name=self.name_context.step_name, |
| output_index=0, |
| consumers=self.consumer, |
| coder=self.windowed_coder, |
| producer_type_hints=self._get_runtime_performance_hints(), |
| producer_batch_converter=self.get_output_batch_converter()) |
| ] |
| |
| def start(self) -> None: |
| super().start() |
| with self.splitting_lock: |
| self.started = True |
| |
| def process(self, windowed_value: windowed_value.WindowedValue) -> None: |
| self.output(windowed_value) |
| |
| def process_encoded(self, encoded_windowed_values: bytes) -> None: |
| input_stream = coder_impl.create_InputStream(encoded_windowed_values) |
| while input_stream.size() > 0: |
| with self.splitting_lock: |
| if self.index == self.stop - 1: |
| return |
| self.index += 1 |
| try: |
| decoded_value = self.windowed_coder_impl.decode_from_stream( |
| input_stream, True) |
| except Exception as exn: |
| coder = str(self.windowed_coder) |
| step = self.name_context.step_name |
| raise ValueError( |
| f"Error decoding input stream with coder {coder} in step {step}" |
| ) from exn |
| self.output(decoded_value) |
| |
| def monitoring_infos( |
| self, transform_id: str, tag_to_pcollection_id: Dict[str, str] |
| ) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]: |
| all_monitoring_infos = super().monitoring_infos( |
| transform_id, tag_to_pcollection_id) |
| read_progress_info = monitoring_infos.int64_counter( |
| monitoring_infos.DATA_CHANNEL_READ_INDEX, |
| self.index, |
| ptransform=transform_id) |
| all_monitoring_infos[monitoring_infos.to_key( |
| read_progress_info)] = read_progress_info |
| return all_monitoring_infos |
| |
| # TODO(https://github.com/apache/beam/issues/19737): typing not compatible |
| # with super type |
| def try_split( # type: ignore[override] |
| self, fraction_of_remainder, total_buffer_size, allowed_split_points |
| ) -> Optional[ |
| Tuple[ |
| int, |
| Iterable[operations.SdfSplitResultsPrimary], |
| Iterable[operations.SdfSplitResultsResidual], |
| int]]: |
| with self.splitting_lock: |
| if not self.started: |
| return None |
| if self.index == -1: |
| # We are "finished" with the (non-existent) previous element. |
| current_element_progress = 1.0 |
| else: |
| current_element_progress_object = ( |
| self.receivers[0].current_element_progress()) |
| if current_element_progress_object is None: |
| current_element_progress = 0.5 |
| else: |
| current_element_progress = ( |
| current_element_progress_object.fraction_completed) |
| # Now figure out where to split. |
| split = self._compute_split( |
| self.index, |
| current_element_progress, |
| self.stop, |
| fraction_of_remainder, |
| total_buffer_size, |
| allowed_split_points, |
| self.receivers[0].try_split) |
| if split: |
| self.stop = split[-1] |
| return split |
| |
| @staticmethod |
| def _compute_split( |
| index, |
| current_element_progress, |
| stop, |
| fraction_of_remainder, |
| total_buffer_size, |
| allowed_split_points=(), |
| try_split=lambda fraction: None): |
| def is_valid_split_point(index): |
| return not allowed_split_points or index in allowed_split_points |
| |
| if total_buffer_size < index + 1: |
| total_buffer_size = index + 1 |
| elif total_buffer_size > stop: |
| total_buffer_size = stop |
| # The units here (except for keep_of_element_remainder) are all in |
| # terms of number of (possibly fractional) elements. |
| remainder = total_buffer_size - index - current_element_progress |
| keep = remainder * fraction_of_remainder |
| if current_element_progress < 1: |
| keep_of_element_remainder = keep / (1 - current_element_progress) |
| # If it's less than what's left of the current element, |
| # try splitting at the current element. |
| if (keep_of_element_remainder < 1 and is_valid_split_point(index) and |
| is_valid_split_point(index + 1)): |
| split: Optional[Tuple[ |
| Iterable[operations.SdfSplitResultsPrimary], |
| Iterable[operations.SdfSplitResultsResidual]]] = try_split( |
| keep_of_element_remainder) |
| if split: |
| element_primaries, element_residuals = split |
| return index - 1, element_primaries, element_residuals, index + 1 |
| # Otherwise, split at the closest element boundary. |
| # pylint: disable=bad-option-value |
| stop_index = index + max(1, int(round(current_element_progress + keep))) |
| if allowed_split_points and stop_index not in allowed_split_points: |
| # Choose the closest allowed split point. |
| allowed_split_points = sorted(allowed_split_points) |
| closest = bisect.bisect(allowed_split_points, stop_index) |
| if closest == 0: |
| stop_index = allowed_split_points[0] |
| elif closest == len(allowed_split_points): |
| stop_index = allowed_split_points[-1] |
| else: |
| prev = allowed_split_points[closest - 1] |
| next = allowed_split_points[closest] |
| if index < prev and stop_index - prev < next - stop_index: |
| stop_index = prev |
| else: |
| stop_index = next |
| if index < stop_index < stop: |
| return stop_index - 1, [], [], stop_index |
| else: |
| return None |
| |
| def finish(self) -> None: |
| super().finish() |
| with self.splitting_lock: |
| self.index += 1 |
| self.started = False |
| |
| def reset(self) -> None: |
| with self.splitting_lock: |
| self.index = -1 |
| self.stop = float('inf') |
| super().reset() |
| |
| |
| class _StateBackedIterable(object): |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| state_key: beam_fn_api_pb2.StateKey, |
| coder_or_impl: Union[coders.Coder, coder_impl.CoderImpl], |
| ) -> None: |
| self._state_handler = state_handler |
| self._state_key = state_key |
| if isinstance(coder_or_impl, coders.Coder): |
| self._coder_impl = coder_or_impl.get_impl() |
| else: |
| self._coder_impl = coder_or_impl |
| |
| def __iter__(self) -> Iterator[Any]: |
| return iter( |
| self._state_handler.blocking_get(self._state_key, self._coder_impl)) |
| |
| def __reduce__(self): |
| return list, (list(self), ) |
| |
| |
| coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type( |
| _StateBackedIterable) |
| |
| |
| class StateBackedSideInputMap(object): |
| |
| _BULK_READ_LIMIT = 100 |
| _BULK_READ_FULLY = "fully" |
| _BULK_READ_PARTIALLY = "partially" |
| |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| transform_id: str, |
| tag: Optional[str], |
| side_input_data: pvalue.SideInputData, |
| coder: WindowedValueCoder, |
| use_bulk_read: bool = False, |
| ) -> None: |
| self._state_handler = state_handler |
| self._transform_id = transform_id |
| self._tag = tag |
| self._side_input_data = side_input_data |
| self._element_coder = coder.wrapped_value_coder |
| self._target_window_coder = coder.window_coder |
| # TODO(robertwb): Limit the cache size. |
| self._cache: Dict[BoundedWindow, Any] = {} |
| self._use_bulk_read = use_bulk_read |
| |
| def __getitem__(self, window): |
| target_window = self._side_input_data.window_mapping_fn(window) |
| if target_window not in self._cache: |
| state_handler = self._state_handler |
| access_pattern = self._side_input_data.access_pattern |
| |
| if access_pattern == common_urns.side_inputs.ITERABLE.urn: |
| state_key = beam_fn_api_pb2.StateKey( |
| iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( |
| transform_id=self._transform_id, |
| side_input_id=self._tag, |
| window=self._target_window_coder.encode(target_window))) |
| raw_view = _StateBackedIterable( |
| state_handler, state_key, self._element_coder) |
| |
| elif access_pattern == common_urns.side_inputs.MULTIMAP.urn: |
| state_key = beam_fn_api_pb2.StateKey( |
| multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( |
| transform_id=self._transform_id, |
| side_input_id=self._tag, |
| window=self._target_window_coder.encode(target_window), |
| key=b'')) |
| kv_iter_state_key = beam_fn_api_pb2.StateKey( |
| multimap_keys_values_side_input=beam_fn_api_pb2.StateKey. |
| MultimapKeysValuesSideInput( |
| transform_id=self._transform_id, |
| side_input_id=self._tag, |
| window=self._target_window_coder.encode(target_window))) |
| cache = {} |
| key_coder = self._element_coder.key_coder() |
| key_coder_impl = key_coder.get_impl() |
| value_coder = self._element_coder.value_coder() |
| use_bulk_read = self._use_bulk_read |
| |
| class MultiMap(object): |
| _bulk_read = None |
| _lock = threading.Lock() |
| |
| def __getitem__(self, key): |
| if use_bulk_read: |
| if self._bulk_read is None: |
| with self._lock: |
| if self._bulk_read is None: |
| try: |
| # Attempt to bulk read the key-values over the iterable |
| # protocol which, if supported, can be much more efficient |
| # than point lookups if it fits into memory. |
| for ix, (k, vs) in enumerate(_StateBackedIterable( |
| state_handler, |
| kv_iter_state_key, |
| coders.TupleCoder( |
| (key_coder, coders.IterableCoder(value_coder))))): |
| cache[k] = vs |
| if ix > StateBackedSideInputMap._BULK_READ_LIMIT: |
| self._bulk_read = ( |
| StateBackedSideInputMap._BULK_READ_PARTIALLY) |
| break |
| else: |
| # We reached the end of the iteration without breaking. |
| self._bulk_read = ( |
| StateBackedSideInputMap._BULK_READ_FULLY) |
| except Exception: |
| _LOGGER.error( |
| "Iterable access of map side inputs unsupported.", |
| exc_info=True) |
| self._bulk_read = ( |
| StateBackedSideInputMap._BULK_READ_PARTIALLY) |
| |
| if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY): |
| return cache.get(key, []) |
| |
| if key not in cache: |
| keyed_state_key = beam_fn_api_pb2.StateKey() |
| keyed_state_key.CopyFrom(state_key) |
| keyed_state_key.multimap_side_input.key = ( |
| key_coder_impl.encode_nested(key)) |
| cache[key] = _StateBackedIterable( |
| state_handler, keyed_state_key, value_coder) |
| |
| return cache[key] |
| |
| def __reduce__(self): |
| # TODO(robertwb): Figure out how to support this. |
| raise TypeError(common_urns.side_inputs.MULTIMAP.urn) |
| |
| raw_view = MultiMap() |
| |
| else: |
| raise ValueError("Unknown access pattern: '%s'" % access_pattern) |
| |
| self._cache[target_window] = self._side_input_data.view_fn(raw_view) |
| return self._cache[target_window] |
| |
| def is_globally_windowed(self) -> bool: |
| return ( |
| self._side_input_data.window_mapping_fn == |
| sideinputs._global_window_mapping_fn) |
| |
| def reset(self) -> None: |
| # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. |
| self._cache = {} |
| |
| |
| class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): |
| def __init__(self, underlying_bag_state): |
| self._underlying_bag_state = underlying_bag_state |
| |
| def read(self) -> Any: |
| values = list(self._underlying_bag_state.read()) |
| if not values: |
| return None |
| return values[0] |
| |
| def write(self, value: Any) -> None: |
| self.clear() |
| self._underlying_bag_state.add(value) |
| |
| def clear(self) -> None: |
| self._underlying_bag_state.clear() |
| |
| def commit(self) -> None: |
| self._underlying_bag_state.commit() |
| |
| |
| class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): |
| def __init__( |
| self, |
| underlying_bag_state: userstate.AccumulatingRuntimeState, |
| combinefn: core.CombineFn) -> None: |
| self._combinefn = combinefn |
| self._combinefn.setup() |
| self._underlying_bag_state = underlying_bag_state |
| self._finalized = False |
| |
| def _read_accumulator(self, rewrite=True): |
| merged_accumulator = self._combinefn.merge_accumulators( |
| self._underlying_bag_state.read()) |
| if rewrite: |
| self._underlying_bag_state.clear() |
| self._underlying_bag_state.add(merged_accumulator) |
| return merged_accumulator |
| |
| def read(self) -> Iterable[Any]: |
| return self._combinefn.extract_output(self._read_accumulator()) |
| |
| def add(self, value: Any) -> None: |
| # Prefer blind writes, but don't let them grow unboundedly. |
| # This should be tuned to be much lower, but for now exercise |
| # both paths well. |
| if random.random() < 0.5: |
| accumulator = self._read_accumulator(False) |
| self._underlying_bag_state.clear() |
| else: |
| accumulator = self._combinefn.create_accumulator() |
| self._underlying_bag_state.add( |
| self._combinefn.add_input(accumulator, value)) |
| |
| def clear(self) -> None: |
| self._underlying_bag_state.clear() |
| |
| def commit(self): |
| self._underlying_bag_state.commit() |
| |
| def finalize(self): |
| if not self._finalized: |
| self._combinefn.teardown() |
| self._finalized = True |
| |
| |
| class _ConcatIterable(object): |
| """An iterable that is the concatination of two iterables. |
| |
| Unlike itertools.chain, this allows reiteration. |
| """ |
| def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: |
| self.first = first |
| self.second = second |
| |
| def __iter__(self) -> Iterator[Any]: |
| for elem in self.first: |
| yield elem |
| for elem in self.second: |
| yield elem |
| |
| |
| coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatIterable) |
| |
| |
| class SynchronousBagRuntimeState(userstate.BagRuntimeState): |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| state_key: beam_fn_api_pb2.StateKey, |
| value_coder: coders.Coder) -> None: |
| self._state_handler = state_handler |
| self._state_key = state_key |
| self._value_coder = value_coder |
| self._cleared = False |
| self._added_elements: List[Any] = [] |
| |
| def read(self) -> Iterable[Any]: |
| return _ConcatIterable([] if self._cleared else cast( |
| 'Iterable[Any]', |
| _StateBackedIterable( |
| self._state_handler, self._state_key, self._value_coder)), |
| self._added_elements) |
| |
| def add(self, value: Any) -> None: |
| self._added_elements.append(value) |
| |
| def clear(self) -> None: |
| self._cleared = True |
| self._added_elements = [] |
| |
| def commit(self) -> None: |
| to_await = None |
| if self._cleared: |
| to_await = self._state_handler.clear(self._state_key) |
| if self._added_elements: |
| to_await = self._state_handler.extend( |
| self._state_key, self._value_coder.get_impl(), self._added_elements) |
| if to_await: |
| # To commit, we need to wait on the last state request future to complete. |
| to_await.get() |
| |
| |
| class SynchronousSetRuntimeState(userstate.SetRuntimeState): |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| state_key: beam_fn_api_pb2.StateKey, |
| value_coder: coders.Coder) -> None: |
| self._state_handler = state_handler |
| self._state_key = state_key |
| self._value_coder = value_coder |
| self._cleared = False |
| self._added_elements: Set[Any] = set() |
| |
| def _compact_data(self, rewrite=True): |
| accumulator = set( |
| _ConcatIterable( |
| set() if self._cleared else _StateBackedIterable( |
| self._state_handler, self._state_key, self._value_coder), |
| self._added_elements)) |
| |
| if rewrite and accumulator: |
| self._state_handler.clear(self._state_key) |
| self._state_handler.extend( |
| self._state_key, self._value_coder.get_impl(), accumulator) |
| |
| # Since everthing is already committed so we can safely reinitialize |
| # added_elements here. |
| self._added_elements = set() |
| |
| return accumulator |
| |
| def read(self) -> Set[Any]: |
| return self._compact_data(rewrite=False) |
| |
| def add(self, value: Any) -> None: |
| if self._cleared: |
| # This is a good time explicitly clear. |
| self._state_handler.clear(self._state_key) |
| self._cleared = False |
| |
| self._added_elements.add(value) |
| if random.random() > 0.5: |
| self._compact_data() |
| |
| def clear(self) -> None: |
| self._cleared = True |
| self._added_elements = set() |
| |
| def commit(self) -> None: |
| to_await = None |
| if self._cleared: |
| to_await = self._state_handler.clear(self._state_key) |
| if self._added_elements: |
| to_await = self._state_handler.extend( |
| self._state_key, self._value_coder.get_impl(), self._added_elements) |
| if to_await: |
| # To commit, we need to wait on the last state request future to complete. |
| to_await.get() |
| |
| |
| class RangeSet: |
| """For Internal Use only. A simple range set for ranges of [x,y).""" |
| def __init__(self) -> None: |
| # The start points and end points are stored separately in order. |
| self._sorted_starts = SortedList() |
| self._sorted_ends = SortedList() |
| |
| def add(self, start: int, end: int) -> None: |
| if start >= end: |
| return |
| |
| # ranges[:min_idx] and ranges[max_idx:] is unaffected by this insertion |
| # the first range whose end point >= the start of the new range |
| min_idx = self._sorted_ends.bisect_left(start) |
| # the first range whose start point > the end point of the new range |
| max_idx = self._sorted_starts.bisect_right(end) |
| |
| if min_idx >= len(self._sorted_starts) or max_idx <= 0: |
| # the new range is beyond any current ranges |
| new_start = start |
| new_end = end |
| else: |
| # the new range overlaps with ranges[min_idx:max_idx] |
| new_start = min(start, self._sorted_starts[min_idx]) |
| new_end = max(end, self._sorted_ends[max_idx - 1]) |
| |
| del self._sorted_starts[min_idx:max_idx] |
| del self._sorted_ends[min_idx:max_idx] |
| |
| self._sorted_starts.add(new_start) |
| self._sorted_ends.add(new_end) |
| |
| def __contains__(self, key: int) -> bool: |
| idx = self._sorted_starts.bisect_left(key) |
| return (idx < len(self._sorted_starts) and self._sorted_starts[idx] |
| == key) or (idx > 0 and self._sorted_ends[idx - 1] > key) |
| |
| def __len__(self) -> int: |
| assert len(self._sorted_starts) == len(self._sorted_ends) |
| return len(self._sorted_starts) |
| |
| def __iter__(self) -> Iterator[Tuple[int, int]]: |
| return zip(self._sorted_starts, self._sorted_ends) |
| |
| def __str__(self) -> str: |
| return str(list(zip(self._sorted_starts, self._sorted_ends))) |
| |
| |
| class SynchronousOrderedListRuntimeState(userstate.OrderedListRuntimeState): |
| RANGE_MIN = -(1 << 63) |
| RANGE_MAX = (1 << 63) - 1 |
| TIMESTAMP_RANGE_MIN = timestamp.Timestamp(micros=RANGE_MIN) |
| TIMESTAMP_RANGE_MAX = timestamp.Timestamp(micros=RANGE_MAX) |
| |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| state_key: beam_fn_api_pb2.StateKey, |
| value_coder: coders.Coder) -> None: |
| self._state_handler = state_handler |
| self._state_key = state_key |
| self._elem_coder = beam.coders.TupleCoder( |
| [coders.VarIntCoder(), coders.coders.LengthPrefixCoder(value_coder)]) |
| self._cleared = False |
| self._pending_adds = SortedDict() |
| self._pending_removes = RangeSet() |
| |
| def add(self, elem: Tuple[timestamp.Timestamp, Any]) -> None: |
| assert len(elem) == 2 |
| key_ts, value = elem |
| key = key_ts.micros |
| |
| if key >= self.RANGE_MAX or key < self.RANGE_MIN: |
| raise ValueError("key value %d is out of range" % key) |
| self._pending_adds.setdefault(key, []).append(value) |
| |
| def read(self) -> Iterable[Tuple[timestamp.Timestamp, Any]]: |
| return self.read_range(self.TIMESTAMP_RANGE_MIN, self.TIMESTAMP_RANGE_MAX) |
| |
| def read_range( |
| self, |
| min_timestamp: timestamp.Timestamp, |
| limit_timestamp: timestamp.Timestamp |
| ) -> Iterable[Tuple[timestamp.Timestamp, Any]]: |
| # convert timestamp to int, as sort keys are stored as int internally. |
| min_key = min_timestamp.micros |
| limit_key = limit_timestamp.micros |
| |
| keys_to_add = self._pending_adds.irange( |
| min_key, limit_key, inclusive=(True, False)) |
| |
| # use list interpretation here to construct the actual list |
| # of iterators of the selected range. |
| local_items = chain.from_iterable([ |
| itertools.islice( |
| zip(itertools.cycle([ |
| k, |
| ]), self._pending_adds[k]), |
| len(self._pending_adds[k])) for k in keys_to_add |
| ]) |
| |
| if not self._cleared: |
| range_query_state_key = beam_fn_api_pb2.StateKey() |
| range_query_state_key.CopyFrom(self._state_key) |
| range_query_state_key.ordered_list_user_state.range.start = min_key |
| range_query_state_key.ordered_list_user_state.range.end = limit_key |
| |
| # make a deep copy here because there could be other operations occur in |
| # the middle of an iteration and change pending_removes |
| pending_removes_snapshot = copy.deepcopy(self._pending_removes) |
| persistent_items = filter( |
| lambda kv: kv[0] not in pending_removes_snapshot, |
| _StateBackedIterable( |
| self._state_handler, range_query_state_key, self._elem_coder)) |
| |
| return map( |
| lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), |
| heapq.merge(persistent_items, local_items)) |
| |
| return map(lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), local_items) |
| |
| def clear(self) -> None: |
| self._cleared = True |
| self._pending_adds = SortedDict() |
| self._pending_removes = RangeSet() |
| self._pending_removes.add(self.RANGE_MIN, self.RANGE_MAX) |
| |
| def clear_range( |
| self, |
| min_timestamp: timestamp.Timestamp, |
| limit_timestamp: timestamp.Timestamp) -> None: |
| min_key = min_timestamp.micros |
| limit_key = limit_timestamp.micros |
| |
| # materialize the keys to remove before the actual removal |
| keys_to_remove = list( |
| self._pending_adds.irange(min_key, limit_key, inclusive=(True, False))) |
| for k in keys_to_remove: |
| del self._pending_adds[k] |
| |
| if not self._cleared: |
| self._pending_removes.add(min_key, limit_key) |
| |
| def commit(self) -> None: |
| futures = [] |
| if self._pending_removes: |
| for start, end in self._pending_removes: |
| range_query_state_key = beam_fn_api_pb2.StateKey() |
| range_query_state_key.CopyFrom(self._state_key) |
| range_query_state_key.ordered_list_user_state.range.start = start |
| range_query_state_key.ordered_list_user_state.range.end = end |
| futures.append(self._state_handler.clear(range_query_state_key)) |
| |
| self._pending_removes = RangeSet() |
| |
| if self._pending_adds: |
| items_to_add = [] |
| for k in self._pending_adds: |
| items_to_add.extend(zip(itertools.cycle([ |
| k, |
| ]), self._pending_adds[k])) |
| futures.append( |
| self._state_handler.extend( |
| self._state_key, self._elem_coder.get_impl(), items_to_add)) |
| self._pending_adds = SortedDict() |
| |
| if len(futures): |
| # To commit, we need to wait on every state request futures to complete. |
| for to_await in futures: |
| to_await.get() |
| |
| self._cleared = False |
| |
| |
| class OutputTimer(userstate.BaseTimer): |
| def __init__( |
| self, |
| key, |
| window: BoundedWindow, |
| timestamp: timestamp.Timestamp, |
| paneinfo: windowed_value.PaneInfo, |
| time_domain: str, |
| timer_family_id: str, |
| timer_coder_impl: coder_impl.TimerCoderImpl, |
| output_stream: data_plane.ClosableOutputStream): |
| self._key = key |
| self._window = window |
| self._input_timestamp = timestamp |
| self._paneinfo = paneinfo |
| self._time_domain = time_domain |
| self._timer_family_id = timer_family_id |
| self._output_stream = output_stream |
| self._timer_coder_impl = timer_coder_impl |
| |
| def set(self, ts: timestamp.TimestampTypes, dynamic_timer_tag='') -> None: |
| ts = timestamp.Timestamp.of(ts) |
| timer = userstate.Timer( |
| user_key=self._key, |
| dynamic_timer_tag=dynamic_timer_tag, |
| windows=(self._window, ), |
| clear_bit=False, |
| fire_timestamp=ts, |
| hold_timestamp=ts if TimeDomain.is_event_time(self._time_domain) else |
| self._input_timestamp, |
| paneinfo=self._paneinfo) |
| self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True) |
| self._output_stream.maybe_flush() |
| |
| def clear(self, dynamic_timer_tag='') -> None: |
| timer = userstate.Timer( |
| user_key=self._key, |
| dynamic_timer_tag=dynamic_timer_tag, |
| windows=(self._window, ), |
| clear_bit=True, |
| fire_timestamp=None, |
| hold_timestamp=None, |
| paneinfo=None) |
| self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True) |
| self._output_stream.maybe_flush() |
| |
| |
| class TimerInfo(object): |
| """A data class to store information related to a timer.""" |
| def __init__(self, timer_coder_impl, output_stream=None): |
| self.timer_coder_impl = timer_coder_impl |
| self.output_stream = output_stream |
| |
| |
| class FnApiUserStateContext(userstate.UserStateContext): |
| """Interface for state and timers from SDK to Fn API servicer of state..""" |
| def __init__( |
| self, |
| state_handler: sdk_worker.CachingStateHandler, |
| transform_id: str, |
| key_coder: coders.Coder, |
| window_coder: coders.Coder, |
| ) -> None: |
| """Initialize a ``FnApiUserStateContext``. |
| |
| Args: |
| state_handler: A StateServicer object. |
| transform_id: The name of the PTransform that this context is associated. |
| key_coder: Coder for the key type. |
| window_coder: Coder for the window type. |
| """ |
| self._state_handler = state_handler |
| self._transform_id = transform_id |
| self._key_coder = key_coder |
| self._window_coder = window_coder |
| # A mapping of {timer_family_id: TimerInfo} |
| self._timers_info: Dict[str, TimerInfo] = {} |
| self._all_states: Dict[tuple, FnApiUserRuntimeStateTypes] = {} |
| |
| def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo) -> None: |
| self._timers_info[timer_family_id] = timer_info |
| |
| def get_timer( |
| self, timer_spec: userstate.TimerSpec, key, window, timestamp, |
| pane) -> OutputTimer: |
| assert self._timers_info[timer_spec.name].output_stream is not None |
| timer_coder_impl = self._timers_info[timer_spec.name].timer_coder_impl |
| output_stream = self._timers_info[timer_spec.name].output_stream |
| return OutputTimer( |
| key, |
| window, |
| timestamp, |
| pane, |
| timer_spec.time_domain, |
| timer_spec.name, |
| timer_coder_impl, |
| output_stream) |
| |
| def get_state(self, *args: Any) -> FnApiUserRuntimeStateTypes: |
| state_handle = self._all_states.get(args) |
| if state_handle is None: |
| state_handle = self._all_states[args] = self._create_state(*args) |
| return state_handle |
| |
| def _create_state( |
| self, state_spec: userstate.StateSpec, key, |
| window: BoundedWindow) -> FnApiUserRuntimeStateTypes: |
| if isinstance(state_spec, |
| (userstate.BagStateSpec, |
| userstate.CombiningValueStateSpec, |
| userstate.ReadModifyWriteStateSpec)): |
| bag_state = SynchronousBagRuntimeState( |
| self._state_handler, |
| state_key=beam_fn_api_pb2.StateKey( |
| bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( |
| transform_id=self._transform_id, |
| user_state_id=state_spec.name, |
| window=self._window_coder.encode(window), |
| # State keys are expected in nested encoding format |
| key=self._key_coder.encode_nested(key))), |
| value_coder=state_spec.coder) |
| if isinstance(state_spec, userstate.BagStateSpec): |
| return bag_state |
| elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec): |
| return ReadModifyWriteRuntimeState(bag_state) |
| else: |
| return CombiningValueRuntimeState( |
| bag_state, copy.deepcopy(state_spec.combine_fn)) |
| elif isinstance(state_spec, userstate.SetStateSpec): |
| return SynchronousSetRuntimeState( |
| self._state_handler, |
| state_key=beam_fn_api_pb2.StateKey( |
| bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( |
| transform_id=self._transform_id, |
| user_state_id=state_spec.name, |
| window=self._window_coder.encode(window), |
| # State keys are expected in nested encoding format |
| key=self._key_coder.encode_nested(key))), |
| value_coder=state_spec.coder) |
| elif isinstance(state_spec, userstate.OrderedListStateSpec): |
| return SynchronousOrderedListRuntimeState( |
| self._state_handler, |
| state_key=beam_fn_api_pb2.StateKey( |
| ordered_list_user_state=beam_fn_api_pb2.StateKey. |
| OrderedListUserState( |
| transform_id=self._transform_id, |
| user_state_id=state_spec.name, |
| window=self._window_coder.encode(window), |
| key=self._key_coder.encode_nested(key))), |
| value_coder=state_spec.coder) |
| else: |
| raise NotImplementedError(state_spec) |
| |
| def commit(self) -> None: |
| for state in self._all_states.values(): |
| state.commit() |
| |
| def reset(self) -> None: |
| for state in self._all_states.values(): |
| state.finalize() |
| self._all_states = {} |
| |
| |
| def memoize(func): |
| cache = {} |
| missing = object() |
| |
| def wrapper(*args): |
| result = cache.get(args, missing) |
| if result is missing: |
| result = cache[args] = func(*args) |
| return result |
| |
| return wrapper |
| |
| |
| def only_element(iterable: Iterable[T]) -> T: |
| element, = iterable |
| return element |
| |
| |
| def _environments_compatible(submission: str, runtime: str) -> bool: |
| if submission == runtime: |
| return True |
| if 'rc' in submission and runtime in submission: |
| # TODO(https://github.com/apache/beam/issues/28084): Loosen |
| # the check for RCs until RC containers install the matching version. |
| return True |
| return False |
| |
| |
| def _verify_descriptor_created_in_a_compatible_env( |
| process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor) -> None: |
| |
| runtime_sdk = environments.sdk_base_version_capability() |
| for t in process_bundle_descriptor.transforms.values(): |
| env = process_bundle_descriptor.environments[t.environment_id] |
| for c in env.capabilities: |
| if (c.startswith(environments.SDK_VERSION_CAPABILITY_PREFIX) and |
| not _environments_compatible(c, runtime_sdk)): |
| raise RuntimeError( |
| "Pipeline construction environment and pipeline runtime " |
| "environment are not compatible. If you use a custom " |
| "container image, check that the Python interpreter minor version " |
| "and the Apache Beam version in your image match the versions " |
| "used at pipeline construction time. " |
| f"Submission environment: {c}. " |
| f"Runtime environment: {runtime_sdk}.") |
| |
| # TODO: Consider warning on mismatches in versions of installed packages. |
| |
| |
| class BundleProcessor(object): |
| """ A class for processing bundles of elements. """ |
| def __init__( |
| self, |
| runner_capabilities: FrozenSet[str], |
| process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, |
| state_handler: sdk_worker.CachingStateHandler, |
| data_channel_factory: data_plane.DataChannelFactory, |
| data_sampler: Optional[data_sampler.DataSampler] = None, |
| ) -> None: |
| """Initialize a bundle processor. |
| |
| Args: |
| runner_capabilities (``FrozenSet[str]``): The set of capabilities of the |
| runner with which we will be interacting |
| process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``): |
| a description of the stage that this ``BundleProcessor``is to execute. |
| state_handler (CachingStateHandler). |
| data_channel_factory (``data_plane.DataChannelFactory``). |
| """ |
| self.runner_capabilities = runner_capabilities |
| self.process_bundle_descriptor = process_bundle_descriptor |
| self.state_handler = state_handler |
| self.data_channel_factory = data_channel_factory |
| self.data_sampler = data_sampler |
| self.current_instruction_id: Optional[str] = None |
| # Represents whether the SDK is consuming received data. |
| self.consuming_received_data = False |
| |
| _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor) |
| # There is no guarantee that the runner only set |
| # timer_api_service_descriptor when having timers. So this field cannot be |
| # used as an indicator of timers. |
| if self.process_bundle_descriptor.timer_api_service_descriptor.url: |
| self.timer_data_channel = ( |
| data_channel_factory.create_data_channel_from_url( |
| self.process_bundle_descriptor.timer_api_service_descriptor.url)) |
| else: |
| self.timer_data_channel = None |
| |
| # A mapping of |
| # {(transform_id, timer_family_id): TimerInfo} |
| # The mapping is empty when there is no timer_family_specs in the |
| # ProcessBundleDescriptor. |
| self.timers_info: Dict[Tuple[str, str], TimerInfo] = {} |
| |
| # TODO(robertwb): Figure out the correct prefix to use for output counters |
| # from StateSampler. |
| self.counter_factory = counters.CounterFactory() |
| self.state_sampler = statesampler.StateSampler( |
| 'fnapi-step-%s' % self.process_bundle_descriptor.id, |
| self.counter_factory) |
| |
| with concurrent.futures.ThreadPoolExecutor( |
| max_workers=1, thread_name_prefix='ExecutionTreeCreator') as executor: |
| future = executor.submit( |
| self.create_execution_tree, self.process_bundle_descriptor) |
| try: |
| self.ops = future.result(timeout=3600) |
| except concurrent.futures.TimeoutError: |
| # In rare cases, unpickling a DoFn might get permanently stuck, |
| # for example when unpickling involves importing a module and |
| # a subprocess is launched during the import operation. |
| _LOGGER.error( |
| 'Timed out while reconstructing a pipeline fragment for: %s.\n' |
| 'This is likely a transient error. The SDK harness ' |
| 'will self-terminate, and the runner can retry the operation. ' |
| 'If the error is frequent, check whether the stuckness happens ' |
| 'while deserializing (unpickling) a dependency of your pipeline ' |
| 'in the stacktrace below: \n%s\n', |
| self.process_bundle_descriptor.id, |
| thread_dump('ExecutionTreeCreator')) |
| # Raising an exception here doesn't interrupt the left-over thread. |
| # Out of caution, terminate the SDK harness process. |
| from apache_beam.runners.worker.sdk_worker_main import terminate_sdk_harness |
| terminate_sdk_harness() |
| |
| for op in reversed(self.ops.values()): |
| op.setup(self.data_sampler) |
| self.splitting_lock = threading.Lock() |
| |
| def create_execution_tree( |
| self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor |
| ) -> collections.OrderedDict[str, operations.DoOperation]: |
| transform_factory = BeamTransformFactory( |
| self.runner_capabilities, |
| descriptor, |
| self.data_channel_factory, |
| self.counter_factory, |
| self.state_sampler, |
| self.state_handler, |
| self.data_sampler, |
| ) |
| |
| self.timers_info = transform_factory.extract_timers_info() |
| |
| def is_side_input(transform_proto, tag): |
| if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: |
| return tag in proto_utils.parse_Bytes( |
| transform_proto.spec.payload, |
| beam_runner_api_pb2.ParDoPayload).side_inputs |
| |
| pcoll_consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) |
| for transform_id, transform_proto in descriptor.transforms.items(): |
| for tag, pcoll_id in transform_proto.inputs.items(): |
| if not is_side_input(transform_proto, tag): |
| pcoll_consumers[pcoll_id].append(transform_id) |
| |
| @memoize |
| def get_operation(transform_id: str) -> operations.Operation: |
| transform_consumers = { |
| tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] |
| for tag, pcoll_id in |
| descriptor.transforms[transform_id].outputs.items() |
| } |
| |
| # Initialize transform-specific state in the Data Sampler. |
| if self.data_sampler: |
| self.data_sampler.initialize_samplers( |
| transform_id, descriptor, transform_factory.get_coder) |
| |
| return transform_factory.create_operation( |
| transform_id, transform_consumers) |
| |
| # Operations must be started (hence returned) in order. |
| @memoize |
| def topological_height(transform_id: str) -> int: |
| return 1 + max([0] + [ |
| topological_height(consumer) |
| for pcoll in descriptor.transforms[transform_id].outputs.values() |
| for consumer in pcoll_consumers[pcoll] |
| ]) |
| |
| return collections.OrderedDict([( |
| transform_id, |
| cast(operations.DoOperation, |
| get_operation(transform_id))) for transform_id in sorted( |
| descriptor.transforms, key=topological_height, reverse=True)]) |
| |
| def reset(self) -> None: |
| self.counter_factory.reset() |
| self.state_sampler.reset() |
| # Side input caches. |
| for op in self.ops.values(): |
| op.reset() |
| |
| def process_bundle( |
| self, instruction_id: str |
| ) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]: |
| |
| expected_input_ops: List[DataInputOperation] = [] |
| |
| for op in self.ops.values(): |
| if isinstance(op, DataOutputOperation): |
| # TODO(robertwb): Is there a better way to pass the instruction id to |
| # the operation? |
| op.set_output_stream( |
| op.data_channel.output_stream(instruction_id, op.transform_id)) |
| elif isinstance(op, DataInputOperation): |
| # We must wait until we receive "end of stream" for each of these ops. |
| expected_input_ops.append(op) |
| |
| try: |
| execution_context = ExecutionContext(instruction_id=instruction_id) |
| self.current_instruction_id = instruction_id |
| self.state_sampler.start() |
| # Start all operations. |
| for op in reversed(self.ops.values()): |
| _LOGGER.debug('start %s', op) |
| op.execution_context = execution_context |
| op.start() |
| |
| # Each data_channel is mapped to a list of expected inputs which includes |
| # both data input and timer input. The data input is identied by |
| # transform_id. The data input is identified by |
| # (transform_id, timer_family_id). |
| data_channels: DefaultDict[data_plane.DataChannel, |
| List[Union[str, Tuple[ |
| str, |
| str]]]] = collections.defaultdict(list) |
| |
| # Add expected data inputs for each data channel. |
| input_op_by_transform_id = {} |
| for input_op in expected_input_ops: |
| data_channels[input_op.data_channel].append(input_op.transform_id) |
| input_op_by_transform_id[input_op.transform_id] = input_op |
| |
| # Update timer_data channel with expected timer inputs. |
| if self.timer_data_channel: |
| data_channels[self.timer_data_channel].extend( |
| list(self.timers_info.keys())) |
| |
| # Set up timer output stream for DoOperation. |
| for ((transform_id, timer_family_id), |
| timer_info) in self.timers_info.items(): |
| output_stream = self.timer_data_channel.output_timer_stream( |
| instruction_id, transform_id, timer_family_id) |
| timer_info.output_stream = output_stream |
| self.ops[transform_id].add_timer_info(timer_family_id, timer_info) |
| |
| # Process data and timer inputs |
| # We are currently not consuming received data. |
| self.consuming_received_data = False |
| for data_channel, expected_inputs in data_channels.items(): |
| for element in data_channel.input_elements(instruction_id, |
| expected_inputs): |
| # Since we have received a set of elements and are consuming it. |
| self.consuming_received_data = True |
| if isinstance(element, beam_fn_api_pb2.Elements.Timers): |
| timer_coder_impl = ( |
| self.timers_info[( |
| element.transform_id, |
| element.timer_family_id)].timer_coder_impl) |
| for timer_data in timer_coder_impl.decode_all(element.timers): |
| self.ops[element.transform_id].process_timer( |
| element.timer_family_id, timer_data) |
| elif isinstance(element, beam_fn_api_pb2.Elements.Data): |
| input_op_by_transform_id[element.transform_id].process_encoded( |
| element.data) |
| # We are done consuming the set of elements. |
| self.consuming_received_data = False |
| |
| # Finish all operations. |
| for op in self.ops.values(): |
| _LOGGER.debug('finish %s', op) |
| op.finish() |
| |
| # Close every timer output stream |
| for timer_info in self.timers_info.values(): |
| assert timer_info.output_stream is not None |
| timer_info.output_stream.close() |
| |
| return ([ |
| self.delayed_bundle_application(op, residual) |
| for op, residual in execution_context.delayed_applications |
| ], |
| self.requires_finalization()) |
| |
| finally: |
| self.consuming_received_data = False |
| # Ensure any in-flight split attempts complete. |
| with self.splitting_lock: |
| self.current_instruction_id = None |
| self.state_sampler.stop_if_still_running() |
| |
| def finalize_bundle(self) -> beam_fn_api_pb2.FinalizeBundleResponse: |
| for op in self.ops.values(): |
| op.finalize_bundle() |
| return beam_fn_api_pb2.FinalizeBundleResponse() |
| |
| def requires_finalization(self) -> bool: |
| return any(op.needs_finalization() for op in self.ops.values()) |
| |
| def try_split( |
| self, bundle_split_request: beam_fn_api_pb2.ProcessBundleSplitRequest |
| ) -> beam_fn_api_pb2.ProcessBundleSplitResponse: |
| split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() |
| with self.splitting_lock: |
| if bundle_split_request.instruction_id != self.current_instruction_id: |
| # This may be a delayed split for a former bundle, see BEAM-12475. |
| return split_response |
| |
| for op in self.ops.values(): |
| if isinstance(op, DataInputOperation): |
| desired_split = bundle_split_request.desired_splits.get( |
| op.transform_id) |
| if desired_split: |
| split = op.try_split( |
| desired_split.fraction_of_remainder, |
| desired_split.estimated_input_elements, |
| desired_split.allowed_split_points) |
| if split: |
| ( |
| primary_end, |
| element_primaries, |
| element_residuals, |
| residual_start, |
| ) = split |
| for element_primary in element_primaries: |
| split_response.primary_roots.add().CopyFrom( |
| self.bundle_application(*element_primary)) |
| for element_residual in element_residuals: |
| split_response.residual_roots.add().CopyFrom( |
| self.delayed_bundle_application(*element_residual)) |
| split_response.channel_splits.extend([ |
| beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit( |
| transform_id=op.transform_id, |
| last_primary_element=primary_end, |
| first_residual_element=residual_start) |
| ]) |
| |
| return split_response |
| |
| def delayed_bundle_application( |
| self, op: operations.DoOperation, deferred_remainder: SplitResultResidual |
| ) -> beam_fn_api_pb2.DelayedBundleApplication: |
| assert op.input_info is not None |
| # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. |
| (element_and_restriction, current_watermark, deferred_timestamp) = ( |
| deferred_remainder) |
| if deferred_timestamp: |
| assert isinstance(deferred_timestamp, timestamp.Duration) |
| proto_deferred_watermark: Optional[ |
| duration_pb2.Duration] = proto_utils.from_micros( |
| duration_pb2.Duration, deferred_timestamp.micros) |
| else: |
| proto_deferred_watermark = None |
| return beam_fn_api_pb2.DelayedBundleApplication( |
| requested_time_delay=proto_deferred_watermark, |
| application=self.construct_bundle_application( |
| op.input_info, current_watermark, element_and_restriction)) |
| |
| def bundle_application( |
| self, op: operations.DoOperation, |
| primary: SplitResultPrimary) -> beam_fn_api_pb2.BundleApplication: |
| assert op.input_info is not None |
| return self.construct_bundle_application( |
| op.input_info, None, primary.primary_value) |
| |
| def construct_bundle_application( |
| self, |
| op_input_info: operations.OpInputInfo, |
| output_watermark: Optional[timestamp.Timestamp], |
| element) -> beam_fn_api_pb2.BundleApplication: |
| transform_id, main_input_tag, main_input_coder, outputs = op_input_info |
| if output_watermark: |
| proto_output_watermark = proto_utils.from_micros( |
| timestamp_pb2.Timestamp, output_watermark.micros) |
| output_watermarks: Optional[Dict[str, timestamp_pb2.Timestamp]] = { |
| output: proto_output_watermark |
| for output in outputs |
| } |
| else: |
| output_watermarks = None |
| return beam_fn_api_pb2.BundleApplication( |
| transform_id=transform_id, |
| input_id=main_input_tag, |
| output_watermarks=output_watermarks, |
| element=main_input_coder.get_impl().encode_nested(element)) |
| |
| def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: |
| """Returns the list of MonitoringInfos collected processing this bundle.""" |
| # Construct a new dict first to remove duplicates. |
| all_monitoring_infos_dict = {} |
| for transform_id, op in self.ops.items(): |
| tag_to_pcollection_id = self.process_bundle_descriptor.transforms[ |
| transform_id].outputs |
| all_monitoring_infos_dict.update( |
| op.monitoring_infos(transform_id, dict(tag_to_pcollection_id))) |
| |
| return list(all_monitoring_infos_dict.values()) |
| |
| def shutdown(self) -> None: |
| for op in self.ops.values(): |
| op.teardown() |
| |
| |
| @dataclass |
| class ExecutionContext: |
| # Any splits to be processed later. |
| delayed_applications: List[Tuple[operations.DoOperation, |
| common.SplitResultResidual]] = field( |
| default_factory=list) |
| |
| # The exception sampler for the currently executing PTransform. |
| output_sampler: Optional[data_sampler.OutputSampler] = None |
| |
| # The current instruction being executed. |
| instruction_id: Optional[str] = None |
| |
| |
| class BeamTransformFactory(object): |
| """Factory for turning transform_protos into executable operations.""" |
| def __init__( |
| self, |
| runner_capabilities: FrozenSet[str], |
| descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, |
| data_channel_factory: data_plane.DataChannelFactory, |
| counter_factory: counters.CounterFactory, |
| state_sampler: statesampler.StateSampler, |
| state_handler: sdk_worker.CachingStateHandler, |
| data_sampler: Optional[data_sampler.DataSampler], |
| ): |
| self.runner_capabilities = runner_capabilities |
| self.descriptor = descriptor |
| self.data_channel_factory = data_channel_factory |
| self.counter_factory = counter_factory |
| self.state_sampler = state_sampler |
| self.state_handler = state_handler |
| self.context = pipeline_context.PipelineContext( |
| descriptor, |
| iterable_state_read=lambda token, element_coder_impl: |
| _StateBackedIterable( |
| state_handler, beam_fn_api_pb2.StateKey( |
| runner=beam_fn_api_pb2.StateKey.Runner(key=token)), |
| element_coder_impl)) |
| self.data_sampler = data_sampler |
| |
| _known_urns: Dict[str, |
| Tuple[ConstructorFn, |
| Union[Type[message.Message], Type[bytes], |
| None]]] = {} |
| |
| @classmethod |
| def register_urn( |
| cls, urn: str, parameter_type: Optional[Type[T]] |
| ) -> Callable[[ |
| Callable[[ |
| BeamTransformFactory, |
| str, |
| beam_runner_api_pb2.PTransform, |
| T, |
| Dict[str, List[operations.Operation]] |
| ], |
| operations.Operation] |
| ], |
| Callable[[ |
| BeamTransformFactory, |
| str, |
| beam_runner_api_pb2.PTransform, |
| T, |
| Dict[str, List[operations.Operation]] |
| ], |
| operations.Operation]]: |
| def wrapper(func): |
| cls._known_urns[urn] = func, parameter_type |
| return func |
| |
| return wrapper |
| |
| def create_operation( |
| self, transform_id: str, |
| consumers: Dict[str, List[operations.Operation]]) -> operations.Operation: |
| transform_proto = self.descriptor.transforms[transform_id] |
| if not transform_proto.unique_name: |
| _LOGGER.debug("No unique name set for transform %s" % transform_id) |
| transform_proto.unique_name = transform_id |
| creator, parameter_type = self._known_urns[transform_proto.spec.urn] |
| payload = proto_utils.parse_Bytes( |
| transform_proto.spec.payload, parameter_type) |
| return creator(self, transform_id, transform_proto, payload, consumers) |
| |
| def extract_timers_info(self) -> Dict[Tuple[str, str], TimerInfo]: |
| timers_info = {} |
| for transform_id, transform_proto in self.descriptor.transforms.items(): |
| if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: |
| pardo_payload = proto_utils.parse_Bytes( |
| transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload) |
| for (timer_family_id, |
| timer_family_spec) in pardo_payload.timer_family_specs.items(): |
| timer_coder_impl = self.get_coder( |
| timer_family_spec.timer_family_coder_id).get_impl() |
| # The output_stream should be updated when processing a bundle. |
| timers_info[(transform_id, timer_family_id)] = TimerInfo( |
| timer_coder_impl=timer_coder_impl) |
| return timers_info |
| |
| def get_coder(self, coder_id: str) -> coders.Coder: |
| if coder_id not in self.descriptor.coders: |
| raise KeyError("No such coder: %s" % coder_id) |
| coder_proto = self.descriptor.coders[coder_id] |
| if coder_proto.spec.urn: |
| return self.context.coders.get_by_id(coder_id) |
| else: |
| # No URN, assume cloud object encoding json bytes. |
| return operation_specs.get_coder_from_spec( |
| json.loads(coder_proto.spec.payload.decode('utf-8'))) |
| |
| def get_windowed_coder(self, pcoll_id: str) -> WindowedValueCoder: |
| coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) |
| # TODO(robertwb): Remove this condition once all runners are consistent. |
| if not isinstance(coder, WindowedValueCoder): |
| windowing_strategy = self.descriptor.windowing_strategies[ |
| self.descriptor.pcollections[pcoll_id].windowing_strategy_id] |
| return WindowedValueCoder( |
| coder, self.get_coder(windowing_strategy.window_coder_id)) |
| else: |
| return coder |
| |
| def get_output_coders( |
| self, transform_proto: beam_runner_api_pb2.PTransform |
| ) -> Dict[str, coders.Coder]: |
| return { |
| tag: self.get_windowed_coder(pcoll_id) |
| for tag, pcoll_id in transform_proto.outputs.items() |
| } |
| |
| def get_only_output_coder( |
| self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: |
| return only_element(self.get_output_coders(transform_proto).values()) |
| |
| def get_input_coders( |
| self, transform_proto: beam_runner_api_pb2.PTransform |
| ) -> Dict[str, coders.WindowedValueCoder]: |
| return { |
| tag: self.get_windowed_coder(pcoll_id) |
| for tag, pcoll_id in transform_proto.inputs.items() |
| } |
| |
| def get_only_input_coder( |
| self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: |
| return only_element(list(self.get_input_coders(transform_proto).values())) |
| |
| def get_input_windowing( |
| self, transform_proto: beam_runner_api_pb2.PTransform) -> Windowing: |
| pcoll_id = only_element(transform_proto.inputs.values()) |
| windowing_strategy_id = self.descriptor.pcollections[ |
| pcoll_id].windowing_strategy_id |
| return self.context.windowing_strategies.get_by_id(windowing_strategy_id) |
| |
| # TODO(robertwb): Update all operations to take these in the constructor. |
| @staticmethod |
| def augment_oldstyle_op( |
| op: OperationT, |
| step_name: str, |
| consumers: Mapping[str, Iterable[operations.Operation]], |
| tag_list: Optional[List[str]] = None) -> OperationT: |
| op.step_name = step_name |
| for tag, op_consumers in consumers.items(): |
| for consumer in op_consumers: |
| op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0) |
| return op |
| |
| |
| @BeamTransformFactory.register_urn( |
| DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) |
| def create_source_runner( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| grpc_port: beam_fn_api_pb2.RemoteGrpcPort, |
| consumers: Dict[str, List[operations.Operation]]) -> DataInputOperation: |
| |
| output_coder = factory.get_coder(grpc_port.coder_id) |
| return DataInputOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| transform_proto.unique_name, |
| consumers, |
| factory.counter_factory, |
| factory.state_sampler, |
| output_coder, |
| transform_id=transform_id, |
| data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) |
| |
| |
| @BeamTransformFactory.register_urn( |
| DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) |
| def create_sink_runner( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| grpc_port: beam_fn_api_pb2.RemoteGrpcPort, |
| consumers: Dict[str, List[operations.Operation]]) -> DataOutputOperation: |
| output_coder = factory.get_coder(grpc_port.coder_id) |
| return DataOutputOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| transform_proto.unique_name, |
| consumers, |
| factory.counter_factory, |
| factory.state_sampler, |
| output_coder, |
| transform_id=transform_id, |
| data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) |
| |
| |
| @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) |
| def create_source_java( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter, |
| consumers: Dict[str, |
| List[operations.Operation]]) -> operations.ReadOperation: |
| # The Dataflow runner harness strips the base64 encoding. |
| source = pickler.loads(base64.b64encode(parameter)) |
| spec = operation_specs.WorkerRead( |
| iobase.SourceBundle(1.0, source, None, None), |
| [factory.get_only_output_coder(transform_proto)]) |
| return factory.augment_oldstyle_op( |
| operations.ReadOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| spec, |
| factory.counter_factory, |
| factory.state_sampler), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) |
| def create_deprecated_read( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter: beam_runner_api_pb2.ReadPayload, |
| consumers: Dict[str, |
| List[operations.Operation]]) -> operations.ReadOperation: |
| source = iobase.BoundedSource.from_runner_api( |
| parameter.source, factory.context) |
| spec = operation_specs.WorkerRead( |
| iobase.SourceBundle(1.0, source, None, None), |
| [WindowedValueCoder(source.default_output_coder())]) |
| return factory.augment_oldstyle_op( |
| operations.ReadOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| spec, |
| factory.counter_factory, |
| factory.state_sampler), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn( |
| python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) |
| def create_read_from_impulse_python( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter: beam_runner_api_pb2.ReadPayload, |
| consumers: Dict[str, List[operations.Operation]] |
| ) -> operations.ImpulseReadOperation: |
| return operations.ImpulseReadOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| factory.counter_factory, |
| factory.state_sampler, |
| consumers, |
| iobase.BoundedSource.from_runner_api(parameter.source, factory.context), |
| factory.get_only_output_coder(transform_proto)) |
| |
| |
| @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) |
| def create_dofn_javasdk( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| serialized_fn, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_pardo_operation( |
| factory, transform_id, transform_proto, consumers, serialized_fn) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, |
| beam_runner_api_pb2.ParDoPayload) |
| def create_pair_with_restriction(*args): |
| class PairWithRestriction(beam.DoFn): |
| def __init__(self, fn, restriction_provider, watermark_estimator_provider): |
| self.restriction_provider = restriction_provider |
| self.watermark_estimator_provider = watermark_estimator_provider |
| |
| def process(self, element, *args, **kwargs): |
| # TODO(SDF): Do we want to allow mutation of the element? |
| # (E.g. it could be nice to shift bulky description to the portion |
| # that can be distributed.) |
| initial_restriction = self.restriction_provider.initial_restriction( |
| element) |
| initial_estimator_state = ( |
| self.watermark_estimator_provider.initial_estimator_state( |
| element, initial_restriction)) |
| yield (element, (initial_restriction, initial_estimator_state)) |
| |
| return _create_sdf_operation(PairWithRestriction, *args) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn, |
| beam_runner_api_pb2.ParDoPayload) |
| def create_split_and_size_restrictions(*args): |
| class SplitAndSizeRestrictions(beam.DoFn): |
| def __init__(self, fn, restriction_provider, watermark_estimator_provider): |
| self.restriction_provider = restriction_provider |
| self.watermark_estimator_provider = watermark_estimator_provider |
| |
| def process(self, element_restriction, *args, **kwargs): |
| element, (restriction, _) = element_restriction |
| for part, size in self.restriction_provider.split_and_size( |
| element, restriction): |
| if size < 0: |
| raise ValueError('Expected size >= 0 but received %s.' % size) |
| estimator_state = ( |
| self.watermark_estimator_provider.initial_estimator_state( |
| element, part)) |
| yield ((element, (part, estimator_state)), size) |
| |
| return _create_sdf_operation(SplitAndSizeRestrictions, *args) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn, |
| beam_runner_api_pb2.ParDoPayload) |
| def create_truncate_sized_restriction(*args): |
| class TruncateAndSizeRestriction(beam.DoFn): |
| def __init__(self, fn, restriction_provider, watermark_estimator_provider): |
| self.restriction_provider = restriction_provider |
| |
| def process(self, element_restriction, *args, **kwargs): |
| ((element, (restriction, estimator_state)), _) = element_restriction |
| truncated_restriction = self.restriction_provider.truncate( |
| element, restriction) |
| if truncated_restriction: |
| truncated_restriction_size = ( |
| self.restriction_provider.restriction_size( |
| element, truncated_restriction)) |
| if truncated_restriction_size < 0: |
| raise ValueError( |
| 'Expected size >= 0 but received %s.' % |
| truncated_restriction_size) |
| yield ((element, (truncated_restriction, estimator_state)), |
| truncated_restriction_size) |
| |
| return _create_sdf_operation( |
| TruncateAndSizeRestriction, |
| *args, |
| operation_cls=operations.SdfTruncateSizedRestrictions) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, |
| beam_runner_api_pb2.ParDoPayload) |
| def create_process_sized_elements_and_restrictions( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter: beam_runner_api_pb2.ParDoPayload, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_pardo_operation( |
| factory, |
| transform_id, |
| transform_proto, |
| consumers, |
| core.DoFnInfo.from_runner_api(parameter.do_fn, |
| factory.context).serialized_dofn_data(), |
| parameter, |
| operation_cls=operations.SdfProcessSizedElements) |
| |
| |
| def _create_sdf_operation( |
| proxy_dofn, |
| factory, |
| transform_id, |
| transform_proto, |
| parameter, |
| consumers, |
| operation_cls=operations.DoOperation): |
| |
| dofn_data = pickler.loads(parameter.do_fn.payload) |
| dofn = dofn_data[0] |
| restriction_provider = common.DoFnSignature(dofn).get_restriction_provider() |
| watermark_estimator_provider = ( |
| common.DoFnSignature(dofn).get_watermark_estimator_provider()) |
| serialized_fn = pickler.dumps( |
| (proxy_dofn(dofn, restriction_provider, watermark_estimator_provider), ) + |
| dofn_data[1:]) |
| return _create_pardo_operation( |
| factory, |
| transform_id, |
| transform_proto, |
| consumers, |
| serialized_fn, |
| parameter, |
| operation_cls=operation_cls) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) |
| def create_par_do( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter: beam_runner_api_pb2.ParDoPayload, |
| consumers: Dict[str, List[operations.Operation]]) -> operations.DoOperation: |
| return _create_pardo_operation( |
| factory, |
| transform_id, |
| transform_proto, |
| consumers, |
| core.DoFnInfo.from_runner_api(parameter.do_fn, |
| factory.context).serialized_dofn_data(), |
| parameter) |
| |
| |
| def _create_pardo_operation( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| consumers, |
| serialized_fn, |
| pardo_proto: Optional[beam_runner_api_pb2.ParDoPayload] = None, |
| operation_cls=operations.DoOperation): |
| |
| if pardo_proto and pardo_proto.side_inputs: |
| input_tags_to_coders = factory.get_input_coders(transform_proto) |
| tagged_side_inputs = [ |
| (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) |
| for tag, si in pardo_proto.side_inputs.items() |
| ] |
| tagged_side_inputs.sort( |
| key=lambda tag_si: sideinputs.get_sideinput_index(tag_si[0])) |
| side_input_maps = [ |
| StateBackedSideInputMap( |
| factory.state_handler, |
| transform_id, |
| tag, |
| si, |
| input_tags_to_coders[tag], |
| use_bulk_read=( |
| common_urns.runner_protocols.MULTIMAP_KEYS_VALUES_SIDE_INPUT.urn |
| in factory.runner_capabilities)) |
| for (tag, si) in tagged_side_inputs |
| ] |
| else: |
| side_input_maps = [] |
| |
| output_tags = list(transform_proto.outputs.keys()) |
| |
| dofn_data = pickler.loads(serialized_fn) |
| if not dofn_data[-1]: |
| # Windowing not set. |
| if pardo_proto: |
| other_input_tags: Container[str] = set.union( |
| set(pardo_proto.side_inputs), set(pardo_proto.timer_family_specs)) |
| else: |
| other_input_tags = () |
| pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() |
| if tag not in other_input_tags] |
| windowing = factory.context.windowing_strategies.get_by_id( |
| factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) |
| serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, )) |
| |
| if pardo_proto and (pardo_proto.timer_family_specs or pardo_proto.state_specs |
| or pardo_proto.restriction_coder_id): |
| found_input_coder = None |
| for tag, pcoll_id in transform_proto.inputs.items(): |
| if tag in pardo_proto.side_inputs: |
| pass |
| else: |
| # Must be the main input |
| assert found_input_coder is None |
| main_input_tag = tag |
| found_input_coder = factory.get_windowed_coder(pcoll_id) |
| assert found_input_coder is not None |
| main_input_coder = found_input_coder |
| |
| if pardo_proto.timer_family_specs or pardo_proto.state_specs: |
| user_state_context: Optional[ |
| FnApiUserStateContext] = FnApiUserStateContext( |
| factory.state_handler, |
| transform_id, |
| main_input_coder.key_coder(), |
| main_input_coder.window_coder) |
| else: |
| user_state_context = None |
| else: |
| user_state_context = None |
| |
| output_coders = factory.get_output_coders(transform_proto) |
| spec = operation_specs.WorkerDoFn( |
| serialized_fn=serialized_fn, |
| output_tags=output_tags, |
| input=None, |
| side_inputs=None, # Fn API uses proto definitions and the Fn State API |
| output_coders=[output_coders[tag] for tag in output_tags]) |
| |
| result = factory.augment_oldstyle_op( |
| operation_cls( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| spec, |
| factory.counter_factory, |
| factory.state_sampler, |
| side_input_maps, |
| user_state_context), |
| transform_proto.unique_name, |
| consumers, |
| output_tags) |
| if pardo_proto and pardo_proto.restriction_coder_id: |
| result.input_info = operations.OpInputInfo( |
| transform_id, |
| main_input_tag, |
| main_input_coder, |
| transform_proto.outputs.keys()) |
| return result |
| |
| |
| def _create_simple_pardo_operation( |
| factory: BeamTransformFactory, |
| transform_id, |
| transform_proto, |
| consumers, |
| dofn: beam.DoFn, |
| ): |
| serialized_fn = pickler.dumps((dofn, (), {}, [], None)) |
| return _create_pardo_operation( |
| factory, transform_id, transform_proto, consumers, serialized_fn) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.primitives.ASSIGN_WINDOWS.urn, |
| beam_runner_api_pb2.WindowingStrategy) |
| def create_assign_windows( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter: beam_runner_api_pb2.WindowingStrategy, |
| consumers: Dict[str, List[operations.Operation]]): |
| class WindowIntoDoFn(beam.DoFn): |
| def __init__(self, windowing): |
| self.windowing = windowing |
| |
| def process( |
| self, |
| element, |
| timestamp=beam.DoFn.TimestampParam, |
| window=beam.DoFn.WindowParam): |
| new_windows = self.windowing.windowfn.assign( |
| WindowFn.AssignContext(timestamp, element=element, window=window)) |
| yield WindowedValue(element, timestamp, new_windows) |
| |
| from apache_beam.transforms.core import Windowing |
| from apache_beam.transforms.window import WindowFn |
| windowing = Windowing.from_runner_api(parameter, factory.context) |
| return _create_simple_pardo_operation( |
| factory, |
| transform_id, |
| transform_proto, |
| consumers, |
| WindowIntoDoFn(windowing)) |
| |
| |
| @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) |
| def create_identity_dofn( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| parameter, |
| consumers: Dict[str, List[operations.Operation]] |
| ) -> operations.FlattenOperation: |
| return factory.augment_oldstyle_op( |
| operations.FlattenOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| operation_specs.WorkerFlatten( |
| None, [factory.get_only_output_coder(transform_proto)]), |
| factory.counter_factory, |
| factory.state_sampler), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, |
| beam_runner_api_pb2.CombinePayload) |
| def create_combine_per_key_precombine( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload: beam_runner_api_pb2.CombinePayload, |
| consumers: Dict[str, |
| List[operations.Operation]]) -> operations.PGBKCVOperation: |
| serialized_combine_fn = pickler.dumps(( |
| beam.CombineFn.from_runner_api(payload.combine_fn, |
| factory.context), [], {})) |
| return factory.augment_oldstyle_op( |
| operations.PGBKCVOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| operation_specs.WorkerPartialGroupByKey( |
| serialized_combine_fn, |
| None, [factory.get_only_output_coder(transform_proto)]), |
| factory.counter_factory, |
| factory.state_sampler, |
| factory.get_input_windowing(transform_proto)), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, |
| beam_runner_api_pb2.CombinePayload) |
| def create_combbine_per_key_merge_accumulators( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload: beam_runner_api_pb2.CombinePayload, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_combine_phase_operation( |
| factory, transform_id, transform_proto, payload, consumers, 'merge') |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, |
| beam_runner_api_pb2.CombinePayload) |
| def create_combine_per_key_extract_outputs( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload: beam_runner_api_pb2.CombinePayload, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_combine_phase_operation( |
| factory, transform_id, transform_proto, payload, consumers, 'extract') |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, |
| beam_runner_api_pb2.CombinePayload) |
| def create_combine_per_key_convert_to_accumulators( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload: beam_runner_api_pb2.CombinePayload, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_combine_phase_operation( |
| factory, transform_id, transform_proto, payload, consumers, 'convert') |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, |
| beam_runner_api_pb2.CombinePayload) |
| def create_combine_grouped_values( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload: beam_runner_api_pb2.CombinePayload, |
| consumers: Dict[str, List[operations.Operation]]): |
| return _create_combine_phase_operation( |
| factory, transform_id, transform_proto, payload, consumers, 'all') |
| |
| |
| def _create_combine_phase_operation( |
| factory, transform_id, transform_proto, payload, consumers, |
| phase) -> operations.CombineOperation: |
| serialized_combine_fn = pickler.dumps(( |
| beam.CombineFn.from_runner_api(payload.combine_fn, |
| factory.context), [], {})) |
| return factory.augment_oldstyle_op( |
| operations.CombineOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| operation_specs.WorkerCombineFn( |
| serialized_combine_fn, |
| phase, |
| None, [factory.get_only_output_coder(transform_proto)]), |
| factory.counter_factory, |
| factory.state_sampler), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) |
| def create_flatten( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| payload, |
| consumers: Dict[str, List[operations.Operation]] |
| ) -> operations.FlattenOperation: |
| return factory.augment_oldstyle_op( |
| operations.FlattenOperation( |
| common.NameContext(transform_proto.unique_name, transform_id), |
| operation_specs.WorkerFlatten( |
| None, [factory.get_only_output_coder(transform_proto)]), |
| factory.counter_factory, |
| factory.state_sampler), |
| transform_proto.unique_name, |
| consumers) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) |
| def create_map_windows( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, |
| consumers: Dict[str, List[operations.Operation]]): |
| assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN |
| window_mapping_fn = pickler.loads(mapping_fn_spec.payload) |
| |
| class MapWindows(beam.DoFn): |
| def process(self, element): |
| key, window = element |
| return [(key, window_mapping_fn(window))] |
| |
| return _create_simple_pardo_operation( |
| factory, transform_id, transform_proto, consumers, MapWindows()) |
| |
| |
| @BeamTransformFactory.register_urn( |
| common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) |
| def create_merge_windows( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, |
| consumers: Dict[str, List[operations.Operation]]): |
| assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN |
| window_fn = pickler.loads(mapping_fn_spec.payload) |
| |
| class MergeWindows(beam.DoFn): |
| def process(self, element): |
| nonce, windows = element |
| |
| original_windows: Set[window.BoundedWindow] = set(windows) |
| merged_windows: MutableMapping[ |
| window.BoundedWindow, |
| Set[window.BoundedWindow]] = collections.defaultdict( |
| set) # noqa: F821 |
| |
| class RecordingMergeContext(window.WindowFn.MergeContext): |
| def merge( |
| self, |
| to_be_merged: Iterable[window.BoundedWindow], |
| merge_result: window.BoundedWindow, |
| ): |
| originals = merged_windows[merge_result] |
| for w in to_be_merged: |
| if w in original_windows: |
| originals.add(w) |
| original_windows.remove(w) |
| else: |
| originals.update(merged_windows.pop(w)) |
| |
| window_fn.merge(RecordingMergeContext(windows)) |
| yield nonce, (original_windows, merged_windows.items()) |
| |
| return _create_simple_pardo_operation( |
| factory, transform_id, transform_proto, consumers, MergeWindows()) |
| |
| |
| @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) |
| def create_to_string_fn( |
| factory: BeamTransformFactory, |
| transform_id: str, |
| transform_proto: beam_runner_api_pb2.PTransform, |
| mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, |
| consumers: Dict[str, List[operations.Operation]]): |
| class ToString(beam.DoFn): |
| def process(self, element): |
| key, value = element |
| return [(key, str(value))] |
| |
| return _create_simple_pardo_operation( |
| factory, transform_id, transform_proto, consumers, ToString()) |