| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """A PipelineRunner using the SDK harness. |
| """ |
| from __future__ import absolute_import |
| from __future__ import print_function |
| |
| import collections |
| import contextlib |
| import copy |
| import itertools |
| import logging |
| import os |
| import queue |
| import subprocess |
| import sys |
| import threading |
| import time |
| import uuid |
| from builtins import object |
| from concurrent import futures |
| |
| import grpc |
| |
| import apache_beam as beam # pylint: disable=ungrouped-imports |
| from apache_beam import coders |
| from apache_beam.coders.coder_impl import create_InputStream |
| from apache_beam.coders.coder_impl import create_OutputStream |
| from apache_beam.metrics import metric |
| from apache_beam.metrics import monitoring_infos |
| from apache_beam.metrics.execution import MetricResult |
| from apache_beam.metrics.execution import MetricsEnvironment |
| from apache_beam.options import pipeline_options |
| from apache_beam.options.value_provider import RuntimeValueProvider |
| from apache_beam.portability import common_urns |
| from apache_beam.portability import python_urns |
| from apache_beam.portability.api import beam_artifact_api_pb2 |
| from apache_beam.portability.api import beam_artifact_api_pb2_grpc |
| from apache_beam.portability.api import beam_fn_api_pb2 |
| from apache_beam.portability.api import beam_fn_api_pb2_grpc |
| from apache_beam.portability.api import beam_provision_api_pb2 |
| from apache_beam.portability.api import beam_provision_api_pb2_grpc |
| from apache_beam.portability.api import beam_runner_api_pb2 |
| from apache_beam.portability.api import endpoints_pb2 |
| from apache_beam.runners import pipeline_context |
| from apache_beam.runners import runner |
| from apache_beam.runners.portability import artifact_service |
| from apache_beam.runners.portability import fn_api_runner_transforms |
| from apache_beam.runners.portability import portable_metrics |
| from apache_beam.runners.portability.fn_api_runner_transforms import create_buffer_id |
| from apache_beam.runners.portability.fn_api_runner_transforms import only_element |
| from apache_beam.runners.portability.fn_api_runner_transforms import split_buffer_id |
| from apache_beam.runners.portability.fn_api_runner_transforms import unique_name |
| from apache_beam.runners.worker import bundle_processor |
| from apache_beam.runners.worker import data_plane |
| from apache_beam.runners.worker import sdk_worker |
| from apache_beam.runners.worker.channel_factory import GRPCChannelFactory |
| from apache_beam.runners.worker.sdk_worker import _Future |
| from apache_beam.runners.worker.statecache import StateCache |
| from apache_beam.transforms import environments |
| from apache_beam.transforms import trigger |
| from apache_beam.transforms.window import GlobalWindow |
| from apache_beam.transforms.window import GlobalWindows |
| from apache_beam.utils import profiler |
| from apache_beam.utils import proto_utils |
| from apache_beam.utils import windowed_value |
| |
| # This module is experimental. No backwards-compatibility guarantees. |
| |
| ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder( |
| beam.coders.BytesCoder(), |
| beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested( |
| beam.transforms.window.GlobalWindows.windowed_value(b'')) |
| |
| # State caching is enabled in the fn_api_runner for testing, except for one |
| # test which runs without state caching (FnApiRunnerTestWithDisabledCaching). |
| # The cache is disabled in production for other runners. |
| STATE_CACHE_SIZE = 100 |
| |
| |
| class ControlConnection(object): |
| |
| _uid_counter = 0 |
| _lock = threading.Lock() |
| |
| def __init__(self): |
| self._push_queue = queue.Queue() |
| self._input = None |
| self._futures_by_id = dict() |
| self._read_thread = threading.Thread( |
| name='beam_control_read', target=self._read) |
| self._state = BeamFnControlServicer.UNSTARTED_STATE |
| |
| def _read(self): |
| for data in self._input: |
| self._futures_by_id.pop(data.instruction_id).set(data) |
| |
| def push(self, req): |
| if req == BeamFnControlServicer._DONE_MARKER: |
| self._push_queue.put(req) |
| return None |
| if not req.instruction_id: |
| with ControlConnection._lock: |
| ControlConnection._uid_counter += 1 |
| req.instruction_id = 'control_%s' % ControlConnection._uid_counter |
| future = ControlFuture(req.instruction_id) |
| self._futures_by_id[req.instruction_id] = future |
| self._push_queue.put(req) |
| return future |
| |
| def get_req(self): |
| return self._push_queue.get() |
| |
| def set_input(self, input): |
| with ControlConnection._lock: |
| if self._input: |
| raise RuntimeError('input is already set.') |
| self._input = input |
| self._read_thread.start() |
| self._state = BeamFnControlServicer.STARTED_STATE |
| |
| def close(self): |
| with ControlConnection._lock: |
| if self._state == BeamFnControlServicer.STARTED_STATE: |
| self.push(BeamFnControlServicer._DONE_MARKER) |
| self._read_thread.join() |
| self._state = BeamFnControlServicer.DONE_STATE |
| |
| |
| class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer): |
| """Implementation of BeamFnControlServicer for clients.""" |
| |
| UNSTARTED_STATE = 'unstarted' |
| STARTED_STATE = 'started' |
| DONE_STATE = 'done' |
| |
| _DONE_MARKER = object() |
| |
| def __init__(self): |
| self._lock = threading.Lock() |
| self._uid_counter = 0 |
| self._state = self.UNSTARTED_STATE |
| # following self._req_* variables are used for debugging purpose, data is |
| # added only when self._log_req is True. |
| self._req_sent = collections.defaultdict(int) |
| self._req_worker_mapping = {} |
| self._log_req = logging.getLogger().getEffectiveLevel() <= logging.DEBUG |
| self._connections_by_worker_id = collections.defaultdict(ControlConnection) |
| |
| def get_conn_by_worker_id(self, worker_id): |
| with self._lock: |
| return self._connections_by_worker_id[worker_id] |
| |
| def Control(self, iterator, context): |
| with self._lock: |
| if self._state == self.DONE_STATE: |
| return |
| else: |
| self._state = self.STARTED_STATE |
| |
| worker_id = dict(context.invocation_metadata()).get('worker_id') |
| if not worker_id: |
| raise RuntimeError('All workers communicate through gRPC should have ' |
| 'worker_id. Received None.') |
| |
| control_conn = self.get_conn_by_worker_id(worker_id) |
| control_conn.set_input(iterator) |
| |
| while True: |
| to_push = control_conn.get_req() |
| if to_push is self._DONE_MARKER: |
| return |
| yield to_push |
| if self._log_req: |
| self._req_sent[to_push.instruction_id] += 1 |
| |
| def done(self): |
| self._state = self.DONE_STATE |
| logging.debug('Runner: Requests sent by runner: %s', |
| [(str(req), cnt) for req, cnt in self._req_sent.items()]) |
| logging.debug('Runner: Requests multiplexing info: %s', |
| [(str(req), worker) for req, worker |
| in self._req_worker_mapping.items()]) |
| |
| |
| class _ListBuffer(list): |
| """Used to support parititioning of a list.""" |
| def partition(self, n): |
| return [self[k::n] for k in range(n)] |
| |
| |
| class _GroupingBuffer(object): |
| """Used to accumulate groupded (shuffled) results.""" |
| def __init__(self, pre_grouped_coder, post_grouped_coder, windowing): |
| self._key_coder = pre_grouped_coder.key_coder() |
| self._pre_grouped_coder = pre_grouped_coder |
| self._post_grouped_coder = post_grouped_coder |
| self._table = collections.defaultdict(list) |
| self._windowing = windowing |
| self._grouped_output = None |
| |
| def append(self, elements_data): |
| if self._grouped_output: |
| raise RuntimeError('Grouping table append after read.') |
| input_stream = create_InputStream(elements_data) |
| coder_impl = self._pre_grouped_coder.get_impl() |
| key_coder_impl = self._key_coder.get_impl() |
| # TODO(robertwb): We could optimize this even more by using a |
| # window-dropping coder for the data plane. |
| is_trivial_windowing = self._windowing.is_default() |
| while input_stream.size() > 0: |
| windowed_key_value = coder_impl.decode_from_stream(input_stream, True) |
| key, value = windowed_key_value.value |
| self._table[key_coder_impl.encode(key)].append( |
| value if is_trivial_windowing |
| else windowed_key_value.with_value(value)) |
| |
| def partition(self, n): |
| """ It is used to partition _GroupingBuffer to N parts. Once it is |
| partitioned, it would not be re-partitioned with diff N. Re-partition |
| is not supported now. |
| """ |
| if not self._grouped_output: |
| if self._windowing.is_default(): |
| globally_window = GlobalWindows.windowed_value( |
| None, |
| timestamp=GlobalWindow().max_timestamp(), |
| pane_info=windowed_value.PaneInfo( |
| is_first=True, |
| is_last=True, |
| timing=windowed_value.PaneInfoTiming.ON_TIME, |
| index=0, |
| nonspeculative_index=0)).with_value |
| windowed_key_values = lambda key, values: [ |
| globally_window((key, values))] |
| else: |
| # TODO(pabloem, BEAM-7514): Trigger driver needs access to the clock |
| # note that this only comes through if windowing is default - but what |
| # about having multiple firings on the global window. |
| # May need to revise. |
| trigger_driver = trigger.create_trigger_driver(self._windowing, True) |
| windowed_key_values = trigger_driver.process_entire_key |
| coder_impl = self._post_grouped_coder.get_impl() |
| key_coder_impl = self._key_coder.get_impl() |
| self._grouped_output = [[] for _ in range(n)] |
| output_stream_list = [] |
| for _ in range(n): |
| output_stream_list.append(create_OutputStream()) |
| for idx, (encoded_key, windowed_values) in enumerate(self._table.items()): |
| key = key_coder_impl.decode(encoded_key) |
| for wkvs in windowed_key_values(key, windowed_values): |
| coder_impl.encode_to_stream(wkvs, output_stream_list[idx % n], True) |
| for ix, output_stream in enumerate(output_stream_list): |
| self._grouped_output[ix] = [output_stream.get()] |
| self._table = None |
| return self._grouped_output |
| |
| def __iter__(self): |
| """ Since partition() returns a list of lists, add this __iter__ to return |
| a list to simplify code when we need to iterate through ALL elements of |
| _GroupingBuffer. |
| """ |
| return itertools.chain(*self.partition(1)) |
| |
| |
| class _WindowGroupingBuffer(object): |
| """Used to partition windowed side inputs.""" |
| def __init__(self, access_pattern, coder): |
| # Here's where we would use a different type of partitioning |
| # (e.g. also by key) for a different access pattern. |
| if access_pattern.urn == common_urns.side_inputs.ITERABLE.urn: |
| self._kv_extrator = lambda value: ('', value) |
| self._key_coder = coders.SingletonCoder('') |
| self._value_coder = coder.wrapped_value_coder |
| elif access_pattern.urn == common_urns.side_inputs.MULTIMAP.urn: |
| self._kv_extrator = lambda value: value |
| self._key_coder = coder.wrapped_value_coder.key_coder() |
| self._value_coder = ( |
| coder.wrapped_value_coder.value_coder()) |
| else: |
| raise ValueError( |
| "Unknown access pattern: '%s'" % access_pattern.urn) |
| self._windowed_value_coder = coder |
| self._window_coder = coder.window_coder |
| self._values_by_window = collections.defaultdict(list) |
| |
| def append(self, elements_data): |
| input_stream = create_InputStream(elements_data) |
| while input_stream.size() > 0: |
| windowed_value = self._windowed_value_coder.get_impl( |
| ).decode_from_stream(input_stream, True) |
| key, value = self._kv_extrator(windowed_value.value) |
| for window in windowed_value.windows: |
| self._values_by_window[key, window].append(value) |
| |
| def encoded_items(self): |
| value_coder_impl = self._value_coder.get_impl() |
| key_coder_impl = self._key_coder.get_impl() |
| for (key, window), values in self._values_by_window.items(): |
| encoded_window = self._window_coder.encode(window) |
| encoded_key = key_coder_impl.encode_nested(key) |
| output_stream = create_OutputStream() |
| for value in values: |
| value_coder_impl.encode_to_stream(value, output_stream, True) |
| yield encoded_key, encoded_window, output_stream.get() |
| |
| |
| class FnApiRunner(runner.PipelineRunner): |
| |
| def __init__( |
| self, |
| default_environment=None, |
| bundle_repeat=0, |
| use_state_iterables=False, |
| provision_info=None, |
| progress_request_frequency=None): |
| """Creates a new Fn API Runner. |
| |
| Args: |
| default_environment: the default environment to use for UserFns. |
| bundle_repeat: replay every bundle this many extra times, for profiling |
| and debugging |
| use_state_iterables: Intentionally split gbk iterables over state API |
| (for testing) |
| provision_info: provisioning info to make available to workers, or None |
| progress_request_frequency: The frequency (in seconds) that the runner |
| waits before requesting progress from the SDK. |
| """ |
| super(FnApiRunner, self).__init__() |
| self._last_uid = -1 |
| self._default_environment = ( |
| default_environment |
| or environments.EmbeddedPythonEnvironment()) |
| self._bundle_repeat = bundle_repeat |
| self._num_workers = 1 |
| self._progress_frequency = progress_request_frequency |
| self._profiler_factory = None |
| self._use_state_iterables = use_state_iterables |
| self._provision_info = provision_info or ExtendedProvisionInfo( |
| beam_provision_api_pb2.ProvisionInfo( |
| job_id='unknown-job-id', |
| job_name='unknown-job-name', |
| retrieval_token='unused-retrieval-token')) |
| |
| def _next_uid(self): |
| self._last_uid += 1 |
| return str(self._last_uid) |
| |
| def run_pipeline(self, pipeline, options): |
| MetricsEnvironment.set_metrics_supported(False) |
| RuntimeValueProvider.set_runtime_options({}) |
| |
| # Setup "beam_fn_api" experiment options if lacked. |
| experiments = (options.view_as(pipeline_options.DebugOptions).experiments |
| or []) |
| if not 'beam_fn_api' in experiments: |
| experiments.append('beam_fn_api') |
| options.view_as(pipeline_options.DebugOptions).experiments = experiments |
| |
| # This is sometimes needed if type checking is disabled |
| # to enforce that the inputs (and outputs) of GroupByKey operations |
| # are known to be KVs. |
| from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner |
| # TODO: Move group_by_key_input_visitor() to a non-dataflow specific file. |
| pipeline.visit(DataflowRunner.group_by_key_input_visitor()) |
| self._bundle_repeat = self._bundle_repeat or options.view_as( |
| pipeline_options.DirectOptions).direct_runner_bundle_repeat |
| self._num_workers = options.view_as( |
| pipeline_options.DirectOptions).direct_num_workers or self._num_workers |
| self._profiler_factory = profiler.Profile.factory_from_options( |
| options.view_as(pipeline_options.ProfilingOptions)) |
| |
| if 'use_sdf_bounded_source' in experiments: |
| pipeline.replace_all(DataflowRunner._SDF_PTRANSFORM_OVERRIDES) |
| |
| self._latest_run_result = self.run_via_runner_api(pipeline.to_runner_api( |
| default_environment=self._default_environment)) |
| return self._latest_run_result |
| |
| def run_via_runner_api(self, pipeline_proto): |
| stage_context, stages = self.create_stages(pipeline_proto) |
| # TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to |
| # the teststream (if any), and all the stages). |
| return self.run_stages(stage_context, stages) |
| |
| @contextlib.contextmanager |
| def maybe_profile(self): |
| if self._profiler_factory: |
| try: |
| profile_id = 'direct-' + subprocess.check_output( |
| ['git', 'rev-parse', '--abbrev-ref', 'HEAD'] |
| ).decode(errors='ignore').strip() |
| except subprocess.CalledProcessError: |
| profile_id = 'direct-unknown' |
| profiler = self._profiler_factory(profile_id, time_prefix='') |
| else: |
| profiler = None |
| |
| if profiler: |
| with profiler: |
| yield |
| if not self._bundle_repeat: |
| logging.warning( |
| 'The --direct_runner_bundle_repeat option is not set; ' |
| 'a significant portion of the profile may be one-time overhead.') |
| path = profiler.profile_output |
| print('CPU Profile written to %s' % path) |
| try: |
| import gprof2dot # pylint: disable=unused-import |
| if not subprocess.call([ |
| sys.executable, '-m', 'gprof2dot', |
| '-f', 'pstats', path, '-o', path + '.dot']): |
| if not subprocess.call( |
| ['dot', '-Tsvg', '-o', path + '.svg', path + '.dot']): |
| print('CPU Profile rendering at file://%s.svg' |
| % os.path.abspath(path)) |
| except ImportError: |
| # pylint: disable=superfluous-parens |
| print('Please install gprof2dot and dot for profile renderings.') |
| |
| else: |
| # Empty context. |
| yield |
| |
| def create_stages(self, pipeline_proto): |
| return fn_api_runner_transforms.create_and_optimize_stages( |
| copy.deepcopy(pipeline_proto), |
| phases=[fn_api_runner_transforms.annotate_downstream_side_inputs, |
| fn_api_runner_transforms.fix_side_input_pcoll_coders, |
| fn_api_runner_transforms.lift_combiners, |
| fn_api_runner_transforms.expand_sdf, |
| fn_api_runner_transforms.expand_gbk, |
| fn_api_runner_transforms.sink_flattens, |
| fn_api_runner_transforms.greedily_fuse, |
| fn_api_runner_transforms.read_to_impulse, |
| fn_api_runner_transforms.impulse_to_input, |
| fn_api_runner_transforms.inject_timer_pcollections, |
| fn_api_runner_transforms.sort_stages, |
| fn_api_runner_transforms.window_pcollection_coders], |
| known_runner_urns=frozenset([ |
| common_urns.primitives.FLATTEN.urn, |
| common_urns.primitives.GROUP_BY_KEY.urn]), |
| use_state_iterables=self._use_state_iterables) |
| |
| def run_stages(self, stage_context, stages): |
| """Run a list of topologically-sorted stages in batch mode. |
| |
| Args: |
| stage_context (fn_api_runner_transforms.TransformContext) |
| stages (list[fn_api_runner_transforms.Stage]) |
| """ |
| worker_handler_manager = WorkerHandlerManager( |
| stage_context.components.environments, self._provision_info) |
| metrics_by_stage = {} |
| monitoring_infos_by_stage = {} |
| |
| try: |
| with self.maybe_profile(): |
| pcoll_buffers = collections.defaultdict(_ListBuffer) |
| for stage in stages: |
| stage_results = self._run_stage( |
| worker_handler_manager.get_worker_handlers, |
| stage_context.components, |
| stage, |
| pcoll_buffers, |
| stage_context.safe_coders) |
| metrics_by_stage[stage.name] = stage_results.process_bundle.metrics |
| monitoring_infos_by_stage[stage.name] = ( |
| stage_results.process_bundle.monitoring_infos) |
| finally: |
| worker_handler_manager.close_all() |
| return RunnerResult( |
| runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage) |
| |
| def _store_side_inputs_in_state(self, |
| worker_handler, |
| context, |
| pipeline_components, |
| data_side_input, |
| pcoll_buffers, |
| safe_coders): |
| for (transform_id, tag), (buffer_id, si) in data_side_input.items(): |
| _, pcoll_id = split_buffer_id(buffer_id) |
| value_coder = context.coders[safe_coders[ |
| pipeline_components.pcollections[pcoll_id].coder_id]] |
| elements_by_window = _WindowGroupingBuffer(si, value_coder) |
| for element_data in pcoll_buffers[buffer_id]: |
| elements_by_window.append(element_data) |
| for key, window, elements_data in elements_by_window.encoded_items(): |
| state_key = beam_fn_api_pb2.StateKey( |
| multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( |
| transform_id=transform_id, |
| side_input_id=tag, |
| window=window, |
| key=key)) |
| worker_handler.state.append_raw(state_key, elements_data) |
| |
| def _run_bundle_multiple_times_for_testing( |
| self, worker_handler_list, process_bundle_descriptor, data_input, |
| data_output, get_input_coder_callable, cache_token_generator): |
| |
| # all workers share state, so use any worker_handler. |
| worker_handler = worker_handler_list[0] |
| for k in range(self._bundle_repeat): |
| try: |
| worker_handler.state.checkpoint() |
| testing_bundle_manager = ParallelBundleManager( |
| worker_handler_list, lambda pcoll_id: [], |
| get_input_coder_callable, process_bundle_descriptor, |
| self._progress_frequency, k, |
| num_workers=self._num_workers, |
| cache_token_generator=cache_token_generator |
| ) |
| testing_bundle_manager.process_bundle(data_input, data_output) |
| finally: |
| worker_handler.state.restore() |
| |
| def _collect_written_timers_and_add_to_deferred_inputs(self, |
| context, |
| pipeline_components, |
| stage, |
| get_buffer_callable, |
| deferred_inputs): |
| |
| for transform_id, timer_writes in stage.timer_pcollections: |
| |
| # Queue any set timers as new inputs. |
| windowed_timer_coder_impl = context.coders[ |
| pipeline_components.pcollections[timer_writes].coder_id].get_impl() |
| written_timers = get_buffer_callable( |
| create_buffer_id(timer_writes, kind='timers')) |
| if written_timers: |
| # Keep only the "last" timer set per key and window. |
| timers_by_key_and_window = {} |
| for elements_data in written_timers: |
| input_stream = create_InputStream(elements_data) |
| while input_stream.size() > 0: |
| windowed_key_timer = windowed_timer_coder_impl.decode_from_stream( |
| input_stream, True) |
| key, _ = windowed_key_timer.value |
| # TODO: Explode and merge windows. |
| assert len(windowed_key_timer.windows) == 1 |
| timers_by_key_and_window[ |
| key, windowed_key_timer.windows[0]] = windowed_key_timer |
| out = create_OutputStream() |
| for windowed_key_timer in timers_by_key_and_window.values(): |
| windowed_timer_coder_impl.encode_to_stream( |
| windowed_key_timer, out, True) |
| deferred_inputs[transform_id] = _ListBuffer([out.get()]) |
| written_timers[:] = [] |
| |
| def _add_residuals_and_channel_splits_to_deferred_inputs( |
| self, splits, get_input_coder_callable, |
| input_for_callable, last_sent, deferred_inputs): |
| prev_stops = {} |
| for split in splits: |
| for delayed_application in split.residual_roots: |
| deferred_inputs[ |
| input_for_callable( |
| delayed_application.application.transform_id, |
| delayed_application.application.input_id) |
| ].append(delayed_application.application.element) |
| for channel_split in split.channel_splits: |
| coder_impl = get_input_coder_callable(channel_split.transform_id) |
| # TODO(SDF): This requires determanistic ordering of buffer iteration. |
| # TODO(SDF): The return split is in terms of indices. Ideally, |
| # a runner could map these back to actual positions to effectively |
| # describe the two "halves" of the now-split range. Even if we have |
| # to buffer each element we send (or at the very least a bit of |
| # metadata, like position, about each of them) this should be doable |
| # if they're already in memory and we are bounding the buffer size |
| # (e.g. to 10mb plus whatever is eagerly read from the SDK). In the |
| # case of non-split-points, we can either immediately replay the |
| # "non-split-position" elements or record them as we do the other |
| # delayed applications. |
| |
| # Decode and recode to split the encoded buffer by element index. |
| all_elements = list(coder_impl.decode_all(b''.join(last_sent[ |
| channel_split.transform_id]))) |
| residual_elements = all_elements[ |
| channel_split.first_residual_element : prev_stops.get( |
| channel_split.transform_id, len(all_elements)) + 1] |
| if residual_elements: |
| deferred_inputs[channel_split.transform_id].append( |
| coder_impl.encode_all(residual_elements)) |
| prev_stops[ |
| channel_split.transform_id] = channel_split.last_primary_element |
| |
| @staticmethod |
| def _extract_stage_data_endpoints( |
| stage, pipeline_components, data_api_service_descriptor, pcoll_buffers): |
| # Returns maps of transform names to PCollection identifiers. |
| # Also mutates IO stages to point to the data ApiServiceDescriptor. |
| data_input = {} |
| data_side_input = {} |
| data_output = {} |
| for transform in stage.transforms: |
| if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, |
| bundle_processor.DATA_OUTPUT_URN): |
| pcoll_id = transform.spec.payload |
| if transform.spec.urn == bundle_processor.DATA_INPUT_URN: |
| target = transform.unique_name, only_element(transform.outputs) |
| if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: |
| data_input[target] = _ListBuffer([ENCODED_IMPULSE_VALUE]) |
| else: |
| data_input[target] = pcoll_buffers[pcoll_id] |
| coder_id = pipeline_components.pcollections[ |
| only_element(transform.outputs.values())].coder_id |
| elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: |
| target = transform.unique_name, only_element(transform.inputs) |
| data_output[target] = pcoll_id |
| coder_id = pipeline_components.pcollections[ |
| only_element(transform.inputs.values())].coder_id |
| else: |
| raise NotImplementedError |
| data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) |
| if data_api_service_descriptor: |
| data_spec.api_service_descriptor.url = ( |
| data_api_service_descriptor.url) |
| transform.spec.payload = data_spec.SerializeToString() |
| elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS: |
| payload = proto_utils.parse_Bytes( |
| transform.spec.payload, beam_runner_api_pb2.ParDoPayload) |
| for tag, si in payload.side_inputs.items(): |
| data_side_input[transform.unique_name, tag] = ( |
| create_buffer_id(transform.inputs[tag]), si.access_pattern) |
| return data_input, data_side_input, data_output |
| |
| def _run_stage(self, |
| worker_handler_factory, |
| pipeline_components, |
| stage, |
| pcoll_buffers, |
| safe_coders): |
| """Run an individual stage. |
| |
| Args: |
| worker_handler_factory: A ``callable`` that takes in an environment, and |
| returns a ``WorkerHandler`` class. |
| pipeline_components (beam_runner_api_pb2.Components): TODO |
| stage (fn_api_runner_transforms.Stage) |
| pcoll_buffers (collections.defaultdict of str: list): Mapping of |
| PCollection IDs to list that functions as buffer for the |
| ``beam.PCollection``. |
| safe_coders (dict): TODO |
| """ |
| def iterable_state_write(values, element_coder_impl): |
| token = unique_name(None, 'iter').encode('ascii') |
| out = create_OutputStream() |
| for element in values: |
| element_coder_impl.encode_to_stream(element, out, True) |
| worker_handler.state.append_raw( |
| beam_fn_api_pb2.StateKey( |
| runner=beam_fn_api_pb2.StateKey.Runner(key=token)), |
| out.get()) |
| return token |
| |
| worker_handler_list = worker_handler_factory( |
| stage.environment, self._num_workers) |
| |
| # All worker_handlers share the same grpc server, so we can read grpc server |
| # info from any worker_handler and read from the first worker_handler. |
| worker_handler = next(iter(worker_handler_list)) |
| context = pipeline_context.PipelineContext( |
| pipeline_components, iterable_state_write=iterable_state_write) |
| data_api_service_descriptor = worker_handler.data_api_service_descriptor() |
| |
| logging.info('Running %s', stage.name) |
| data_input, data_side_input, data_output = self._extract_endpoints( |
| stage, pipeline_components, data_api_service_descriptor, pcoll_buffers) |
| |
| process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( |
| id=self._next_uid(), |
| transforms={transform.unique_name: transform |
| for transform in stage.transforms}, |
| pcollections=dict(pipeline_components.pcollections.items()), |
| coders=dict(pipeline_components.coders.items()), |
| windowing_strategies=dict( |
| pipeline_components.windowing_strategies.items()), |
| environments=dict(pipeline_components.environments.items())) |
| |
| if worker_handler.state_api_service_descriptor(): |
| process_bundle_descriptor.state_api_service_descriptor.url = ( |
| worker_handler.state_api_service_descriptor().url) |
| |
| # Store the required side inputs into state so it is accessible for the |
| # worker when it runs this bundle. |
| self._store_side_inputs_in_state(worker_handler, |
| context, |
| pipeline_components, |
| data_side_input, |
| pcoll_buffers, |
| safe_coders) |
| |
| def get_buffer(buffer_id): |
| """Returns the buffer for a given (operation_type, PCollection ID). |
| |
| For grouping-typed operations, we produce a ``_GroupingBuffer``. For |
| others, we produce a ``_ListBuffer``. |
| """ |
| kind, name = split_buffer_id(buffer_id) |
| if kind in ('materialize', 'timers'): |
| # If `buffer_id` is not a key in `pcoll_buffers`, it will be added by |
| # the `defaultdict`. |
| return pcoll_buffers[buffer_id] |
| elif kind == 'group': |
| # This is a grouping write, create a grouping buffer if needed. |
| if buffer_id not in pcoll_buffers: |
| original_gbk_transform = name |
| transform_proto = pipeline_components.transforms[ |
| original_gbk_transform] |
| input_pcoll = only_element(list(transform_proto.inputs.values())) |
| output_pcoll = only_element(list(transform_proto.outputs.values())) |
| pre_gbk_coder = context.coders[safe_coders[ |
| pipeline_components.pcollections[input_pcoll].coder_id]] |
| post_gbk_coder = context.coders[safe_coders[ |
| pipeline_components.pcollections[output_pcoll].coder_id]] |
| windowing_strategy = context.windowing_strategies[ |
| pipeline_components |
| .pcollections[output_pcoll].windowing_strategy_id] |
| pcoll_buffers[buffer_id] = _GroupingBuffer( |
| pre_gbk_coder, post_gbk_coder, windowing_strategy) |
| else: |
| # These should be the only two identifiers we produce for now, |
| # but special side input writes may go here. |
| raise NotImplementedError(buffer_id) |
| return pcoll_buffers[buffer_id] |
| |
| def get_input_coder_impl(transform_id): |
| return context.coders[safe_coders[ |
| beam_fn_api_pb2.RemoteGrpcPort.FromString( |
| process_bundle_descriptor.transforms[transform_id].spec.payload |
| ).coder_id |
| ]].get_impl() |
| |
| # Change cache token across bundle repeats |
| cache_token_generator = FnApiRunner.get_cache_token_generator(static=False) |
| |
| self._run_bundle_multiple_times_for_testing( |
| worker_handler_list, process_bundle_descriptor, data_input, data_output, |
| get_input_coder_impl, cache_token_generator=cache_token_generator) |
| |
| bundle_manager = ParallelBundleManager( |
| worker_handler_list, get_buffer, get_input_coder_impl, |
| process_bundle_descriptor, self._progress_frequency, |
| num_workers=self._num_workers, |
| cache_token_generator=cache_token_generator) |
| |
| result, splits = bundle_manager.process_bundle(data_input, data_output) |
| |
| def input_for(transform_id, input_id): |
| input_pcoll = process_bundle_descriptor.transforms[ |
| transform_id].inputs[input_id] |
| for read_id, proto in process_bundle_descriptor.transforms.items(): |
| if (proto.spec.urn == bundle_processor.DATA_INPUT_URN |
| and input_pcoll in proto.outputs.values()): |
| return read_id |
| raise RuntimeError( |
| 'No IO transform feeds %s' % transform_id) |
| |
| last_result = result |
| last_sent = data_input |
| |
| while True: |
| deferred_inputs = collections.defaultdict(_ListBuffer) |
| |
| self._collect_written_timers_and_add_to_deferred_inputs( |
| context, pipeline_components, stage, get_buffer, deferred_inputs) |
| |
| # Queue any process-initiated delayed bundle applications. |
| for delayed_application in last_result.process_bundle.residual_roots: |
| deferred_inputs[ |
| input_for( |
| delayed_application.application.transform_id, |
| delayed_application.application.input_id) |
| ].append(delayed_application.application.element) |
| |
| # Queue any runner-initiated delayed bundle applications. |
| self._add_residuals_and_channel_splits_to_deferred_inputs( |
| splits, get_input_coder_impl, input_for, last_sent, deferred_inputs) |
| |
| if deferred_inputs: |
| # The worker will be waiting on these inputs as well. |
| for other_input in data_input: |
| if other_input not in deferred_inputs: |
| deferred_inputs[other_input] = _ListBuffer([]) |
| # TODO(robertwb): merge results |
| # We cannot split deferred_input until we include residual_roots to |
| # merged results. Without residual_roots, pipeline stops earlier and we |
| # may miss some data. |
| bundle_manager._num_workers = 1 |
| bundle_manager._skip_registration = True |
| last_result, splits = bundle_manager.process_bundle( |
| deferred_inputs, data_output) |
| last_sent = deferred_inputs |
| result = beam_fn_api_pb2.InstructionResponse( |
| process_bundle=beam_fn_api_pb2.ProcessBundleResponse( |
| monitoring_infos=monitoring_infos.consolidate( |
| itertools.chain( |
| result.process_bundle.monitoring_infos, |
| last_result.process_bundle.monitoring_infos))), |
| error=result.error or last_result.error) |
| else: |
| break |
| |
| return result |
| |
| @staticmethod |
| def _extract_endpoints(stage, |
| pipeline_components, |
| data_api_service_descriptor, |
| pcoll_buffers): |
| """Returns maps of transform names to PCollection identifiers. |
| |
| Also mutates IO stages to point to the data ApiServiceDescriptor. |
| |
| Args: |
| stage (fn_api_runner_transforms.Stage): The stage to extract endpoints |
| for. |
| pipeline_components (beam_runner_api_pb2.Components): Components of the |
| pipeline to include coders, transforms, PCollections, etc. |
| data_api_service_descriptor: A GRPC endpoint descriptor for data plane. |
| pcoll_buffers (dict): A dictionary containing buffers for PCollection |
| elements. |
| Returns: |
| A tuple of (data_input, data_side_input, data_output) dictionaries. |
| `data_input` is a dictionary mapping (transform_name, output_name) to a |
| PCollection buffer; `data_output` is a dictionary mapping |
| (transform_name, output_name) to a PCollection ID. |
| """ |
| data_input = {} |
| data_side_input = {} |
| data_output = {} |
| for transform in stage.transforms: |
| if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, |
| bundle_processor.DATA_OUTPUT_URN): |
| pcoll_id = transform.spec.payload |
| if transform.spec.urn == bundle_processor.DATA_INPUT_URN: |
| if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: |
| data_input[transform.unique_name] = _ListBuffer( |
| [ENCODED_IMPULSE_VALUE]) |
| else: |
| data_input[transform.unique_name] = pcoll_buffers[pcoll_id] |
| coder_id = pipeline_components.pcollections[ |
| only_element(transform.outputs.values())].coder_id |
| elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: |
| data_output[transform.unique_name] = pcoll_id |
| coder_id = pipeline_components.pcollections[ |
| only_element(transform.inputs.values())].coder_id |
| else: |
| raise NotImplementedError |
| data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) |
| if data_api_service_descriptor: |
| data_spec.api_service_descriptor.url = ( |
| data_api_service_descriptor.url) |
| transform.spec.payload = data_spec.SerializeToString() |
| elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS: |
| payload = proto_utils.parse_Bytes( |
| transform.spec.payload, beam_runner_api_pb2.ParDoPayload) |
| for tag, si in payload.side_inputs.items(): |
| data_side_input[transform.unique_name, tag] = ( |
| create_buffer_id(transform.inputs[tag]), si.access_pattern) |
| return data_input, data_side_input, data_output |
| |
| # These classes are used to interact with the worker. |
| |
| class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer): |
| |
| class CopyOnWriteState(object): |
| def __init__(self, underlying): |
| self._underlying = underlying |
| self._overlay = {} |
| |
| def __getitem__(self, key): |
| if key in self._overlay: |
| return self._overlay[key] |
| else: |
| return FnApiRunner.StateServicer.CopyOnWriteList( |
| self._underlying, self._overlay, key) |
| |
| def __delitem__(self, key): |
| self._overlay[key] = [] |
| |
| def commit(self): |
| self._underlying.update(self._overlay) |
| return self._underlying |
| |
| class CopyOnWriteList(object): |
| def __init__(self, underlying, overlay, key): |
| self._underlying = underlying |
| self._overlay = overlay |
| self._key = key |
| |
| def __iter__(self): |
| if self._key in self._overlay: |
| return iter(self._overlay[self._key]) |
| else: |
| return iter(self._underlying[self._key]) |
| |
| def append(self, item): |
| if self._key not in self._overlay: |
| self._overlay[self._key] = list(self._underlying[self._key]) |
| self._overlay[self._key].append(item) |
| |
| def __init__(self): |
| self._lock = threading.Lock() |
| self._state = collections.defaultdict(list) |
| self._checkpoint = None |
| self._use_continuation_tokens = False |
| self._continuations = {} |
| |
| def checkpoint(self): |
| assert self._checkpoint is None |
| self._checkpoint = self._state |
| self._state = FnApiRunner.StateServicer.CopyOnWriteState(self._state) |
| |
| def commit(self): |
| self._state.commit() |
| self._state = self._checkpoint.commit() |
| self._checkpoint = None |
| |
| def restore(self): |
| self._state = self._checkpoint |
| self._checkpoint = None |
| |
| @contextlib.contextmanager |
| def process_instruction_id(self, unused_instruction_id): |
| yield |
| |
| def get_raw(self, state_key, continuation_token=None): |
| with self._lock: |
| full_state = self._state[self._to_key(state_key)] |
| if self._use_continuation_tokens: |
| # The token is "nonce:index". |
| if not continuation_token: |
| token_base = 'token_%x' % len(self._continuations) |
| self._continuations[token_base] = tuple(full_state) |
| return b'', '%s:0' % token_base |
| else: |
| token_base, index = continuation_token.split(':') |
| ix = int(index) |
| full_state = self._continuations[token_base] |
| if ix == len(full_state): |
| return b'', None |
| else: |
| return full_state[ix], '%s:%d' % (token_base, ix + 1) |
| else: |
| assert not continuation_token |
| return b''.join(full_state), None |
| |
| def append_raw(self, state_key, data): |
| with self._lock: |
| self._state[self._to_key(state_key)].append(data) |
| return _Future.done() |
| |
| def clear(self, state_key): |
| with self._lock: |
| try: |
| del self._state[self._to_key(state_key)] |
| except KeyError: |
| # This may happen with the caching layer across bundles. Caching may |
| # skip this storage layer for a blocking_get(key) request. Without |
| # the caching, the state for a key would be initialized via the |
| # defaultdict that _state uses. |
| pass |
| return _Future.done() |
| |
| @staticmethod |
| def _to_key(state_key): |
| return state_key.SerializeToString() |
| |
| class GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer): |
| def __init__(self, state): |
| self._state = state |
| |
| def State(self, request_stream, context=None): |
| # Note that this eagerly mutates state, assuming any failures are fatal. |
| # Thus it is safe to ignore instruction_id. |
| for request in request_stream: |
| request_type = request.WhichOneof('request') |
| if request_type == 'get': |
| data, continuation_token = self._state.get_raw( |
| request.state_key, request.get.continuation_token) |
| yield beam_fn_api_pb2.StateResponse( |
| id=request.id, |
| get=beam_fn_api_pb2.StateGetResponse( |
| data=data, continuation_token=continuation_token)) |
| elif request_type == 'append': |
| self._state.append_raw(request.state_key, request.append.data) |
| yield beam_fn_api_pb2.StateResponse( |
| id=request.id, |
| append=beam_fn_api_pb2.StateAppendResponse()) |
| elif request_type == 'clear': |
| self._state.clear(request.state_key) |
| yield beam_fn_api_pb2.StateResponse( |
| id=request.id, |
| clear=beam_fn_api_pb2.StateClearResponse()) |
| else: |
| raise NotImplementedError('Unknown state request: %s' % request_type) |
| |
| class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory): |
| """A singleton cache for a StateServicer.""" |
| |
| def __init__(self, state_handler): |
| self._state_handler = state_handler |
| |
| def create_state_handler(self, api_service_descriptor): |
| """Returns the singleton state handler.""" |
| return self._state_handler |
| |
| def close(self): |
| """Does nothing.""" |
| pass |
| |
| @staticmethod |
| def get_cache_token_generator(static=True): |
| """A generator for cache tokens. |
| :arg static If True, generator always returns the same cache token |
| If False, generator returns a new cache token each time |
| :return A generator which returns a cache token on next(generator) |
| """ |
| def generate_token(identifier): |
| return beam_fn_api_pb2.ProcessBundleRequest.CacheToken( |
| user_state=beam_fn_api_pb2 |
| .ProcessBundleRequest.CacheToken.UserState(), |
| token="cache_token_{}".format(identifier).encode("utf-8")) |
| |
| class StaticGenerator(object): |
| def __init__(self): |
| self._token = generate_token(1) |
| |
| def __iter__(self): |
| # pylint: disable=non-iterator-returned |
| return self |
| |
| def __next__(self): |
| return self._token |
| |
| class DynamicGenerator(object): |
| def __init__(self): |
| self._counter = 0 |
| self._lock = threading.Lock() |
| |
| def __iter__(self): |
| # pylint: disable=non-iterator-returned |
| return self |
| |
| def __next__(self): |
| with self._lock: |
| self._counter += 1 |
| return generate_token(self._counter) |
| |
| return StaticGenerator() if static else DynamicGenerator() |
| |
| |
| class WorkerHandler(object): |
| """worker_handler for a worker. |
| |
| It provides utilities to start / stop the worker, provision any resources for |
| it, as well as provide descriptors for the data, state and logging APIs for |
| it. |
| """ |
| |
| _registered_environments = {} |
| _worker_id_counter = -1 |
| _lock = threading.Lock() |
| |
| def __init__( |
| self, control_handler, data_plane_handler, state, provision_info): |
| """Initialize a WorkerHandler. |
| |
| Args: |
| control_handler: |
| data_plane_handler (data_plane.DataChannel): |
| state: |
| provision_info: |
| """ |
| self.control_handler = control_handler |
| self.data_plane_handler = data_plane_handler |
| self.state = state |
| self.provision_info = provision_info |
| |
| with WorkerHandler._lock: |
| WorkerHandler._worker_id_counter += 1 |
| self.worker_id = 'worker_%s' % WorkerHandler._worker_id_counter |
| |
| def close(self): |
| self.stop_worker() |
| |
| def start_worker(self): |
| raise NotImplementedError |
| |
| def stop_worker(self): |
| raise NotImplementedError |
| |
| def data_api_service_descriptor(self): |
| raise NotImplementedError |
| |
| def state_api_service_descriptor(self): |
| raise NotImplementedError |
| |
| def logging_api_service_descriptor(self): |
| raise NotImplementedError |
| |
| @classmethod |
| def register_environment(cls, urn, payload_type): |
| def wrapper(constructor): |
| cls._registered_environments[urn] = constructor, payload_type |
| return constructor |
| return wrapper |
| |
| @classmethod |
| def create(cls, environment, state, provision_info, grpc_server): |
| constructor, payload_type = cls._registered_environments[environment.urn] |
| return constructor( |
| proto_utils.parse_Bytes(environment.payload, payload_type), |
| state, |
| provision_info, |
| grpc_server) |
| |
| |
| @WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON, None) |
| class EmbeddedWorkerHandler(WorkerHandler): |
| """An in-memory worker_handler for fn API control, state and data planes.""" |
| |
| def __init__(self, unused_payload, state, provision_info, |
| unused_grpc_server=None): |
| super(EmbeddedWorkerHandler, self).__init__( |
| self, data_plane.InMemoryDataChannel(), state, provision_info) |
| self.control_conn = self |
| self.data_conn = self.data_plane_handler |
| state_cache = StateCache(STATE_CACHE_SIZE) |
| self.worker = sdk_worker.SdkWorker( |
| sdk_worker.BundleProcessorCache( |
| FnApiRunner.SingletonStateHandlerFactory( |
| sdk_worker.CachingStateHandler(state_cache, state)), |
| data_plane.InMemoryDataChannelFactory( |
| self.data_plane_handler.inverse()), |
| {}), state_cache_metrics_fn=state_cache.get_monitoring_infos) |
| self._uid_counter = 0 |
| |
| def push(self, request): |
| if not request.instruction_id: |
| self._uid_counter += 1 |
| request.instruction_id = 'control_%s' % self._uid_counter |
| response = self.worker.do_instruction(request) |
| return ControlFuture(request.instruction_id, response) |
| |
| def start_worker(self): |
| pass |
| |
| def stop_worker(self): |
| self.worker.stop() |
| |
| def done(self): |
| pass |
| |
| def data_api_service_descriptor(self): |
| return None |
| |
| def state_api_service_descriptor(self): |
| return None |
| |
| def logging_api_service_descriptor(self): |
| return None |
| |
| |
| class BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer): |
| |
| LOG_LEVEL_MAP = { |
| beam_fn_api_pb2.LogEntry.Severity.CRITICAL: logging.CRITICAL, |
| beam_fn_api_pb2.LogEntry.Severity.ERROR: logging.ERROR, |
| beam_fn_api_pb2.LogEntry.Severity.WARN: logging.WARNING, |
| beam_fn_api_pb2.LogEntry.Severity.NOTICE: logging.INFO + 1, |
| beam_fn_api_pb2.LogEntry.Severity.INFO: logging.INFO, |
| beam_fn_api_pb2.LogEntry.Severity.DEBUG: logging.DEBUG, |
| beam_fn_api_pb2.LogEntry.Severity.TRACE: logging.DEBUG - 1, |
| beam_fn_api_pb2.LogEntry.Severity.UNSPECIFIED: logging.NOTSET, |
| } |
| |
| def Logging(self, log_messages, context=None): |
| yield beam_fn_api_pb2.LogControl() |
| for log_message in log_messages: |
| for log in log_message.log_entries: |
| logging.log(self.LOG_LEVEL_MAP[log.severity], str(log)) |
| |
| |
| class BasicProvisionService( |
| beam_provision_api_pb2_grpc.ProvisionServiceServicer): |
| |
| def __init__(self, info): |
| self._info = info |
| |
| def GetProvisionInfo(self, request, context=None): |
| return beam_provision_api_pb2.GetProvisionInfoResponse( |
| info=self._info) |
| |
| |
| class EmptyArtifactRetrievalService( |
| beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer): |
| |
| def GetManifest(self, request, context=None): |
| return beam_artifact_api_pb2.GetManifestResponse( |
| manifest=beam_artifact_api_pb2.Manifest()) |
| |
| def GetArtifact(self, request, context=None): |
| raise ValueError('No artifacts staged.') |
| |
| |
| class GrpcServer(object): |
| |
| _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 |
| |
| def __init__(self, state, provision_info, max_workers): |
| self.state = state |
| self.provision_info = provision_info |
| self.max_workers = max_workers |
| self.control_server = grpc.server( |
| futures.ThreadPoolExecutor(max_workers=self.max_workers)) |
| self.control_port = self.control_server.add_insecure_port('[::]:0') |
| self.control_address = 'localhost:%s' % self.control_port |
| |
| # Options to have no limits (-1) on the size of the messages |
| # received or sent over the data plane. The actual buffer size |
| # is controlled in a layer above. |
| no_max_message_sizes = [("grpc.max_receive_message_length", -1), |
| ("grpc.max_send_message_length", -1)] |
| self.data_server = grpc.server( |
| futures.ThreadPoolExecutor(max_workers=self.max_workers), |
| options=no_max_message_sizes) |
| self.data_port = self.data_server.add_insecure_port('[::]:0') |
| |
| self.state_server = grpc.server( |
| futures.ThreadPoolExecutor(max_workers=self.max_workers), |
| options=no_max_message_sizes) |
| self.state_port = self.state_server.add_insecure_port('[::]:0') |
| |
| self.control_handler = BeamFnControlServicer() |
| beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server( |
| self.control_handler, self.control_server) |
| |
| # If we have provision info, serve these off the control port as well. |
| if self.provision_info: |
| if self.provision_info.provision_info: |
| provision_info = self.provision_info.provision_info |
| if not provision_info.worker_id: |
| provision_info = copy.copy(provision_info) |
| provision_info.worker_id = str(uuid.uuid4()) |
| beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( |
| BasicProvisionService(self.provision_info.provision_info), |
| self.control_server) |
| |
| if self.provision_info.artifact_staging_dir: |
| service = artifact_service.BeamFilesystemArtifactService( |
| self.provision_info.artifact_staging_dir) |
| else: |
| service = EmptyArtifactRetrievalService() |
| beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( |
| service, self.control_server) |
| |
| self.data_plane_handler = data_plane.BeamFnDataServicer() |
| beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server( |
| self.data_plane_handler, self.data_server) |
| |
| beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( |
| FnApiRunner.GrpcStateServicer(state), |
| self.state_server) |
| |
| self.logging_server = grpc.server( |
| futures.ThreadPoolExecutor(max_workers=2), |
| options=no_max_message_sizes) |
| self.logging_port = self.logging_server.add_insecure_port('[::]:0') |
| beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server( |
| BasicLoggingService(), |
| self.logging_server) |
| |
| logging.info('starting control server on port %s', self.control_port) |
| logging.info('starting data server on port %s', self.data_port) |
| logging.info('starting state server on port %s', self.state_port) |
| logging.info('starting logging server on port %s', self.logging_port) |
| self.logging_server.start() |
| self.state_server.start() |
| self.data_server.start() |
| self.control_server.start() |
| |
| def close(self): |
| self.control_handler.done() |
| to_wait = [ |
| self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), |
| self.data_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), |
| self.state_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS), |
| self.logging_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS) |
| ] |
| for w in to_wait: |
| w.wait() |
| |
| |
| class GrpcWorkerHandler(WorkerHandler): |
| """An grpc based worker_handler for fn API control, state and data planes.""" |
| |
| def __init__(self, state, provision_info, grpc_server): |
| self._grpc_server = grpc_server |
| super(GrpcWorkerHandler, self).__init__( |
| self._grpc_server.control_handler, self._grpc_server.data_plane_handler, |
| state, provision_info) |
| self.state = state |
| |
| self.control_address = self.port_from_worker(self._grpc_server.control_port) |
| self.control_conn = self._grpc_server.control_handler.get_conn_by_worker_id( |
| self.worker_id) |
| |
| self.data_conn = self._grpc_server.data_plane_handler.get_conn_by_worker_id( |
| self.worker_id) |
| |
| def data_api_service_descriptor(self): |
| return endpoints_pb2.ApiServiceDescriptor( |
| url=self.port_from_worker(self._grpc_server.data_port)) |
| |
| def state_api_service_descriptor(self): |
| return endpoints_pb2.ApiServiceDescriptor( |
| url=self.port_from_worker(self._grpc_server.state_port)) |
| |
| def logging_api_service_descriptor(self): |
| return endpoints_pb2.ApiServiceDescriptor( |
| url=self.port_from_worker(self._grpc_server.logging_port)) |
| |
| def close(self): |
| self.control_conn.close() |
| self.data_conn.close() |
| super(GrpcWorkerHandler, self).close() |
| |
| def port_from_worker(self, port): |
| return '%s:%s' % (self.host_from_worker(), port) |
| |
| def host_from_worker(self): |
| return 'localhost' |
| |
| |
| @WorkerHandler.register_environment( |
| common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload) |
| class ExternalWorkerHandler(GrpcWorkerHandler): |
| def __init__(self, external_payload, state, provision_info, grpc_server): |
| super(ExternalWorkerHandler, self).__init__(state, provision_info, |
| grpc_server) |
| self._external_payload = external_payload |
| |
| def start_worker(self): |
| stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub( |
| GRPCChannelFactory.insecure_channel( |
| self._external_payload.endpoint.url)) |
| response = stub.StartWorker( |
| beam_fn_api_pb2.StartWorkerRequest( |
| worker_id=self.worker_id, |
| control_endpoint=endpoints_pb2.ApiServiceDescriptor( |
| url=self.control_address), |
| logging_endpoint=self.logging_api_service_descriptor(), |
| params=self._external_payload.params)) |
| if response.error: |
| raise RuntimeError("Error starting worker: %s" % response.error) |
| |
| def stop_worker(self): |
| pass |
| |
| def host_from_worker(self): |
| import socket |
| return socket.getfqdn() |
| |
| |
| @WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON_GRPC, bytes) |
| class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler): |
| def __init__(self, payload, state, provision_info, grpc_server): |
| super(EmbeddedGrpcWorkerHandler, self).__init__(state, provision_info, |
| grpc_server) |
| if payload: |
| num_workers, state_cache_size = payload.decode('ascii').split(',') |
| self._num_threads = int(num_workers) |
| self._state_cache_size = int(state_cache_size) |
| else: |
| self._num_threads = 1 |
| self._state_cache_size = STATE_CACHE_SIZE |
| |
| def start_worker(self): |
| self.worker = sdk_worker.SdkHarness( |
| self.control_address, worker_count=self._num_threads, |
| state_cache_size=self._state_cache_size, worker_id=self.worker_id) |
| self.worker_thread = threading.Thread( |
| name='run_worker', target=self.worker.run) |
| self.worker_thread.daemon = True |
| self.worker_thread.start() |
| |
| def stop_worker(self): |
| self.worker_thread.join() |
| |
| |
| # The subprocesses module is not threadsafe on Python 2.7. Use this lock to |
| # prevent concurrent calls to POpen(). |
| SUBPROCESS_LOCK = threading.Lock() |
| |
| |
| @WorkerHandler.register_environment(python_urns.SUBPROCESS_SDK, bytes) |
| class SubprocessSdkWorkerHandler(GrpcWorkerHandler): |
| def __init__(self, worker_command_line, state, provision_info, grpc_server): |
| super(SubprocessSdkWorkerHandler, self).__init__(state, provision_info, |
| grpc_server) |
| self._worker_command_line = worker_command_line |
| |
| def start_worker(self): |
| from apache_beam.runners.portability import local_job_service |
| self.worker = local_job_service.SubprocessSdkWorker( |
| self._worker_command_line, self.control_address, self.worker_id) |
| self.worker_thread = threading.Thread( |
| name='run_worker', target=self.worker.run) |
| self.worker_thread.start() |
| |
| def stop_worker(self): |
| self.worker_thread.join() |
| |
| |
| @WorkerHandler.register_environment(common_urns.environments.DOCKER.urn, |
| beam_runner_api_pb2.DockerPayload) |
| class DockerSdkWorkerHandler(GrpcWorkerHandler): |
| def __init__(self, payload, state, provision_info, grpc_server): |
| super(DockerSdkWorkerHandler, self).__init__(state, provision_info, |
| grpc_server) |
| self._container_image = payload.container_image |
| self._container_id = None |
| |
| def host_from_worker(self): |
| if sys.platform == "darwin": |
| # See https://docs.docker.com/docker-for-mac/networking/ |
| return 'host.docker.internal' |
| else: |
| return super(DockerSdkWorkerHandler, self).host_from_worker() |
| |
| def start_worker(self): |
| with SUBPROCESS_LOCK: |
| try: |
| subprocess.check_call(['docker', 'pull', self._container_image]) |
| except Exception: |
| logging.info('Unable to pull image %s' % self._container_image) |
| self._container_id = subprocess.check_output( |
| ['docker', |
| 'run', |
| '-d', |
| # TODO: credentials |
| '--network=host', |
| self._container_image, |
| '--id=%s' % self.worker_id, |
| '--logging_endpoint=%s' % self.logging_api_service_descriptor().url, |
| '--control_endpoint=%s' % self.control_address, |
| '--artifact_endpoint=%s' % self.control_address, |
| '--provision_endpoint=%s' % self.control_address, |
| ]).strip() |
| while True: |
| status = subprocess.check_output([ |
| 'docker', |
| 'inspect', |
| '-f', |
| '{{.State.Status}}', |
| self._container_id]).strip() |
| logging.info('Waiting for docker to start up.Current status is %s' % |
| status) |
| if status == b'running': |
| logging.info('Docker container is running. container_id = %s, ' |
| 'worker_id = %s', self._container_id, self.worker_id) |
| break |
| elif status in (b'dead', b'exited'): |
| subprocess.call([ |
| 'docker', |
| 'container', |
| 'logs', |
| self._container_id]) |
| raise RuntimeError('SDK failed to start. Final status is %s' % status) |
| time.sleep(1) |
| |
| def stop_worker(self): |
| if self._container_id: |
| with SUBPROCESS_LOCK: |
| subprocess.call([ |
| 'docker', |
| 'kill', |
| self._container_id]) |
| |
| |
| class WorkerHandlerManager(object): |
| def __init__(self, environments, job_provision_info): |
| self._environments = environments |
| self._job_provision_info = job_provision_info |
| self._cached_handlers = collections.defaultdict(list) |
| self._state = FnApiRunner.StateServicer() # rename? |
| self._grpc_server = None |
| |
| def get_worker_handlers(self, environment_id, num_workers): |
| if environment_id is None: |
| # Any environment will do, pick one arbitrarily. |
| environment_id = next(iter(self._environments.keys())) |
| environment = self._environments[environment_id] |
| max_total_workers = num_workers * len(self._environments) |
| |
| # assume all environments except EMBEDDED_PYTHON use gRPC. |
| if environment.urn == python_urns.EMBEDDED_PYTHON: |
| pass # no need for a gRPC server |
| elif self._grpc_server is None: |
| self._grpc_server = GrpcServer(self._state, self._job_provision_info, |
| max_total_workers) |
| elif max_total_workers > self._grpc_server.max_workers: |
| # each gRPC server is running with fixed number of threads ( |
| # max_total_workers), which is defined by the first call to |
| # get_worker_handlers(). Assumption here is a worker has a connection to a |
| # gRPC server. In case a stage tries to add more workers |
| # than the max_total_workers, some workers cannot connect to gRPC and |
| # pipeline will hang, hence raise an error here. |
| raise RuntimeError('gRPC servers are running with %s threads, we cannot ' |
| 'attach %s workers.' % (self._grpc_server.max_workers, |
| max_total_workers)) |
| |
| worker_handler_list = self._cached_handlers[environment_id] |
| if len(worker_handler_list) < num_workers: |
| for _ in range(len(worker_handler_list), num_workers): |
| worker_handler = WorkerHandler.create( |
| environment, self._state, self._job_provision_info, |
| self._grpc_server) |
| logging.info("Created Worker handler %s for environment %s", |
| worker_handler, environment) |
| self._cached_handlers[environment_id].append(worker_handler) |
| worker_handler.start_worker() |
| return self._cached_handlers[environment_id][:num_workers] |
| |
| def close_all(self): |
| for worker_handler_list in self._cached_handlers.values(): |
| for worker_handler in set(worker_handler_list): |
| try: |
| worker_handler.close() |
| except Exception: |
| logging.error("Error closing worker_handler %s" % worker_handler, |
| exc_info=True) |
| self._cached_handlers = {} |
| if self._grpc_server is not None: |
| self._grpc_server.close() |
| self._grpc_server = None |
| |
| |
| class ExtendedProvisionInfo(object): |
| def __init__(self, provision_info=None, artifact_staging_dir=None): |
| self.provision_info = ( |
| provision_info or beam_provision_api_pb2.ProvisionInfo()) |
| self.artifact_staging_dir = artifact_staging_dir |
| |
| |
| _split_managers = [] |
| |
| |
| @contextlib.contextmanager |
| def split_manager(stage_name, split_manager): |
| """Registers a split manager to control the flow of elements to a given stage. |
| |
| Used for testing. |
| |
| A split manager should be a coroutine yielding desired split fractions, |
| receiving the corresponding split results. Currently, only one input is |
| supported. |
| """ |
| try: |
| _split_managers.append((stage_name, split_manager)) |
| yield |
| finally: |
| _split_managers.pop() |
| |
| |
| class BundleManager(object): |
| """Manages the execution of a bundle from the runner-side. |
| |
| This class receives a bundle descriptor, and performs the following tasks: |
| - Registration of the bundle with the worker. |
| - Splitting of the bundle |
| - Setting up any other bundle requirements (e.g. side inputs). |
| - Submitting the bundle to worker for execution |
| - Passing bundle input data to the worker |
| - Collecting bundle output data from the worker |
| - Finalizing the bundle. |
| """ |
| |
| _uid_counter = 0 |
| _lock = threading.Lock() |
| |
| def __init__( |
| self, worker_handler_list, get_buffer, get_input_coder_impl, |
| bundle_descriptor, progress_frequency=None, skip_registration=False, |
| cache_token_generator=FnApiRunner.get_cache_token_generator()): |
| """Set up a bundle manager. |
| |
| Args: |
| worker_handler_list |
| get_buffer (Callable[[str], list]) |
| get_input_coder_impl (Callable[[str], Coder]) |
| bundle_descriptor (beam_fn_api_pb2.ProcessBundleDescriptor) |
| progress_frequency |
| skip_registration |
| """ |
| self._worker_handler_list = worker_handler_list |
| self._get_buffer = get_buffer |
| self._get_input_coder_impl = get_input_coder_impl |
| self._bundle_descriptor = bundle_descriptor |
| self._registered = skip_registration |
| self._progress_frequency = progress_frequency |
| self._worker_handler = None |
| self._cache_token_generator = cache_token_generator |
| |
| def _send_input_to_worker(self, |
| process_bundle_id, |
| read_transform_id, |
| byte_streams): |
| data_out = self._worker_handler.data_conn.output_stream( |
| process_bundle_id, read_transform_id) |
| for byte_stream in byte_streams: |
| data_out.write(byte_stream) |
| data_out.close() |
| |
| def _register_bundle_descriptor(self): |
| if self._registered: |
| registration_future = None |
| else: |
| process_bundle_registration = beam_fn_api_pb2.InstructionRequest( |
| register=beam_fn_api_pb2.RegisterRequest( |
| process_bundle_descriptor=[self._bundle_descriptor])) |
| registration_future = self._worker_handler.control_conn.push( |
| process_bundle_registration) |
| self._registered = True |
| |
| return registration_future |
| |
| def _select_split_manager(self): |
| """TODO(pabloem) WHAT DOES THIS DO""" |
| unique_names = set( |
| t.unique_name for t in self._bundle_descriptor.transforms.values()) |
| for stage_name, candidate in reversed(_split_managers): |
| if (stage_name in unique_names |
| or (stage_name + '/Process') in unique_names): |
| split_manager = candidate |
| break |
| else: |
| split_manager = None |
| |
| return split_manager |
| |
| def _generate_splits_for_testing(self, |
| split_manager, |
| inputs, |
| process_bundle_id): |
| split_results = [] |
| read_transform_id, buffer_data = only_element(inputs.items()) |
| |
| byte_stream = b''.join(buffer_data) |
| num_elements = len(list( |
| self._get_input_coder_impl(read_transform_id).decode_all(byte_stream))) |
| |
| # Start the split manager in case it wants to set any breakpoints. |
| split_manager_generator = split_manager(num_elements) |
| try: |
| split_fraction = next(split_manager_generator) |
| done = False |
| except StopIteration: |
| done = True |
| |
| # Send all the data. |
| self._send_input_to_worker( |
| process_bundle_id, read_transform_id, [byte_stream]) |
| |
| # Execute the requested splits. |
| while not done: |
| if split_fraction is None: |
| split_result = None |
| else: |
| split_request = beam_fn_api_pb2.InstructionRequest( |
| process_bundle_split= |
| beam_fn_api_pb2.ProcessBundleSplitRequest( |
| instruction_id=process_bundle_id, |
| desired_splits={ |
| read_transform_id: |
| beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( |
| fraction_of_remainder=split_fraction, |
| estimated_input_elements=num_elements) |
| })) |
| split_response = self._worker_handler.control_conn.push( |
| split_request).get() |
| for t in (0.05, 0.1, 0.2): |
| waiting = ('Instruction not running', 'not yet scheduled') |
| if any(msg in split_response.error for msg in waiting): |
| time.sleep(t) |
| split_response = self._worker_handler.control_conn.push( |
| split_request).get() |
| if 'Unknown process bundle' in split_response.error: |
| # It may have finished too fast. |
| split_result = None |
| elif split_response.error: |
| raise RuntimeError(split_response.error) |
| else: |
| split_result = split_response.process_bundle_split |
| split_results.append(split_result) |
| try: |
| split_fraction = split_manager_generator.send(split_result) |
| except StopIteration: |
| break |
| return split_results |
| |
| def process_bundle(self, inputs, expected_outputs): |
| # Unique id for the instruction processing this bundle. |
| with BundleManager._lock: |
| BundleManager._uid_counter += 1 |
| process_bundle_id = 'bundle_%s' % BundleManager._uid_counter |
| self._worker_handler = self._worker_handler_list[ |
| BundleManager._uid_counter % len(self._worker_handler_list)] |
| |
| # Register the bundle descriptor, if needed - noop if already registered. |
| registration_future = self._register_bundle_descriptor() |
| # Check that the bundle was successfully registered. |
| if registration_future and registration_future.get().error: |
| raise RuntimeError(registration_future.get().error) |
| |
| split_manager = self._select_split_manager() |
| if not split_manager: |
| # If there is no split_manager, write all input data to the channel. |
| for transform_id, elements in inputs.items(): |
| self._send_input_to_worker( |
| process_bundle_id, transform_id, elements) |
| |
| # Actually start the bundle. |
| process_bundle_req = beam_fn_api_pb2.InstructionRequest( |
| instruction_id=process_bundle_id, |
| process_bundle=beam_fn_api_pb2.ProcessBundleRequest( |
| process_bundle_descriptor_id=self._bundle_descriptor.id, |
| cache_tokens=[next(self._cache_token_generator)])) |
| result_future = self._worker_handler.control_conn.push(process_bundle_req) |
| |
| split_results = [] |
| with ProgressRequester( |
| self._worker_handler, process_bundle_id, self._progress_frequency): |
| |
| if split_manager: |
| split_results = self._generate_splits_for_testing( |
| split_manager, inputs, process_bundle_id) |
| |
| # Gather all output data. |
| for output in self._worker_handler.data_conn.input_elements( |
| process_bundle_id, |
| expected_outputs.keys(), |
| abort_callback=lambda: (result_future.is_done() |
| and result_future.get().error)): |
| if output.transform_id in expected_outputs: |
| with BundleManager._lock: |
| self._get_buffer( |
| expected_outputs[output.transform_id]).append(output.data) |
| |
| logging.debug('Wait for the bundle %s to finish.' % process_bundle_id) |
| result = result_future.get() |
| |
| if result.error: |
| raise RuntimeError(result.error) |
| |
| if result.process_bundle.requires_finalization: |
| finalize_request = beam_fn_api_pb2.InstructionRequest( |
| finalize_bundle= |
| beam_fn_api_pb2.FinalizeBundleRequest( |
| instruction_id=process_bundle_id |
| )) |
| self._worker_handler.control_conn.push(finalize_request) |
| |
| return result, split_results |
| |
| |
| class ParallelBundleManager(BundleManager): |
| |
| def __init__( |
| self, worker_handler_list, get_buffer, get_input_coder_impl, |
| bundle_descriptor, progress_frequency=None, skip_registration=False, |
| cache_token_generator=None, **kwargs): |
| super(ParallelBundleManager, self).__init__( |
| worker_handler_list, get_buffer, get_input_coder_impl, |
| bundle_descriptor, progress_frequency, skip_registration, |
| cache_token_generator=cache_token_generator) |
| self._num_workers = kwargs.pop('num_workers', 1) |
| |
| def process_bundle(self, inputs, expected_outputs): |
| part_inputs = [{} for _ in range(self._num_workers)] |
| for name, input in inputs.items(): |
| for ix, part in enumerate(input.partition(self._num_workers)): |
| part_inputs[ix][name] = part |
| |
| merged_result = None |
| split_result_list = [] |
| with futures.ThreadPoolExecutor(max_workers=self._num_workers) as executor: |
| for result, split_result in executor.map(lambda part: BundleManager( |
| self._worker_handler_list, self._get_buffer, |
| self._get_input_coder_impl, self._bundle_descriptor, |
| self._progress_frequency, self._registered, |
| cache_token_generator=self._cache_token_generator).process_bundle( |
| part, expected_outputs), part_inputs): |
| |
| split_result_list += split_result |
| if merged_result is None: |
| merged_result = result |
| else: |
| merged_result = beam_fn_api_pb2.InstructionResponse( |
| process_bundle=beam_fn_api_pb2.ProcessBundleResponse( |
| monitoring_infos=monitoring_infos.consolidate( |
| itertools.chain( |
| result.process_bundle.monitoring_infos, |
| merged_result.process_bundle.monitoring_infos))), |
| error=result.error or merged_result.error) |
| |
| return merged_result, split_result_list |
| |
| |
| class ProgressRequester(threading.Thread): |
| """ Thread that asks SDK Worker for progress reports with a certain frequency. |
| |
| A callback can be passed to call with progress updates. |
| """ |
| |
| def __init__(self, worker_handler, instruction_id, frequency, callback=None): |
| super(ProgressRequester, self).__init__() |
| self._worker_handler = worker_handler |
| self._instruction_id = instruction_id |
| self._frequency = frequency |
| self._done = False |
| self._latest_progress = None |
| self._callback = callback |
| self.daemon = True |
| |
| def __enter__(self): |
| if self._frequency: |
| self.start() |
| |
| def __exit__(self, *unused_exc_info): |
| if self._frequency: |
| self.stop() |
| |
| def run(self): |
| while not self._done: |
| try: |
| progress_result = self._worker_handler.control_conn.push( |
| beam_fn_api_pb2.InstructionRequest( |
| process_bundle_progress= |
| beam_fn_api_pb2.ProcessBundleProgressRequest( |
| instruction_id=self._instruction_id))).get() |
| self._latest_progress = progress_result.process_bundle_progress |
| if self._callback: |
| self._callback(self._latest_progress) |
| except Exception as exn: |
| logging.error("Bad progress: %s", exn) |
| time.sleep(self._frequency) |
| |
| def stop(self): |
| self._done = True |
| |
| |
| class ControlFuture(object): |
| def __init__(self, instruction_id, response=None): |
| self.instruction_id = instruction_id |
| if response: |
| self._response = response |
| else: |
| self._response = None |
| self._condition = threading.Condition() |
| |
| def is_done(self): |
| return self._response is not None |
| |
| def set(self, response): |
| with self._condition: |
| self._response = response |
| self._condition.notify_all() |
| |
| def get(self, timeout=None): |
| if not self._response: |
| with self._condition: |
| if not self._response: |
| self._condition.wait(timeout) |
| return self._response |
| |
| |
| class FnApiMetrics(metric.MetricResults): |
| def __init__(self, step_monitoring_infos, user_metrics_only=True): |
| """Used for querying metrics from the PipelineResult object. |
| |
| step_monitoring_infos: Per step metrics specified as MonitoringInfos. |
| user_metrics_only: If true, includes user metrics only. |
| """ |
| self._counters = {} |
| self._distributions = {} |
| self._gauges = {} |
| self._user_metrics_only = user_metrics_only |
| self._monitoring_infos = step_monitoring_infos |
| |
| for smi in step_monitoring_infos.values(): |
| counters, distributions, gauges = \ |
| portable_metrics.from_monitoring_infos(smi, user_metrics_only) |
| self._counters.update(counters) |
| self._distributions.update(distributions) |
| self._gauges.update(gauges) |
| |
| def query(self, filter=None): |
| counters = [MetricResult(k, v, v) |
| for k, v in self._counters.items() |
| if self.matches(filter, k)] |
| distributions = [MetricResult(k, v, v) |
| for k, v in self._distributions.items() |
| if self.matches(filter, k)] |
| gauges = [MetricResult(k, v, v) |
| for k, v in self._gauges.items() |
| if self.matches(filter, k)] |
| |
| return {self.COUNTERS: counters, |
| self.DISTRIBUTIONS: distributions, |
| self.GAUGES: gauges} |
| |
| def monitoring_infos(self): |
| return [item for sublist in self._monitoring_infos.values() for item in |
| sublist] |
| |
| |
| class RunnerResult(runner.PipelineResult): |
| def __init__(self, state, monitoring_infos_by_stage, metrics_by_stage): |
| super(RunnerResult, self).__init__(state) |
| self._monitoring_infos_by_stage = monitoring_infos_by_stage |
| self._metrics_by_stage = metrics_by_stage |
| self._metrics = None |
| self._monitoring_metrics = None |
| |
| def wait_until_finish(self, duration=None): |
| return self._state |
| |
| def metrics(self): |
| """Returns a queryable object including user metrics only.""" |
| if self._metrics is None: |
| self._metrics = FnApiMetrics( |
| self._monitoring_infos_by_stage, user_metrics_only=True) |
| return self._metrics |
| |
| def monitoring_metrics(self): |
| """Returns a queryable object including all metrics.""" |
| if self._monitoring_metrics is None: |
| self._monitoring_metrics = FnApiMetrics( |
| self._monitoring_infos_by_stage, user_metrics_only=False) |
| return self._monitoring_metrics |