| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """SDK harness for executing Python Fns via the Fn API.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import abc |
| import collections |
| import contextlib |
| import logging |
| import queue |
| import sys |
| import threading |
| import time |
| import traceback |
| from builtins import object |
| from builtins import range |
| from concurrent import futures |
| |
| import grpc |
| from future.utils import raise_ |
| from future.utils import with_metaclass |
| |
| from apache_beam.coders import coder_impl |
| from apache_beam.portability.api import beam_fn_api_pb2 |
| from apache_beam.portability.api import beam_fn_api_pb2_grpc |
| from apache_beam.runners.worker import bundle_processor |
| from apache_beam.runners.worker import data_plane |
| from apache_beam.runners.worker.channel_factory import GRPCChannelFactory |
| from apache_beam.runners.worker.statecache import StateCache |
| from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor |
| |
| # This SDK harness will (by default), log a "lull" in processing if it sees no |
| # transitions in over 5 minutes. |
| # 5 minutes * 60 seconds * 1020 millis * 1000 micros * 1000 nanoseconds |
| DEFAULT_LOG_LULL_TIMEOUT_NS = 5 * 60 * 1000 * 1000 * 1000 |
| |
| |
| class SdkHarness(object): |
| REQUEST_METHOD_PREFIX = '_request_' |
| SCHEDULING_DELAY_THRESHOLD_SEC = 5*60 # 5 Minutes |
| |
| def __init__( |
| self, control_address, worker_count, |
| credentials=None, |
| worker_id=None, |
| # Caching is disabled by default |
| state_cache_size=0, |
| profiler_factory=None): |
| self._alive = True |
| self._worker_count = worker_count |
| self._worker_index = 0 |
| self._worker_id = worker_id |
| if credentials is None: |
| logging.info('Creating insecure control channel for %s.', control_address) |
| self._control_channel = GRPCChannelFactory.insecure_channel( |
| control_address) |
| else: |
| logging.info('Creating secure control channel for %s.', control_address) |
| self._control_channel = GRPCChannelFactory.secure_channel( |
| control_address, credentials) |
| grpc.channel_ready_future(self._control_channel).result(timeout=60) |
| logging.info('Control channel established.') |
| |
| self._control_channel = grpc.intercept_channel( |
| self._control_channel, WorkerIdInterceptor(self._worker_id)) |
| self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( |
| credentials, self._worker_id) |
| self._state_handler_factory = GrpcStateHandlerFactory(state_cache_size, |
| credentials) |
| self._profiler_factory = profiler_factory |
| self._fns = {} |
| # BundleProcessor cache across all workers. |
| self._bundle_processor_cache = BundleProcessorCache( |
| state_handler_factory=self._state_handler_factory, |
| data_channel_factory=self._data_channel_factory, |
| fns=self._fns) |
| # workers for process/finalize bundle. |
| self.workers = queue.Queue() |
| # one worker for progress/split request. |
| self.progress_worker = SdkWorker(self._bundle_processor_cache, |
| profiler_factory=self._profiler_factory) |
| # one thread is enough for getting the progress report. |
| # Assumption: |
| # Progress report generation should not do IO or wait on other resources. |
| # Without wait, having multiple threads will not improve performance and |
| # will only add complexity. |
| self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1) |
| # finalize and process share one thread pool. |
| self._process_thread_pool = futures.ThreadPoolExecutor( |
| max_workers=self._worker_count) |
| self._responses = queue.Queue() |
| self._process_bundle_queue = queue.Queue() |
| self._unscheduled_process_bundle = {} |
| logging.info('Initializing SDKHarness with %s workers.', self._worker_count) |
| |
| def run(self): |
| control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel) |
| no_more_work = object() |
| |
| # Create process workers |
| for _ in range(self._worker_count): |
| # SdkHarness manage function registration and share self._fns with all |
| # the workers. This is needed because function registration (register) |
| # and execution (process_bundle) are send over different request and we |
| # do not really know which worker is going to process bundle |
| # for a function till we get process_bundle request. Moreover |
| # same function is reused by different process bundle calls and |
| # potentially get executed by different worker. Hence we need a |
| # centralized function list shared among all the workers. |
| self.workers.put( |
| SdkWorker(self._bundle_processor_cache, |
| profiler_factory=self._profiler_factory)) |
| |
| def get_responses(): |
| while True: |
| response = self._responses.get() |
| if response is no_more_work: |
| return |
| yield response |
| |
| self._alive = True |
| monitoring_thread = threading.Thread(name='SdkHarness_monitor', |
| target=self._monitor_process_bundle) |
| monitoring_thread.daemon = True |
| monitoring_thread.start() |
| |
| try: |
| for work_request in control_stub.Control(get_responses()): |
| logging.debug('Got work %s', work_request.instruction_id) |
| request_type = work_request.WhichOneof('request') |
| # Name spacing the request method with 'request_'. The called method |
| # will be like self.request_register(request) |
| getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( |
| work_request) |
| finally: |
| self._alive = False |
| |
| logging.info('No more requests from control plane') |
| logging.info('SDK Harness waiting for in-flight requests to complete') |
| # Wait until existing requests are processed. |
| self._progress_thread_pool.shutdown() |
| self._process_thread_pool.shutdown() |
| # get_responses may be blocked on responses.get(), but we need to return |
| # control to its caller. |
| self._responses.put(no_more_work) |
| # Stop all the workers and clean all the associated resources |
| self._data_channel_factory.close() |
| self._state_handler_factory.close() |
| logging.info('Done consuming work.') |
| |
| def _execute(self, task, request): |
| try: |
| response = task() |
| except Exception: # pylint: disable=broad-except |
| traceback_string = traceback.format_exc() |
| print(traceback_string, file=sys.stderr) |
| logging.error( |
| 'Error processing instruction %s. Original traceback is\n%s\n', |
| request.instruction_id, traceback_string) |
| response = beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, error=traceback_string) |
| self._responses.put(response) |
| |
| def _request_register(self, request): |
| |
| def task(): |
| for process_bundle_descriptor in getattr( |
| request, request.WhichOneof('request')).process_bundle_descriptor: |
| self._fns[process_bundle_descriptor.id] = process_bundle_descriptor |
| |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, |
| register=beam_fn_api_pb2.RegisterResponse()) |
| |
| self._execute(task, request) |
| |
| def _request_process_bundle(self, request): |
| |
| def task(): |
| # Take the free worker. Wait till a worker is free. |
| worker = self.workers.get() |
| # Get the first work item in the queue |
| work = self._process_bundle_queue.get() |
| self._unscheduled_process_bundle.pop(work.instruction_id, None) |
| try: |
| self._execute(lambda: worker.do_instruction(work), work) |
| finally: |
| # Put the worker back in the free worker pool |
| self.workers.put(worker) |
| # Create a task for each process_bundle request and schedule it |
| self._process_bundle_queue.put(request) |
| self._unscheduled_process_bundle[request.instruction_id] = time.time() |
| self._process_thread_pool.submit(task) |
| logging.debug( |
| "Currently using %s threads." % len(self._process_thread_pool._threads)) |
| |
| def _request_process_bundle_split(self, request): |
| self._request_process_bundle_action(request) |
| |
| def _request_process_bundle_progress(self, request): |
| self._request_process_bundle_action(request) |
| |
| def _request_process_bundle_action(self, request): |
| |
| def task(): |
| instruction_id = getattr( |
| request, request.WhichOneof('request')).instruction_id |
| # only process progress/split request when a bundle is in processing. |
| if (instruction_id in |
| self._bundle_processor_cache.active_bundle_processors): |
| self._execute( |
| lambda: self.progress_worker.do_instruction(request), request) |
| else: |
| self._execute(lambda: beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, error=( |
| 'Process bundle request not yet scheduled for instruction {}' if |
| instruction_id in self._unscheduled_process_bundle else |
| 'Unknown process bundle instruction {}').format( |
| instruction_id)), request) |
| |
| self._progress_thread_pool.submit(task) |
| |
| def _request_finalize_bundle(self, request): |
| |
| def task(): |
| # Get one available worker. |
| worker = self.workers.get() |
| try: |
| self._execute( |
| lambda: worker.do_instruction(request), request) |
| finally: |
| # Put the worker back in the free worker pool. |
| self.workers.put(worker) |
| |
| self._process_thread_pool.submit(task) |
| |
| def _monitor_process_bundle(self): |
| """ |
| Monitor the unscheduled bundles and log if a bundle is not scheduled for |
| more than SCHEDULING_DELAY_THRESHOLD_SEC. |
| """ |
| while self._alive: |
| time.sleep(SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC) |
| # Check for bundles to be scheduled. |
| if self._unscheduled_process_bundle: |
| current_time = time.time() |
| for instruction_id in self._unscheduled_process_bundle: |
| request_time = None |
| try: |
| request_time = self._unscheduled_process_bundle[instruction_id] |
| except KeyError: |
| pass |
| if request_time: |
| scheduling_delay = current_time - request_time |
| if scheduling_delay > SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC: |
| logging.warn('Unable to schedule instruction %s for %s', |
| instruction_id, scheduling_delay) |
| |
| |
| class BundleProcessorCache(object): |
| """A cache for ``BundleProcessor``s. |
| |
| ``BundleProcessor`` objects are cached by the id of their |
| ``beam_fn_api_pb2.ProcessBundleDescriptor``. |
| |
| Attributes: |
| fns (dict): A dictionary that maps bundle descriptor IDs to instances of |
| ``beam_fn_api_pb2.ProcessBundleDescriptor``. |
| state_handler_factory (``StateHandlerFactory``): Used to create state |
| handlers to be used by a ``bundle_processor.BundleProcessor`` during |
| processing. |
| data_channel_factory (``data_plane.DataChannelFactory``) |
| active_bundle_processors (dict): A dictionary, indexed by instruction IDs, |
| containing ``bundle_processor.BundleProcessor`` objects that are currently |
| active processing the corresponding instruction. |
| cached_bundle_processors (dict): A dictionary, indexed by bundle processor |
| id, of cached ``bundle_processor.BundleProcessor`` that are not currently |
| performing processing. |
| """ |
| |
| def __init__(self, state_handler_factory, data_channel_factory, fns): |
| self.fns = fns |
| self.state_handler_factory = state_handler_factory |
| self.data_channel_factory = data_channel_factory |
| self.active_bundle_processors = {} |
| self.cached_bundle_processors = collections.defaultdict(list) |
| |
| def register(self, bundle_descriptor): |
| """Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id.""" |
| self.fns[bundle_descriptor.id] = bundle_descriptor |
| |
| def get(self, instruction_id, bundle_descriptor_id): |
| try: |
| # pop() is threadsafe |
| processor = self.cached_bundle_processors[bundle_descriptor_id].pop() |
| except IndexError: |
| processor = bundle_processor.BundleProcessor( |
| self.fns[bundle_descriptor_id], |
| self.state_handler_factory.create_state_handler( |
| self.fns[bundle_descriptor_id].state_api_service_descriptor), |
| self.data_channel_factory) |
| self.active_bundle_processors[ |
| instruction_id] = bundle_descriptor_id, processor |
| return processor |
| |
| def lookup(self, instruction_id): |
| return self.active_bundle_processors.get(instruction_id, (None, None))[-1] |
| |
| def discard(self, instruction_id): |
| self.active_bundle_processors[instruction_id][1].shutdown() |
| del self.active_bundle_processors[instruction_id] |
| |
| def release(self, instruction_id): |
| descriptor_id, processor = self.active_bundle_processors.pop(instruction_id) |
| processor.reset() |
| self.cached_bundle_processors[descriptor_id].append(processor) |
| |
| def shutdown(self): |
| for instruction_id in self.active_bundle_processors: |
| self.active_bundle_processors[instruction_id][1].shutdown() |
| del self.active_bundle_processors[instruction_id] |
| for cached_bundle_processors in self.cached_bundle_processors.values(): |
| while len(cached_bundle_processors) > 0: |
| cached_bundle_processors.pop().shutdown() |
| |
| |
| class SdkWorker(object): |
| |
| def __init__(self, |
| bundle_processor_cache, |
| profiler_factory=None, |
| log_lull_timeout_ns=None): |
| self.bundle_processor_cache = bundle_processor_cache |
| self.profiler_factory = profiler_factory |
| self.log_lull_timeout_ns = (log_lull_timeout_ns |
| or DEFAULT_LOG_LULL_TIMEOUT_NS) |
| |
| def do_instruction(self, request): |
| request_type = request.WhichOneof('request') |
| if request_type: |
| # E.g. if register is set, this will call self.register(request.register)) |
| return getattr(self, request_type)(getattr(request, request_type), |
| request.instruction_id) |
| else: |
| raise NotImplementedError |
| |
| def register(self, request, instruction_id): |
| """Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s. |
| |
| This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a |
| ``beam_fn_api_pb2.RegisterRequest``, which the runner sends to the SDK |
| worker before starting processing to register stages. |
| """ |
| |
| for process_bundle_descriptor in request.process_bundle_descriptor: |
| self.bundle_processor_cache.register(process_bundle_descriptor) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| register=beam_fn_api_pb2.RegisterResponse()) |
| |
| def process_bundle(self, request, instruction_id): |
| bundle_processor = self.bundle_processor_cache.get( |
| instruction_id, request.process_bundle_descriptor_id) |
| try: |
| with bundle_processor.state_handler.process_instruction_id( |
| instruction_id, request.cache_tokens): |
| with self.maybe_profile(instruction_id): |
| delayed_applications, requests_finalization = ( |
| bundle_processor.process_bundle(instruction_id)) |
| response = beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle=beam_fn_api_pb2.ProcessBundleResponse( |
| residual_roots=delayed_applications, |
| metrics=bundle_processor.metrics(), |
| monitoring_infos=bundle_processor.monitoring_infos(), |
| requires_finalization=requests_finalization)) |
| # Don't release here if finalize is needed. |
| if not requests_finalization: |
| self.bundle_processor_cache.release(instruction_id) |
| return response |
| except: # pylint: disable=broad-except |
| # Don't re-use bundle processors on failure. |
| self.bundle_processor_cache.discard(instruction_id) |
| raise |
| |
| def process_bundle_split(self, request, instruction_id): |
| processor = self.bundle_processor_cache.lookup( |
| request.instruction_id) |
| if processor: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle_split=processor.try_split(request)) |
| else: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| error='Instruction not running: %s' % instruction_id) |
| |
| def _log_lull_in_bundle_processor(self, processor): |
| state_sampler = processor.state_sampler |
| sampler_info = state_sampler.get_info() |
| if (sampler_info |
| and sampler_info.time_since_transition |
| and sampler_info.time_since_transition > self.log_lull_timeout_ns): |
| step_name = sampler_info.state_name.step_name |
| state_name = sampler_info.state_name.name |
| state_lull_log = ( |
| 'There has been a processing lull of over %.2f seconds in state %s' |
| % (sampler_info.time_since_transition / 1e9, state_name)) |
| step_name_log = (' in step %s ' % step_name) if step_name else '' |
| |
| exec_thread = getattr(sampler_info, 'tracked_thread', None) |
| if exec_thread is not None: |
| thread_frame = sys._current_frames().get(exec_thread.ident) # pylint: disable=protected-access |
| stack_trace = '\n'.join( |
| traceback.format_stack(thread_frame)) if thread_frame else '' |
| else: |
| stack_trace = '-NOT AVAILABLE-' |
| |
| logging.warning( |
| '%s%s. Traceback:\n%s', state_lull_log, step_name_log, stack_trace) |
| |
| def process_bundle_progress(self, request, instruction_id): |
| # It is an error to get progress for a not-in-flight bundle. |
| processor = self.bundle_processor_cache.lookup(request.instruction_id) |
| if processor: |
| self._log_lull_in_bundle_processor(processor) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse( |
| metrics=processor.metrics() if processor else None, |
| monitoring_infos=processor.monitoring_infos() if processor else [])) |
| |
| def finalize_bundle(self, request, instruction_id): |
| processor = self.bundle_processor_cache.lookup( |
| request.instruction_id) |
| if processor: |
| try: |
| finalize_response = processor.finalize_bundle() |
| self.bundle_processor_cache.release(request.instruction_id) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| finalize_bundle=finalize_response) |
| except: |
| self.bundle_processor_cache.discard(request.instruction_id) |
| raise |
| else: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| error='Instruction not running: %s' % instruction_id) |
| |
| def stop(self): |
| self.bundle_processor_cache.shutdown() |
| |
| @contextlib.contextmanager |
| def maybe_profile(self, instruction_id): |
| if self.profiler_factory: |
| profiler = self.profiler_factory(instruction_id) |
| if profiler: |
| with profiler: |
| yield |
| else: |
| yield |
| else: |
| yield |
| |
| |
| class StateHandlerFactory(with_metaclass(abc.ABCMeta, object)): |
| """An abstract factory for creating ``DataChannel``.""" |
| |
| @abc.abstractmethod |
| def create_state_handler(self, api_service_descriptor): |
| """Returns a ``StateHandler`` from the given ApiServiceDescriptor.""" |
| raise NotImplementedError(type(self)) |
| |
| @abc.abstractmethod |
| def close(self): |
| """Close all channels that this factory owns.""" |
| raise NotImplementedError(type(self)) |
| |
| |
| class GrpcStateHandlerFactory(StateHandlerFactory): |
| """A factory for ``GrpcStateHandler``. |
| |
| Caches the created channels by ``state descriptor url``. |
| """ |
| |
| def __init__(self, state_cache_size, credentials=None): |
| self._state_handler_cache = {} |
| self._lock = threading.Lock() |
| self._throwing_state_handler = ThrowingStateHandler() |
| self._credentials = credentials |
| self._state_cache = StateCache(state_cache_size) |
| |
| def create_state_handler(self, api_service_descriptor): |
| if not api_service_descriptor: |
| return self._throwing_state_handler |
| url = api_service_descriptor.url |
| if url not in self._state_handler_cache: |
| with self._lock: |
| if url not in self._state_handler_cache: |
| # Options to have no limits (-1) on the size of the messages |
| # received or sent over the data plane. The actual buffer size is |
| # controlled in a layer above. |
| options = [('grpc.max_receive_message_length', -1), |
| ('grpc.max_send_message_length', -1)] |
| if self._credentials is None: |
| logging.info('Creating insecure state channel for %s.', url) |
| grpc_channel = GRPCChannelFactory.insecure_channel( |
| url, options=options) |
| else: |
| logging.info('Creating secure state channel for %s.', url) |
| grpc_channel = GRPCChannelFactory.secure_channel( |
| url, self._credentials, options=options) |
| logging.info('State channel established.') |
| # Add workerId to the grpc channel |
| grpc_channel = grpc.intercept_channel(grpc_channel, |
| WorkerIdInterceptor()) |
| self._state_handler_cache[url] = CachingMaterializingStateHandler( |
| self._state_cache, |
| GrpcStateHandler( |
| beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))) |
| return self._state_handler_cache[url] |
| |
| def close(self): |
| logging.info('Closing all cached gRPC state handlers.') |
| for _, state_handler in self._state_handler_cache.items(): |
| state_handler.done() |
| self._state_handler_cache.clear() |
| self._state_cache.evict_all() |
| |
| |
| class ThrowingStateHandler(object): |
| """A state handler that errors on any requests.""" |
| |
| def blocking_get(self, state_key, coder): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'state ApiServiceDescriptor for state key %s.' % state_key) |
| |
| def append(self, state_key, coder, elements): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'state ApiServiceDescriptor for state key %s.' % state_key) |
| |
| def clear(self, state_key): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'state ApiServiceDescriptor for state key %s.' % state_key) |
| |
| |
| class GrpcStateHandler(object): |
| |
| _DONE = object() |
| |
| def __init__(self, state_stub): |
| self._state_stub = state_stub |
| self._requests = queue.Queue() |
| self._responses_by_id = {} |
| self._last_id = 0 |
| self._exc_info = None |
| self._context = threading.local() |
| self.start() |
| |
| @contextlib.contextmanager |
| def process_instruction_id(self, bundle_id): |
| if getattr(self._context, 'process_instruction_id', None) is not None: |
| raise RuntimeError( |
| 'Already bound to %r' % self._context.process_instruction_id) |
| self._context.process_instruction_id = bundle_id |
| try: |
| yield |
| finally: |
| self._context.process_instruction_id = None |
| |
| def start(self): |
| self._done = False |
| |
| def request_iter(): |
| while True: |
| request = self._requests.get() |
| if request is self._DONE or self._done: |
| break |
| yield request |
| |
| responses = self._state_stub.State(request_iter()) |
| |
| def pull_responses(): |
| try: |
| for response in responses: |
| future = self._responses_by_id.pop(response.id) |
| future.set(response) |
| if self._done: |
| break |
| except: # pylint: disable=bare-except |
| self._exc_info = sys.exc_info() |
| raise |
| |
| reader = threading.Thread(target=pull_responses, name='read_state') |
| reader.daemon = True |
| reader.start() |
| |
| def done(self): |
| self._done = True |
| self._requests.put(self._DONE) |
| |
| def get_raw(self, state_key, continuation_token=None): |
| response = self._blocking_request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| get=beam_fn_api_pb2.StateGetRequest( |
| continuation_token=continuation_token))) |
| return response.get.data, response.get.continuation_token |
| |
| def append_raw(self, state_key, data): |
| return self._request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| append=beam_fn_api_pb2.StateAppendRequest(data=data))) |
| |
| def clear(self, state_key): |
| return self._request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| clear=beam_fn_api_pb2.StateClearRequest())) |
| |
| def _request(self, request): |
| request.id = self._next_id() |
| request.instruction_id = self._context.process_instruction_id |
| self._responses_by_id[request.id] = future = _Future() |
| self._requests.put(request) |
| return future |
| |
| def _blocking_request(self, request): |
| req_future = self._request(request) |
| while not req_future.wait(timeout=1): |
| if self._exc_info: |
| t, v, tb = self._exc_info |
| raise_(t, v, tb) |
| elif self._done: |
| raise RuntimeError() |
| response = req_future.get() |
| if response.error: |
| raise RuntimeError(response.error) |
| else: |
| return response |
| |
| def _next_id(self): |
| self._last_id += 1 |
| return str(self._last_id) |
| |
| |
| class CachingMaterializingStateHandler(object): |
| """ A State handler which retrieves and caches state. """ |
| |
| def __init__(self, global_state_cache, underlying_state): |
| self._underlying = underlying_state |
| self._state_cache = global_state_cache |
| self._context = threading.local() |
| |
| @contextlib.contextmanager |
| def process_instruction_id(self, bundle_id, cache_tokens): |
| if getattr(self._context, 'cache_token', None) is not None: |
| raise RuntimeError( |
| 'Cache tokens already set to %s' % self._context.cache_token) |
| # TODO Also handle cache tokens for side input, if present: |
| # https://issues.apache.org/jira/browse/BEAM-8298 |
| user_state_cache_token = None |
| for cache_token_struct in cache_tokens: |
| if cache_token_struct.HasField("user_state"): |
| # There should only be one user state token present |
| assert not user_state_cache_token |
| user_state_cache_token = cache_token_struct.token |
| try: |
| self._context.cache_token = user_state_cache_token |
| with self._underlying.process_instruction_id(bundle_id): |
| yield |
| finally: |
| self._context.cache_token = None |
| |
| def blocking_get(self, state_key, coder, is_cached=False): |
| if not self._should_be_cached(is_cached): |
| # Cache disabled / no cache token. Can't do a lookup/store in the cache. |
| # Fall back to lazily materializing the state, one element at a time. |
| return self._materialize_iter(state_key, coder) |
| # Cache lookup |
| cache_state_key = self._convert_to_cache_key(state_key) |
| cached_value = self._state_cache.get(cache_state_key, |
| self._context.cache_token) |
| if cached_value is None: |
| # Cache miss, need to retrieve from the Runner |
| # TODO If caching is enabled, this materializes the entire state. |
| # Further size estimation or the use of the continuation token on the |
| # runner side could fall back to materializing one item at a time. |
| # https://jira.apache.org/jira/browse/BEAM-8297 |
| materialized = cached_value = list( |
| self._materialize_iter(state_key, coder)) |
| self._state_cache.put( |
| cache_state_key, |
| self._context.cache_token, |
| materialized) |
| return iter(cached_value) |
| |
| def extend(self, state_key, coder, elements, is_cached=False): |
| if self._should_be_cached(is_cached): |
| # Update the cache |
| cache_key = self._convert_to_cache_key(state_key) |
| self._state_cache.extend(cache_key, self._context.cache_token, elements) |
| # Write to state handler |
| out = coder_impl.create_OutputStream() |
| for element in elements: |
| coder.encode_to_stream(element, out, True) |
| return self._underlying.append_raw(state_key, out.get()) |
| |
| def clear(self, state_key, is_cached=False): |
| if self._should_be_cached(is_cached): |
| cache_key = self._convert_to_cache_key(state_key) |
| self._state_cache.clear(cache_key, self._context.cache_token) |
| return self._underlying.clear(state_key) |
| |
| def done(self): |
| self._underlying.done() |
| |
| def _materialize_iter(self, state_key, coder): |
| """Materializes the state lazily, one element at a time. |
| :return A generator which returns the next element if advanced. |
| """ |
| continuation_token = None |
| while True: |
| data, continuation_token = \ |
| self._underlying.get_raw(state_key, continuation_token) |
| input_stream = coder_impl.create_InputStream(data) |
| while input_stream.size() > 0: |
| yield coder.decode_from_stream(input_stream, True) |
| if not continuation_token: |
| break |
| |
| def _should_be_cached(self, request_is_cached): |
| return (self._state_cache.is_cache_enabled() and |
| request_is_cached and |
| self._context.cache_token) |
| |
| @staticmethod |
| def _convert_to_cache_key(state_key): |
| return state_key.SerializeToString() |
| |
| |
| class _Future(object): |
| """A simple future object to implement blocking requests. |
| """ |
| |
| def __init__(self): |
| self._event = threading.Event() |
| |
| def wait(self, timeout=None): |
| return self._event.wait(timeout) |
| |
| def get(self, timeout=None): |
| if self.wait(timeout): |
| return self._value |
| else: |
| raise LookupError() |
| |
| def set(self, value): |
| self._value = value |
| self._event.set() |
| |
| @classmethod |
| def done(cls): |
| if not hasattr(cls, 'DONE'): |
| done_future = _Future() |
| done_future.set(None) |
| cls.DONE = done_future |
| return cls.DONE |