blob: ecf93d77106331294a98579cf3a6700edbed0a0b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""SDK harness for executing Python Fns via the Fn API."""
# pytype: skip-file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import contextlib
import functools
import logging
import queue
import sys
import threading
import time
import traceback
from builtins import object
from concurrent import futures
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
import grpc
from future.utils import raise_
from future.utils import with_metaclass
from apache_beam.coders import coder_impl
from apache_beam.metrics import monitoring_infos
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import metrics_pb2
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.data_plane import PeriodicThread
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler
from apache_beam.utils import thread_pool_executor
if TYPE_CHECKING:
from apache_beam.portability.api import endpoints_pb2
from apache_beam.utils.profiler import Profile
_LOGGER = logging.getLogger(__name__)
# This SDK harness will (by default), log a "lull" in processing if it sees no
# transitions in over 5 minutes.
# 5 minutes * 60 seconds * 1020 millis * 1000 micros * 1000 nanoseconds
DEFAULT_LOG_LULL_TIMEOUT_NS = 5 * 60 * 1000 * 1000 * 1000
DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S = 60
class ShortIdCache(object):
""" Cache for MonitoringInfo "short ids"
"""
def __init__(self):
self._lock = threading.Lock()
self._lastShortId = 0
self._infoKeyToShortId = {} # type: Dict[FrozenSet, str]
self._shortIdToInfo = {} # type: Dict[str, metrics_pb2.MonitoringInfo]
def getShortId(self, monitoring_info):
# type: (metrics_pb2.MonitoringInfo) -> str
""" Returns the assigned shortId for a given MonitoringInfo, assigns one if
not assigned already.
"""
key = monitoring_infos.to_key(monitoring_info)
with self._lock:
try:
return self._infoKeyToShortId[key]
except KeyError:
self._lastShortId += 1
# Convert to a hex string (and drop the '0x') for some compression
shortId = hex(self._lastShortId)[2:]
payload_cleared = metrics_pb2.MonitoringInfo()
payload_cleared.CopyFrom(monitoring_info)
payload_cleared.ClearField('payload')
self._infoKeyToShortId[key] = shortId
self._shortIdToInfo[shortId] = payload_cleared
return shortId
def getInfos(self, short_ids):
#type: (Iterable[str]) -> List[metrics_pb2.MonitoringInfo]
""" Gets the base MonitoringInfo (with payload cleared) for each short ID.
Throws KeyError if an unassigned short ID is encountered.
"""
return [self._shortIdToInfo[short_id] for short_id in short_ids]
SHORT_ID_CACHE = ShortIdCache()
class SdkHarness(object):
REQUEST_METHOD_PREFIX = '_request_'
def __init__(self,
control_address, # type: str
credentials=None,
worker_id=None, # type: Optional[str]
# Caching is disabled by default
state_cache_size=0,
# time-based data buffering is disabled by default
data_buffer_time_limit_ms=0,
profiler_factory=None, # type: Optional[Callable[..., Profile]]
status_address=None, # type: Optional[str]
):
self._alive = True
self._worker_index = 0
self._worker_id = worker_id
self._state_cache = StateCache(state_cache_size)
options = [('grpc.max_receive_message_length', -1),
('grpc.max_send_message_length', -1)]
if credentials is None:
_LOGGER.info('Creating insecure control channel for %s.', control_address)
self._control_channel = GRPCChannelFactory.insecure_channel(
control_address, options=options)
else:
_LOGGER.info('Creating secure control channel for %s.', control_address)
self._control_channel = GRPCChannelFactory.secure_channel(
control_address, credentials, options=options)
grpc.channel_ready_future(self._control_channel).result(timeout=60)
_LOGGER.info('Control channel established.')
self._control_channel = grpc.intercept_channel(
self._control_channel, WorkerIdInterceptor(self._worker_id))
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
credentials, self._worker_id, data_buffer_time_limit_ms)
self._state_handler_factory = GrpcStateHandlerFactory(
self._state_cache, credentials)
self._profiler_factory = profiler_factory
self._fns = KeyedDefaultDict(
lambda id: self._control_stub.GetProcessBundleDescriptor(
beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
process_bundle_descriptor_id=id))
) # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
# BundleProcessor cache across all workers.
self._bundle_processor_cache = BundleProcessorCache(
state_handler_factory=self._state_handler_factory,
data_channel_factory=self._data_channel_factory,
fns=self._fns)
if status_address:
try:
self._status_handler = FnApiWorkerStatusHandler(
status_address, self._bundle_processor_cache
) # type: Optional[FnApiWorkerStatusHandler]
except Exception:
traceback_string = traceback.format_exc()
_LOGGER.warning(
'Error creating worker status request handler, '
'skipping status report. Trace back: %s' % traceback_string)
else:
self._status_handler = None
# TODO(BEAM-8998) use common
# thread_pool_executor.shared_unbounded_instance() to process bundle
# progress once dataflow runner's excessive progress polling is removed.
self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance()
self._responses = queue.Queue(
) # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
_LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def run(self):
self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
self._control_channel)
no_more_work = object()
def get_responses():
# type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
while True:
response = self._responses.get()
if response is no_more_work:
return
yield response
self._alive = True
try:
for work_request in self._control_stub.Control(get_responses()):
_LOGGER.debug('Got work %s', work_request.instruction_id)
request_type = work_request.WhichOneof('request')
# Name spacing the request method with 'request_'. The called method
# will be like self.request_register(request)
getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
work_request)
finally:
self._alive = False
_LOGGER.info('No more requests from control plane')
_LOGGER.info('SDK Harness waiting for in-flight requests to complete')
# Wait until existing requests are processed.
self._worker_thread_pool.shutdown()
# get_responses may be blocked on responses.get(), but we need to return
# control to its caller.
self._responses.put(no_more_work)
# Stop all the workers and clean all the associated resources
self._data_channel_factory.close()
self._state_handler_factory.close()
self._bundle_processor_cache.shutdown()
if self._status_handler:
self._status_handler.close()
_LOGGER.info('Done consuming work.')
def _execute(self,
task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
request # type: beam_fn_api_pb2.InstructionRequest
):
# type: (...) -> None
with statesampler.instruction_id(request.instruction_id):
try:
response = task()
except Exception: # pylint: disable=broad-except
traceback_string = traceback.format_exc()
print(traceback_string, file=sys.stderr)
_LOGGER.error(
'Error processing instruction %s. Original traceback is\n%s\n',
request.instruction_id,
traceback_string)
response = beam_fn_api_pb2.InstructionResponse(
instruction_id=request.instruction_id, error=traceback_string)
self._responses.put(response)
def _request_register(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
# registration request is handled synchronously
self._execute(lambda: self.create_worker().do_instruction(request), request)
def _request_process_bundle(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
self._request_execute(request)
def _request_process_bundle_split(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
self._request_process_bundle_action(request)
def _request_process_bundle_progress(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
self._request_process_bundle_action(request)
def _request_process_bundle_action(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
def task():
instruction_id = getattr(
request, request.WhichOneof('request')).instruction_id
# only process progress/split request when a bundle is in processing.
if (instruction_id in
self._bundle_processor_cache.active_bundle_processors):
self._execute(
lambda: self.create_worker().do_instruction(request), request)
else:
self._execute(
lambda: beam_fn_api_pb2.InstructionResponse(
instruction_id=request.instruction_id,
error=('Unknown process bundle instruction {}').format(
instruction_id)),
request)
self._report_progress_executor.submit(task)
def _request_finalize_bundle(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
self._request_execute(request)
def _request_execute(self, request):
def task():
self._execute(
lambda: self.create_worker().do_instruction(request), request)
self._worker_thread_pool.submit(task)
_LOGGER.debug(
"Currently using %s threads." % len(self._worker_thread_pool._workers))
def create_worker(self):
return SdkWorker(
self._bundle_processor_cache,
state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
profiler_factory=self._profiler_factory)
class BundleProcessorCache(object):
"""A cache for ``BundleProcessor``s.
``BundleProcessor`` objects are cached by the id of their
``beam_fn_api_pb2.ProcessBundleDescriptor``.
Attributes:
fns (dict): A dictionary that maps bundle descriptor IDs to instances of
``beam_fn_api_pb2.ProcessBundleDescriptor``.
state_handler_factory (``StateHandlerFactory``): Used to create state
handlers to be used by a ``bundle_processor.BundleProcessor`` during
processing.
data_channel_factory (``data_plane.DataChannelFactory``)
active_bundle_processors (dict): A dictionary, indexed by instruction IDs,
containing ``bundle_processor.BundleProcessor`` objects that are currently
active processing the corresponding instruction.
cached_bundle_processors (dict): A dictionary, indexed by bundle processor
id, of cached ``bundle_processor.BundleProcessor`` that are not currently
performing processing.
"""
def __init__(self,
state_handler_factory, # type: StateHandlerFactory
data_channel_factory, # type: data_plane.DataChannelFactory
fns # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
):
self.fns = fns
self.state_handler_factory = state_handler_factory
self.data_channel_factory = data_channel_factory
self.active_bundle_processors = {
} # type: Dict[str, Tuple[str, bundle_processor.BundleProcessor]]
self.cached_bundle_processors = collections.defaultdict(
list) # type: DefaultDict[str, List[bundle_processor.BundleProcessor]]
self.last_access_times = collections.defaultdict(
float) # type: DefaultDict[str, float]
self._schedule_periodic_shutdown()
def register(self, bundle_descriptor):
# type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
"""Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id."""
self.fns[bundle_descriptor.id] = bundle_descriptor
def get(self, instruction_id, bundle_descriptor_id):
# type: (str, str) -> bundle_processor.BundleProcessor
"""
Return the requested ``BundleProcessor``, creating it if necessary.
Moves the ``BundleProcessor`` from the inactive to the active cache.
"""
try:
# pop() is threadsafe
processor = self.cached_bundle_processors[bundle_descriptor_id].pop()
except IndexError:
processor = bundle_processor.BundleProcessor(
self.fns[bundle_descriptor_id],
self.state_handler_factory.create_state_handler(
self.fns[bundle_descriptor_id].state_api_service_descriptor),
self.data_channel_factory)
self.active_bundle_processors[
instruction_id] = bundle_descriptor_id, processor
return processor
def lookup(self, instruction_id):
# type: (str) -> Optional[bundle_processor.BundleProcessor]
"""
Return the requested ``BundleProcessor`` from the cache.
"""
return self.active_bundle_processors.get(instruction_id, (None, None))[-1]
def discard(self, instruction_id):
# type: (str) -> None
"""
Remove the ``BundleProcessor`` from the cache.
"""
self.active_bundle_processors[instruction_id][1].shutdown()
del self.active_bundle_processors[instruction_id]
def release(self, instruction_id):
# type: (str) -> None
"""
Release the requested ``BundleProcessor``.
Resets the ``BundleProcessor`` and moves it from the active to the
inactive cache.
"""
descriptor_id, processor = self.active_bundle_processors.pop(instruction_id)
processor.reset()
self.last_access_times[descriptor_id] = time.time()
self.cached_bundle_processors[descriptor_id].append(processor)
def shutdown(self):
"""
Shutdown all ``BundleProcessor``s in the cache.
"""
if self.periodic_shutdown:
self.periodic_shutdown.cancel()
self.periodic_shutdown.join()
self.periodic_shutdown = None
for instruction_id in self.active_bundle_processors:
self.active_bundle_processors[instruction_id][1].shutdown()
del self.active_bundle_processors[instruction_id]
for cached_bundle_processors in self.cached_bundle_processors.values():
BundleProcessorCache._shutdown_cached_bundle_processors(
cached_bundle_processors)
def _schedule_periodic_shutdown(self):
def shutdown_inactive_bundle_processors():
for descriptor_id, last_access_time in self.last_access_times.items():
if (time.time() - last_access_time >
DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S):
BundleProcessorCache._shutdown_cached_bundle_processors(
self.cached_bundle_processors[descriptor_id])
self.periodic_shutdown = PeriodicThread(
DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S,
shutdown_inactive_bundle_processors)
self.periodic_shutdown.daemon = True
self.periodic_shutdown.start()
@staticmethod
def _shutdown_cached_bundle_processors(cached_bundle_processors):
try:
while True:
# pop() is threadsafe
bundle_processor = cached_bundle_processors.pop()
bundle_processor.shutdown()
except IndexError:
pass
class SdkWorker(object):
def __init__(self,
bundle_processor_cache, # type: BundleProcessorCache
state_cache_metrics_fn=list,
profiler_factory=None, # type: Optional[Callable[..., Profile]]
log_lull_timeout_ns=None,
):
self.bundle_processor_cache = bundle_processor_cache
self.state_cache_metrics_fn = state_cache_metrics_fn
self.profiler_factory = profiler_factory
self.log_lull_timeout_ns = (
log_lull_timeout_ns or DEFAULT_LOG_LULL_TIMEOUT_NS)
def do_instruction(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse
request_type = request.WhichOneof('request')
if request_type:
# E.g. if register is set, this will call self.register(request.register))
return getattr(self, request_type)(
getattr(request, request_type), request.instruction_id)
else:
raise NotImplementedError
def register(self,
request, # type: beam_fn_api_pb2.RegisterRequest
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
"""Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s.
This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a
``beam_fn_api_pb2.RegisterRequest``, which the runner sends to the SDK
worker before starting processing to register stages.
"""
for process_bundle_descriptor in request.process_bundle_descriptor:
self.bundle_processor_cache.register(process_bundle_descriptor)
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
register=beam_fn_api_pb2.RegisterResponse())
def process_bundle(self,
request, # type: beam_fn_api_pb2.ProcessBundleRequest
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
bundle_processor = self.bundle_processor_cache.get(
instruction_id, request.process_bundle_descriptor_id)
try:
with bundle_processor.state_handler.process_instruction_id(
instruction_id, request.cache_tokens):
with self.maybe_profile(instruction_id):
delayed_applications, requests_finalization = (
bundle_processor.process_bundle(instruction_id))
monitoring_infos = bundle_processor.monitoring_infos()
monitoring_infos.extend(self.state_cache_metrics_fn())
response = beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
residual_roots=delayed_applications,
monitoring_infos=monitoring_infos,
monitoring_data={
SHORT_ID_CACHE.getShortId(info): info.payload
for info in monitoring_infos
},
requires_finalization=requests_finalization))
# Don't release here if finalize is needed.
if not requests_finalization:
self.bundle_processor_cache.release(instruction_id)
return response
except: # pylint: disable=broad-except
# Don't re-use bundle processors on failure.
self.bundle_processor_cache.discard(instruction_id)
raise
def process_bundle_split(self,
request, # type: beam_fn_api_pb2.ProcessBundleSplitRequest
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
processor = self.bundle_processor_cache.lookup(request.instruction_id)
if processor:
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle_split=processor.try_split(request))
else:
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
error='Instruction not running: %s' % instruction_id)
def _log_lull_in_bundle_processor(self, processor):
state_sampler = processor.state_sampler
sampler_info = state_sampler.get_info()
if (sampler_info and sampler_info.time_since_transition and
sampler_info.time_since_transition > self.log_lull_timeout_ns):
step_name = sampler_info.state_name.step_name
state_name = sampler_info.state_name.name
state_lull_log = (
'Operation ongoing for over %.2f seconds in state %s' %
(sampler_info.time_since_transition / 1e9, state_name))
step_name_log = (' in step %s ' % step_name) if step_name else ''
exec_thread = getattr(sampler_info, 'tracked_thread', None)
if exec_thread is not None:
thread_frame = sys._current_frames().get(exec_thread.ident) # pylint: disable=protected-access
stack_trace = '\n'.join(
traceback.format_stack(thread_frame)) if thread_frame else ''
else:
stack_trace = '-NOT AVAILABLE-'
_LOGGER.warning(
'%s%s without returning. Current Traceback:\n%s',
state_lull_log,
step_name_log,
stack_trace)
def process_bundle_progress(self,
request, # type: beam_fn_api_pb2.ProcessBundleProgressRequest
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
# It is an error to get progress for a not-in-flight bundle.
processor = self.bundle_processor_cache.lookup(request.instruction_id)
if processor:
self._log_lull_in_bundle_processor(processor)
monitoring_infos = processor.monitoring_infos() if processor else []
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse(
monitoring_infos=monitoring_infos,
monitoring_data={
SHORT_ID_CACHE.getShortId(info): info.payload
for info in monitoring_infos
}))
def process_bundle_progress_metadata_request(self,
request, # type: beam_fn_api_pb2.ProcessBundleProgressMetadataRequest
instruction_id # type: str
):
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle_progress=beam_fn_api_pb2.
ProcessBundleProgressMetadataResponse(
monitoring_info=SHORT_ID_CACHE.getInfos(
request.monitoring_info_id)))
def finalize_bundle(self,
request, # type: beam_fn_api_pb2.FinalizeBundleRequest
instruction_id # type: str
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
processor = self.bundle_processor_cache.lookup(request.instruction_id)
if processor:
try:
finalize_response = processor.finalize_bundle()
self.bundle_processor_cache.release(request.instruction_id)
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id, finalize_bundle=finalize_response)
except:
self.bundle_processor_cache.discard(request.instruction_id)
raise
else:
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
error='Instruction not running: %s' % instruction_id)
@contextlib.contextmanager
def maybe_profile(self, instruction_id):
if self.profiler_factory:
profiler = self.profiler_factory(instruction_id)
if profiler:
with profiler:
yield
else:
yield
else:
yield
class StateHandler(with_metaclass(abc.ABCMeta, object)): # type: ignore[misc]
"""An abstract object representing a ``StateHandler``."""
@abc.abstractmethod
def get_raw(self,
state_key, # type: beam_fn_api_pb2.StateKey
continuation_token=None # type: Optional[bytes]
):
# type: (...) -> Tuple[bytes, Optional[bytes]]
raise NotImplementedError(type(self))
@abc.abstractmethod
def append_raw(
self,
state_key, # type: beam_fn_api_pb2.StateKey
data # type: bytes
):
# type: (...) -> _Future
raise NotImplementedError(type(self))
@abc.abstractmethod
def clear(self, state_key):
# type: (beam_fn_api_pb2.StateKey) -> _Future
raise NotImplementedError(type(self))
class StateHandlerFactory(with_metaclass(abc.ABCMeta,
object)): # type: ignore[misc]
"""An abstract factory for creating ``DataChannel``."""
@abc.abstractmethod
def create_state_handler(self, api_service_descriptor):
# type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
"""Returns a ``StateHandler`` from the given ApiServiceDescriptor."""
raise NotImplementedError(type(self))
@abc.abstractmethod
def close(self):
# type: () -> None
"""Close all channels that this factory owns."""
raise NotImplementedError(type(self))
class GrpcStateHandlerFactory(StateHandlerFactory):
"""A factory for ``GrpcStateHandler``.
Caches the created channels by ``state descriptor url``.
"""
def __init__(self, state_cache, credentials=None):
self._state_handler_cache = {} # type: Dict[str, CachingStateHandler]
self._lock = threading.Lock()
self._throwing_state_handler = ThrowingStateHandler()
self._credentials = credentials
self._state_cache = state_cache
def create_state_handler(self, api_service_descriptor):
# type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
if not api_service_descriptor:
return self._throwing_state_handler
url = api_service_descriptor.url
if url not in self._state_handler_cache:
with self._lock:
if url not in self._state_handler_cache:
# Options to have no limits (-1) on the size of the messages
# received or sent over the data plane. The actual buffer size is
# controlled in a layer above.
options = [('grpc.max_receive_message_length', -1),
('grpc.max_send_message_length', -1)]
if self._credentials is None:
_LOGGER.info('Creating insecure state channel for %s.', url)
grpc_channel = GRPCChannelFactory.insecure_channel(
url, options=options)
else:
_LOGGER.info('Creating secure state channel for %s.', url)
grpc_channel = GRPCChannelFactory.secure_channel(
url, self._credentials, options=options)
_LOGGER.info('State channel established.')
# Add workerId to the grpc channel
grpc_channel = grpc.intercept_channel(
grpc_channel, WorkerIdInterceptor())
self._state_handler_cache[url] = CachingStateHandler(
self._state_cache,
GrpcStateHandler(
beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel)))
return self._state_handler_cache[url]
def close(self):
# type: () -> None
_LOGGER.info('Closing all cached gRPC state handlers.')
for _, state_handler in self._state_handler_cache.items():
state_handler.done()
self._state_handler_cache.clear()
self._state_cache.evict_all()
class ThrowingStateHandler(StateHandler):
"""A state handler that errors on any requests."""
def get_raw(self, state_key, coder):
raise RuntimeError(
'Unable to handle state requests for ProcessBundleDescriptor without '
'state ApiServiceDescriptor for state key %s.' % state_key)
def append_raw(self, state_key, coder, elements):
raise RuntimeError(
'Unable to handle state requests for ProcessBundleDescriptor without '
'state ApiServiceDescriptor for state key %s.' % state_key)
def clear(self, state_key):
raise RuntimeError(
'Unable to handle state requests for ProcessBundleDescriptor without '
'state ApiServiceDescriptor for state key %s.' % state_key)
class GrpcStateHandler(StateHandler):
_DONE = object()
def __init__(self, state_stub):
# type: (beam_fn_api_pb2_grpc.BeamFnStateStub) -> None
self._lock = threading.Lock()
self._state_stub = state_stub
self._requests = queue.Queue(
) # type: queue.Queue[beam_fn_api_pb2.StateRequest]
self._responses_by_id = {} # type: Dict[str, _Future]
self._last_id = 0
self._exc_info = None
self._context = threading.local()
self.start()
@contextlib.contextmanager
def process_instruction_id(self, bundle_id):
if getattr(self._context, 'process_instruction_id', None) is not None:
raise RuntimeError(
'Already bound to %r' % self._context.process_instruction_id)
self._context.process_instruction_id = bundle_id
try:
yield
finally:
self._context.process_instruction_id = None
def start(self):
self._done = False
def request_iter():
while True:
request = self._requests.get()
if request is self._DONE or self._done:
break
yield request
responses = self._state_stub.State(request_iter())
def pull_responses():
try:
for response in responses:
# Popping an item from a dictionary is atomic in cPython
future = self._responses_by_id.pop(response.id)
future.set(response)
if self._done:
break
except: # pylint: disable=bare-except
self._exc_info = sys.exc_info()
raise
reader = threading.Thread(target=pull_responses, name='read_state')
reader.daemon = True
reader.start()
def done(self):
self._done = True
self._requests.put(self._DONE)
def get_raw(self,
state_key, # type: beam_fn_api_pb2.StateKey
continuation_token=None # type: Optional[bytes]
):
# type: (...) -> Tuple[bytes, Optional[bytes]]
response = self._blocking_request(
beam_fn_api_pb2.StateRequest(
state_key=state_key,
get=beam_fn_api_pb2.StateGetRequest(
continuation_token=continuation_token)))
return response.get.data, response.get.continuation_token
def append_raw(self,
state_key, # type: Optional[beam_fn_api_pb2.StateKey]
data # type: bytes
):
# type: (...) -> _Future
return self._request(
beam_fn_api_pb2.StateRequest(
state_key=state_key,
append=beam_fn_api_pb2.StateAppendRequest(data=data)))
def clear(self, state_key):
# type: (Optional[beam_fn_api_pb2.StateKey]) -> _Future
return self._request(
beam_fn_api_pb2.StateRequest(
state_key=state_key, clear=beam_fn_api_pb2.StateClearRequest()))
def _request(self, request):
# type: (beam_fn_api_pb2.StateRequest) -> _Future
request.id = self._next_id()
request.instruction_id = self._context.process_instruction_id
# Adding a new item to a dictionary is atomic in cPython
self._responses_by_id[request.id] = future = _Future()
# Request queue is thread-safe
self._requests.put(request)
return future
def _blocking_request(self, request):
# type: (beam_fn_api_pb2.StateRequest) -> beam_fn_api_pb2.StateResponse
req_future = self._request(request)
while not req_future.wait(timeout=1):
if self._exc_info:
t, v, tb = self._exc_info
raise_(t, v, tb)
elif self._done:
raise RuntimeError()
response = req_future.get()
if response.error:
raise RuntimeError(response.error)
else:
return response
def _next_id(self):
# type: () -> str
with self._lock:
# Use a lock here because this GrpcStateHandler is shared across all
# requests which have the same process bundle descriptor. State requests
# can concurrently access this section if a Runner uses threads / workers
# (aka "parallelism") to send data to this SdkHarness and its workers.
self._last_id += 1
request_id = self._last_id
return str(request_id)
class CachingStateHandler(object):
""" A State handler which retrieves and caches state.
If caching is activated, caches across bundles using a supplied cache token.
If activated but no cache token is supplied, caching is done at the bundle
level.
"""
def __init__(self,
global_state_cache, # type: StateCache
underlying_state # type: StateHandler
):
self._underlying = underlying_state
self._state_cache = global_state_cache
self._context = threading.local()
@contextlib.contextmanager
def process_instruction_id(self, bundle_id, cache_tokens):
if getattr(self._context, 'user_state_cache_token', None) is not None:
raise RuntimeError(
'Cache tokens already set to %s' %
self._context.user_state_cache_token)
self._context.side_input_cache_tokens = {}
user_state_cache_token = None
for cache_token_struct in cache_tokens:
if cache_token_struct.HasField("user_state"):
# There should only be one user state token present
assert not user_state_cache_token
user_state_cache_token = cache_token_struct.token
elif cache_token_struct.HasField("side_input"):
self._context.side_input_cache_tokens[
cache_token_struct.side_input.transform_id,
cache_token_struct.side_input.
side_input_id] = cache_token_struct.token
# TODO: Consider a two-level cache to avoid extra logic and locking
# for items cached at the bundle level.
self._context.bundle_cache_token = bundle_id
try:
self._state_cache.initialize_metrics()
self._context.user_state_cache_token = user_state_cache_token
with self._underlying.process_instruction_id(bundle_id):
yield
finally:
self._context.side_input_cache_tokens = {}
self._context.user_state_cache_token = None
self._context.bundle_cache_token = None
def blocking_get(self,
state_key, # type: beam_fn_api_pb2.StateKey
coder, # type: coder_impl.CoderImpl
):
# type: (...) -> Iterable[Any]
cache_token = self._get_cache_token(state_key)
if not cache_token:
# Cache disabled / no cache token. Can't do a lookup/store in the cache.
# Fall back to lazily materializing the state, one element at a time.
return self._lazy_iterator(state_key, coder)
# Cache lookup
cache_state_key = self._convert_to_cache_key(state_key)
cached_value = self._state_cache.get(cache_state_key, cache_token)
if cached_value is None:
# Cache miss, need to retrieve from the Runner
# Further size estimation or the use of the continuation token on the
# runner side could fall back to materializing one item at a time.
# https://jira.apache.org/jira/browse/BEAM-8297
materialized = cached_value = (
self._partially_cached_iterable(state_key, coder))
if isinstance(materialized, (list, self.ContinuationIterable)):
self._state_cache.put(cache_state_key, cache_token, materialized)
else:
_LOGGER.error(
"Uncacheable type %s for key %s. Not caching.",
materialized,
state_key)
return cached_value
def extend(self,
state_key, # type: beam_fn_api_pb2.StateKey
coder, # type: coder_impl.CoderImpl
elements, # type: Iterable[Any]
):
# type: (...) -> _Future
cache_token = self._get_cache_token(state_key)
if cache_token:
# Update the cache
cache_key = self._convert_to_cache_key(state_key)
cached_value = self._state_cache.get(cache_key, cache_token)
# Keep in mind that the state for this key can be evicted
# while executing this function. Either read or write to the cache
# but never do both here!
if cached_value is None:
# We have never cached this key before, first retrieve state
cached_value = self.blocking_get(state_key, coder)
# Just extend the already cached value
if isinstance(cached_value, list):
# Materialize provided iterable to ensure reproducible iterations,
# here and when writing to the state handler below.
elements = list(elements)
# The state is fully cached and can be extended
cached_value.extend(elements)
elif isinstance(cached_value, self.ContinuationIterable):
# The state is too large to be fully cached (continuation token used),
# only the first part is cached, the rest if enumerated via the runner.
pass
else:
# When a corrupt value made it into the cache, we have to fail.
raise Exception("Unexpected cached value: %s" % cached_value)
# Write to state handler
out = coder_impl.create_OutputStream()
for element in elements:
coder.encode_to_stream(element, out, True)
return self._underlying.append_raw(state_key, out.get())
def clear(self, state_key):
# type: (beam_fn_api_pb2.StateKey) -> _Future
cache_token = self._get_cache_token(state_key)
if cache_token:
cache_key = self._convert_to_cache_key(state_key)
self._state_cache.clear(cache_key, cache_token)
return self._underlying.clear(state_key)
def done(self):
# type: () -> None
self._underlying.done()
def _lazy_iterator(
self,
state_key, # type: beam_fn_api_pb2.StateKey
coder, # type: coder_impl.CoderImpl
continuation_token=None # type: Optional[bytes]
):
# type: (...) -> Iterator[Any]
"""Materializes the state lazily, one element at a time.
:return A generator which returns the next element if advanced.
"""
while True:
data, continuation_token = (
self._underlying.get_raw(state_key, continuation_token))
input_stream = coder_impl.create_InputStream(data)
while input_stream.size() > 0:
yield coder.decode_from_stream(input_stream, True)
if not continuation_token:
break
def _get_cache_token(self, state_key):
if not self._state_cache.is_cache_enabled():
return None
elif state_key.HasField('bag_user_state'):
if self._context.user_state_cache_token:
return self._context.user_state_cache_token
else:
return self._context.bundle_cache_token
elif state_key.WhichOneof('type').endswith('_side_input'):
side_input = getattr(state_key, state_key.WhichOneof('type'))
return self._context.side_input_cache_tokens.get(
(side_input.transform_id, side_input.side_input_id),
self._context.bundle_cache_token)
def _partially_cached_iterable(
self,
state_key, # type: beam_fn_api_pb2.StateKey
coder # type: coder_impl.CoderImpl
):
# type: (...) -> Iterable[Any]
"""Materialized the first page of data, concatenated with a lazy iterable
of the rest, if any.
"""
data, continuation_token = self._underlying.get_raw(state_key, None)
head = []
input_stream = coder_impl.create_InputStream(data)
while input_stream.size() > 0:
head.append(coder.decode_from_stream(input_stream, True))
if not continuation_token:
return head
else:
return self.ContinuationIterable(
head,
functools.partial(
self._lazy_iterator, state_key, coder, continuation_token))
class ContinuationIterable(object):
def __init__(self, head, continue_iterator_fn):
self.head = head
self.continue_iterator_fn = continue_iterator_fn
def __iter__(self):
for item in self.head:
yield item
for item in self.continue_iterator_fn():
yield item
@staticmethod
def _convert_to_cache_key(state_key):
return state_key.SerializeToString()
class _Future(object):
"""A simple future object to implement blocking requests.
"""
def __init__(self):
self._event = threading.Event()
def wait(self, timeout=None):
return self._event.wait(timeout)
def get(self, timeout=None):
if self.wait(timeout):
return self._value
else:
raise LookupError()
def set(self, value):
self._value = value
self._event.set()
@classmethod
def done(cls):
# type: () -> _Future
if not hasattr(cls, 'DONE'):
done_future = _Future()
done_future.set(None)
cls.DONE = done_future # type: ignore[attr-defined]
return cls.DONE # type: ignore[attr-defined]
class KeyedDefaultDict(collections.defaultdict):
def __missing__(self, key):
self[key] = self.default_factory(key)
return self[key]