| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """SDK harness for executing Python Fns via the Fn API.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import abc |
| import collections |
| import contextlib |
| import logging |
| import queue |
| import sys |
| import threading |
| import time |
| import traceback |
| from builtins import object |
| from builtins import range |
| from concurrent import futures |
| |
| import grpc |
| from future.utils import raise_ |
| from future.utils import with_metaclass |
| |
| from apache_beam.portability.api import beam_fn_api_pb2 |
| from apache_beam.portability.api import beam_fn_api_pb2_grpc |
| from apache_beam.runners.worker import bundle_processor |
| from apache_beam.runners.worker import data_plane |
| from apache_beam.runners.worker.channel_factory import GRPCChannelFactory |
| from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor |
| |
| |
| class SdkHarness(object): |
| REQUEST_METHOD_PREFIX = '_request_' |
| SCHEDULING_DELAY_THRESHOLD_SEC = 5*60 # 5 Minutes |
| |
| def __init__( |
| self, control_address, worker_count, credentials=None, worker_id=None, |
| profiler_factory=None): |
| self._alive = True |
| self._worker_count = worker_count |
| self._worker_index = 0 |
| self._worker_id = worker_id |
| if credentials is None: |
| logging.info('Creating insecure control channel for %s.', control_address) |
| self._control_channel = GRPCChannelFactory.insecure_channel( |
| control_address) |
| else: |
| logging.info('Creating secure control channel for %s.', control_address) |
| self._control_channel = GRPCChannelFactory.secure_channel( |
| control_address, credentials) |
| grpc.channel_ready_future(self._control_channel).result(timeout=60) |
| logging.info('Control channel established.') |
| |
| self._control_channel = grpc.intercept_channel( |
| self._control_channel, WorkerIdInterceptor(self._worker_id)) |
| self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( |
| credentials) |
| self._state_handler_factory = GrpcStateHandlerFactory(credentials) |
| self._profiler_factory = profiler_factory |
| self._fns = {} |
| # BundleProcessor cache across all workers. |
| self._bundle_processor_cache = BundleProcessorCache( |
| state_handler_factory=self._state_handler_factory, |
| data_channel_factory=self._data_channel_factory, |
| fns=self._fns) |
| # workers for process/finalize bundle. |
| self.workers = queue.Queue() |
| # one worker for progress/split request. |
| self.progress_worker = SdkWorker(self._bundle_processor_cache, |
| profiler_factory=self._profiler_factory) |
| # one thread is enough for getting the progress report. |
| # Assumption: |
| # Progress report generation should not do IO or wait on other resources. |
| # Without wait, having multiple threads will not improve performance and |
| # will only add complexity. |
| self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1) |
| # finalize and process share one thread pool. |
| self._process_thread_pool = futures.ThreadPoolExecutor( |
| max_workers=self._worker_count) |
| self._responses = queue.Queue() |
| self._process_bundle_queue = queue.Queue() |
| self._unscheduled_process_bundle = {} |
| logging.info('Initializing SDKHarness with %s workers.', self._worker_count) |
| |
| def run(self): |
| control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel) |
| no_more_work = object() |
| |
| # Create process workers |
| for _ in range(self._worker_count): |
| # SdkHarness manage function registration and share self._fns with all |
| # the workers. This is needed because function registration (register) |
| # and exceution(process_bundle) are send over different request and we |
| # do not really know which woker is going to process bundle |
| # for a function till we get process_bundle request. Moreover |
| # same function is reused by different process bundle calls and |
| # potentially get executed by different worker. Hence we need a |
| # centralized function list shared among all the workers. |
| self.workers.put( |
| SdkWorker(self._bundle_processor_cache, |
| profiler_factory=self._profiler_factory)) |
| |
| def get_responses(): |
| while True: |
| response = self._responses.get() |
| if response is no_more_work: |
| return |
| yield response |
| |
| self._alive = True |
| monitoring_thread = threading.Thread(target=self._monitor_process_bundle) |
| monitoring_thread.daemon = True |
| monitoring_thread.start() |
| |
| try: |
| for work_request in control_stub.Control(get_responses()): |
| logging.debug('Got work %s', work_request.instruction_id) |
| request_type = work_request.WhichOneof('request') |
| # Name spacing the request method with 'request_'. The called method |
| # will be like self.request_register(request) |
| getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( |
| work_request) |
| finally: |
| self._alive = False |
| |
| logging.info('No more requests from control plane') |
| logging.info('SDK Harness waiting for in-flight requests to complete') |
| # Wait until existing requests are processed. |
| self._progress_thread_pool.shutdown() |
| self._process_thread_pool.shutdown() |
| # get_responses may be blocked on responses.get(), but we need to return |
| # control to its caller. |
| self._responses.put(no_more_work) |
| # Stop all the workers and clean all the associated resources |
| self._data_channel_factory.close() |
| self._state_handler_factory.close() |
| logging.info('Done consuming work.') |
| |
| def _execute(self, task, request): |
| try: |
| response = task() |
| except Exception: # pylint: disable=broad-except |
| traceback_string = traceback.format_exc() |
| print(traceback_string, file=sys.stderr) |
| logging.error( |
| 'Error processing instruction %s. Original traceback is\n%s\n', |
| request.instruction_id, traceback_string) |
| response = beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, error=traceback_string) |
| self._responses.put(response) |
| |
| def _request_register(self, request): |
| |
| def task(): |
| for process_bundle_descriptor in getattr( |
| request, request.WhichOneof('request')).process_bundle_descriptor: |
| self._fns[process_bundle_descriptor.id] = process_bundle_descriptor |
| |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, |
| register=beam_fn_api_pb2.RegisterResponse()) |
| |
| self._execute(task, request) |
| |
| def _request_process_bundle(self, request): |
| |
| def task(): |
| # Take the free worker. Wait till a worker is free. |
| worker = self.workers.get() |
| # Get the first work item in the queue |
| work = self._process_bundle_queue.get() |
| self._unscheduled_process_bundle.pop(work.instruction_id, None) |
| try: |
| self._execute(lambda: worker.do_instruction(work), work) |
| finally: |
| # Put the worker back in the free worker pool |
| self.workers.put(worker) |
| # Create a task for each process_bundle request and schedule it |
| self._process_bundle_queue.put(request) |
| self._unscheduled_process_bundle[request.instruction_id] = time.time() |
| self._process_thread_pool.submit(task) |
| logging.debug( |
| "Currently using %s threads." % len(self._process_thread_pool._threads)) |
| |
| def _request_process_bundle_split(self, request): |
| self._request_process_bundle_action(request) |
| |
| def _request_process_bundle_progress(self, request): |
| self._request_process_bundle_action(request) |
| |
| def _request_process_bundle_action(self, request): |
| |
| def task(): |
| instruction_reference = getattr( |
| request, request.WhichOneof('request')).instruction_reference |
| # only process progress/split request when a bundle is in processing. |
| if (instruction_reference in |
| self._bundle_processor_cache.active_bundle_processors): |
| self._execute( |
| lambda: self.progress_worker.do_instruction(request), request) |
| else: |
| self._execute(lambda: beam_fn_api_pb2.InstructionResponse( |
| instruction_id=request.instruction_id, error=( |
| 'Process bundle request not yet scheduled for instruction {}' if |
| instruction_reference in self._unscheduled_process_bundle else |
| 'Unknown process bundle instruction {}').format( |
| instruction_reference)), request) |
| |
| self._progress_thread_pool.submit(task) |
| |
| def _request_finalize_bundle(self, request): |
| |
| def task(): |
| # Get one available worker. |
| worker = self.workers.get() |
| try: |
| self._execute( |
| lambda: worker.do_instruction(request), request) |
| finally: |
| # Put the worker back in the free worker pool. |
| self.workers.put(worker) |
| |
| self._process_thread_pool.submit(task) |
| |
| def _monitor_process_bundle(self): |
| """ |
| Monitor the unscheduled bundles and log if a bundle is not scheduled for |
| more than SCHEDULING_DELAY_THRESHOLD_SEC. |
| """ |
| while self._alive: |
| time.sleep(SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC) |
| # Check for bundles to be scheduled. |
| if self._unscheduled_process_bundle: |
| current_time = time.time() |
| for instruction_id in self._unscheduled_process_bundle: |
| request_time = None |
| try: |
| request_time = self._unscheduled_process_bundle[instruction_id] |
| except KeyError: |
| pass |
| if request_time: |
| scheduling_delay = current_time - request_time |
| if scheduling_delay > SdkHarness.SCHEDULING_DELAY_THRESHOLD_SEC: |
| logging.warn('Unable to schedule instruction %s for %s', |
| instruction_id, scheduling_delay) |
| |
| |
| class BundleProcessorCache(object): |
| """A cache for ``BundleProcessor``s. |
| |
| ``BundleProcessor`` objects are cached by the id of their |
| ``beam_fn_api_pb2.ProcessBundleDescriptor``. |
| |
| Attributes: |
| fns (dict): A dictionary that maps bundle descriptor IDs to instances of |
| ``beam_fn_api_pb2.ProcessBundleDescriptor``. |
| state_handler_factory (``StateHandlerFactory``): Used to create state |
| handlers to be used by a ``bundle_processor.BundleProcessor`` during |
| processing. |
| data_channel_factory (``data_plane.DataChannelFactory``) |
| active_bundle_processors (dict): A dictionary, indexed by instruction IDs, |
| containing ``bundle_processor.BundleProcessor`` objects that are currently |
| active processing the corresponding instruction. |
| cached_bundle_processors (dict): A dictionary, indexed by bundle processor |
| id, of cached ``bundle_processor.BundleProcessor`` that are not currently |
| performing processing. |
| """ |
| |
| def __init__(self, state_handler_factory, data_channel_factory, fns): |
| self.fns = fns |
| self.state_handler_factory = state_handler_factory |
| self.data_channel_factory = data_channel_factory |
| self.active_bundle_processors = {} |
| self.cached_bundle_processors = collections.defaultdict(list) |
| |
| def register(self, bundle_descriptor): |
| """Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id.""" |
| self.fns[bundle_descriptor.id] = bundle_descriptor |
| |
| def get(self, instruction_id, bundle_descriptor_id): |
| try: |
| # pop() is threadsafe |
| processor = self.cached_bundle_processors[bundle_descriptor_id].pop() |
| except IndexError: |
| processor = bundle_processor.BundleProcessor( |
| self.fns[bundle_descriptor_id], |
| self.state_handler_factory.create_state_handler( |
| self.fns[bundle_descriptor_id].state_api_service_descriptor), |
| self.data_channel_factory) |
| self.active_bundle_processors[ |
| instruction_id] = bundle_descriptor_id, processor |
| return processor |
| |
| def lookup(self, instruction_id): |
| return self.active_bundle_processors.get(instruction_id, (None, None))[-1] |
| |
| def discard(self, instruction_id): |
| self.active_bundle_processors[instruction_id][1].shutdown() |
| del self.active_bundle_processors[instruction_id] |
| |
| def release(self, instruction_id): |
| descriptor_id, processor = self.active_bundle_processors.pop(instruction_id) |
| processor.reset() |
| self.cached_bundle_processors[descriptor_id].append(processor) |
| |
| def shutdown(self): |
| for instruction_id in self.active_bundle_processors: |
| self.active_bundle_processors[instruction_id][1].shutdown() |
| del self.active_bundle_processors[instruction_id] |
| for cached_bundle_processors in self.cached_bundle_processors.values(): |
| while len(cached_bundle_processors) > 0: |
| cached_bundle_processors.pop().shutdown() |
| |
| |
| class SdkWorker(object): |
| |
| def __init__(self, bundle_processor_cache, profiler_factory=None): |
| self.bundle_processor_cache = bundle_processor_cache |
| self.profiler_factory = profiler_factory |
| |
| def do_instruction(self, request): |
| request_type = request.WhichOneof('request') |
| if request_type: |
| # E.g. if register is set, this will call self.register(request.register)) |
| return getattr(self, request_type)(getattr(request, request_type), |
| request.instruction_id) |
| else: |
| raise NotImplementedError |
| |
| def register(self, request, instruction_id): |
| """Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s. |
| |
| This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a |
| ``beam_fn_api_pb2.RegisterRequest``, which the runner sends to the SDK |
| worker before starting processing to register stages. |
| """ |
| |
| for process_bundle_descriptor in request.process_bundle_descriptor: |
| self.bundle_processor_cache.register(process_bundle_descriptor) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| register=beam_fn_api_pb2.RegisterResponse()) |
| |
| def process_bundle(self, request, instruction_id): |
| bundle_processor = self.bundle_processor_cache.get( |
| instruction_id, request.process_bundle_descriptor_reference) |
| try: |
| with bundle_processor.state_handler.process_instruction_id( |
| instruction_id): |
| with self.maybe_profile(instruction_id): |
| delayed_applications, requests_finalization = ( |
| bundle_processor.process_bundle(instruction_id)) |
| response = beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle=beam_fn_api_pb2.ProcessBundleResponse( |
| residual_roots=delayed_applications, |
| metrics=bundle_processor.metrics(), |
| monitoring_infos=bundle_processor.monitoring_infos(), |
| requires_finalization=requests_finalization)) |
| # Don't release here if finalize is needed. |
| if not requests_finalization: |
| self.bundle_processor_cache.release(instruction_id) |
| return response |
| except: # pylint: disable=broad-except |
| # Don't re-use bundle processors on failure. |
| self.bundle_processor_cache.discard(instruction_id) |
| raise |
| |
| def process_bundle_split(self, request, instruction_id): |
| processor = self.bundle_processor_cache.lookup( |
| request.instruction_reference) |
| if processor: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle_split=processor.try_split(request)) |
| else: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| error='Instruction not running: %s' % instruction_id) |
| |
| def process_bundle_progress(self, request, instruction_id): |
| # It is an error to get progress for a not-in-flight bundle. |
| processor = self.bundle_processor_cache.lookup( |
| request.instruction_reference) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse( |
| metrics=processor.metrics() if processor else None, |
| monitoring_infos=processor.monitoring_infos() if processor else [])) |
| |
| def finalize_bundle(self, request, instruction_id): |
| processor = self.bundle_processor_cache.lookup( |
| request.instruction_reference) |
| if processor: |
| try: |
| finalize_response = processor.finalize_bundle() |
| self.bundle_processor_cache.release(request.instruction_reference) |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| finalize_bundle=finalize_response) |
| except: |
| self.bundle_processor_cache.discard(request.instruction_reference) |
| raise |
| else: |
| return beam_fn_api_pb2.InstructionResponse( |
| instruction_id=instruction_id, |
| error='Instruction not running: %s' % instruction_id) |
| |
| def stop(self): |
| self.bundle_processor_cache.shutdown() |
| |
| @contextlib.contextmanager |
| def maybe_profile(self, instruction_id): |
| if self.profiler_factory: |
| profiler = self.profiler_factory(instruction_id) |
| if profiler: |
| with profiler: |
| yield |
| else: |
| yield |
| else: |
| yield |
| |
| |
| class StateHandlerFactory(with_metaclass(abc.ABCMeta, object)): |
| """An abstract factory for creating ``DataChannel``.""" |
| |
| @abc.abstractmethod |
| def create_state_handler(self, api_service_descriptor): |
| """Returns a ``StateHandler`` from the given ApiServiceDescriptor.""" |
| raise NotImplementedError(type(self)) |
| |
| @abc.abstractmethod |
| def close(self): |
| """Close all channels that this factory owns.""" |
| raise NotImplementedError(type(self)) |
| |
| |
| class GrpcStateHandlerFactory(StateHandlerFactory): |
| """A factory for ``GrpcStateHandler``. |
| |
| Caches the created channels by ``state descriptor url``. |
| """ |
| |
| def __init__(self, credentials=None): |
| self._state_handler_cache = {} |
| self._lock = threading.Lock() |
| self._throwing_state_handler = ThrowingStateHandler() |
| self._credentials = credentials |
| |
| def create_state_handler(self, api_service_descriptor): |
| if not api_service_descriptor: |
| return self._throwing_state_handler |
| url = api_service_descriptor.url |
| if url not in self._state_handler_cache: |
| with self._lock: |
| if url not in self._state_handler_cache: |
| # Options to have no limits (-1) on the size of the messages |
| # received or sent over the data plane. The actual buffer size is |
| # controlled in a layer above. |
| options = [('grpc.max_receive_message_length', -1), |
| ('grpc.max_send_message_length', -1)] |
| if self._credentials is None: |
| logging.info('Creating insecure state channel for %s.', url) |
| grpc_channel = GRPCChannelFactory.insecure_channel( |
| url, options=options) |
| else: |
| logging.info('Creating secure state channel for %s.', url) |
| grpc_channel = GRPCChannelFactory.secure_channel( |
| url, self._credentials, options=options) |
| logging.info('State channel established.') |
| # Add workerId to the grpc channel |
| grpc_channel = grpc.intercept_channel(grpc_channel, |
| WorkerIdInterceptor()) |
| self._state_handler_cache[url] = GrpcStateHandler( |
| beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel)) |
| return self._state_handler_cache[url] |
| |
| def close(self): |
| logging.info('Closing all cached gRPC state handlers.') |
| for _, state_handler in self._state_handler_cache.items(): |
| state_handler.done() |
| self._state_handler_cache.clear() |
| |
| |
| class ThrowingStateHandler(object): |
| """A state handler that errors on any requests.""" |
| |
| def blocking_get(self, state_key, instruction_reference): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'out state ApiServiceDescriptor for instruction %s and state key %s.' |
| % (state_key, instruction_reference)) |
| |
| def blocking_append(self, state_key, data, instruction_reference): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'out state ApiServiceDescriptor for instruction %s and state key %s.' |
| % (state_key, instruction_reference)) |
| |
| def blocking_clear(self, state_key, instruction_reference): |
| raise RuntimeError( |
| 'Unable to handle state requests for ProcessBundleDescriptor without ' |
| 'out state ApiServiceDescriptor for instruction %s and state key %s.' |
| % (state_key, instruction_reference)) |
| |
| |
| class GrpcStateHandler(object): |
| |
| _DONE = object() |
| |
| def __init__(self, state_stub): |
| self._lock = threading.Lock() |
| self._state_stub = state_stub |
| self._requests = queue.Queue() |
| self._responses_by_id = {} |
| self._last_id = 0 |
| self._exc_info = None |
| self._context = threading.local() |
| self.start() |
| |
| @contextlib.contextmanager |
| def process_instruction_id(self, bundle_id): |
| if getattr(self._context, 'process_instruction_id', None) is not None: |
| raise RuntimeError( |
| 'Already bound to %r' % self._context.process_instruction_id) |
| self._context.process_instruction_id = bundle_id |
| try: |
| yield |
| finally: |
| self._context.process_instruction_id = None |
| |
| def start(self): |
| self._done = False |
| |
| def request_iter(): |
| while True: |
| request = self._requests.get() |
| if request is self._DONE or self._done: |
| break |
| yield request |
| |
| responses = self._state_stub.State(request_iter()) |
| |
| def pull_responses(): |
| try: |
| for response in responses: |
| self._responses_by_id[response.id].set(response) |
| if self._done: |
| break |
| except: # pylint: disable=bare-except |
| self._exc_info = sys.exc_info() |
| raise |
| |
| reader = threading.Thread(target=pull_responses, name='read_state') |
| reader.daemon = True |
| reader.start() |
| |
| def done(self): |
| self._done = True |
| self._requests.put(self._DONE) |
| |
| def blocking_get(self, state_key, continuation_token=None): |
| response = self._blocking_request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| get=beam_fn_api_pb2.StateGetRequest( |
| continuation_token=continuation_token))) |
| return response.get.data, response.get.continuation_token |
| |
| def blocking_append(self, state_key, data): |
| self._blocking_request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| append=beam_fn_api_pb2.StateAppendRequest(data=data))) |
| |
| def blocking_clear(self, state_key): |
| self._blocking_request( |
| beam_fn_api_pb2.StateRequest( |
| state_key=state_key, |
| clear=beam_fn_api_pb2.StateClearRequest())) |
| |
| def _blocking_request(self, request): |
| request.id = self._next_id() |
| request.instruction_reference = self._context.process_instruction_id |
| self._responses_by_id[request.id] = future = _Future() |
| self._requests.put(request) |
| while not future.wait(timeout=1): |
| if self._exc_info: |
| t, v, tb = self._exc_info |
| raise_(t, v, tb) |
| elif self._done: |
| raise RuntimeError() |
| del self._responses_by_id[request.id] |
| response = future.get() |
| if response.error: |
| raise RuntimeError(response.error) |
| else: |
| return response |
| |
| def _next_id(self): |
| self._last_id += 1 |
| return str(self._last_id) |
| |
| |
| class _Future(object): |
| """A simple future object to implement blocking requests. |
| """ |
| |
| def __init__(self): |
| self._event = threading.Event() |
| |
| def wait(self, timeout=None): |
| return self._event.wait(timeout) |
| |
| def get(self, timeout=None): |
| if self.wait(timeout): |
| return self._value |
| else: |
| raise LookupError() |
| |
| def set(self, value): |
| self._value = value |
| self._event.set() |