| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| # TODO: https://github.com/apache/beam/issues/21822 |
| # mypy: ignore-errors |
| |
| """An extensible run inference transform. |
| |
| Users of this module can extend the ModelHandler class for any machine learning |
| framework. A ModelHandler implementation is a required parameter of |
| RunInference. |
| |
| The transform handles standard inference functionality, like metric |
| collection, sharing model between threads, and batching elements. |
| """ |
| |
| import logging |
| import os |
| import pickle |
| import sys |
| import threading |
| import time |
| import uuid |
| from collections import OrderedDict |
| from collections import defaultdict |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from typing import Any |
| from typing import Callable |
| from typing import Dict |
| from typing import Generic |
| from typing import Iterable |
| from typing import List |
| from typing import Mapping |
| from typing import NamedTuple |
| from typing import Optional |
| from typing import Sequence |
| from typing import Tuple |
| from typing import TypeVar |
| from typing import Union |
| |
| import apache_beam as beam |
| from apache_beam.utils import multi_process_shared |
| from apache_beam.utils import shared |
| |
| try: |
| # pylint: disable=wrong-import-order, wrong-import-position |
| import resource |
| except ImportError: |
| resource = None # type: ignore[assignment] |
| |
| _NANOSECOND_TO_MILLISECOND = 1_000_000 |
| _NANOSECOND_TO_MICROSECOND = 1_000 |
| |
| ModelT = TypeVar('ModelT') |
| ExampleT = TypeVar('ExampleT') |
| PreProcessT = TypeVar('PreProcessT') |
| PredictionT = TypeVar('PredictionT') |
| PostProcessT = TypeVar('PostProcessT') |
| _INPUT_TYPE = TypeVar('_INPUT_TYPE') |
| _OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') |
| KeyT = TypeVar('KeyT') |
| |
| |
| # We use NamedTuple to define the structure of the PredictionResult, |
| # however, as support for generic NamedTuples is not available in Python |
| # versions prior to 3.11, we use the __new__ method to provide default |
| # values for the fields while maintaining backwards compatibility. |
| class PredictionResult(NamedTuple('PredictionResult', |
| [('example', _INPUT_TYPE), |
| ('inference', _OUTPUT_TYPE), |
| ('model_id', Optional[str])])): |
| __slots__ = () |
| |
| def __new__(cls, example, inference, model_id=None): |
| return super().__new__(cls, example, inference, model_id) |
| |
| |
| PredictionResult.__doc__ = """A NamedTuple containing both input and output |
| from the inference.""" |
| PredictionResult.example.__doc__ = """The input example.""" |
| PredictionResult.inference.__doc__ = """Results for the inference on the model |
| for the given example.""" |
| PredictionResult.model_id.__doc__ = """Model ID used to run the prediction.""" |
| |
| |
| class ModelMetadata(NamedTuple): |
| model_id: str |
| model_name: str |
| |
| |
| class RunInferenceDLQ(NamedTuple): |
| failed_inferences: beam.PCollection |
| failed_preprocessing: Sequence[beam.PCollection] |
| failed_postprocessing: Sequence[beam.PCollection] |
| |
| |
| class _ModelLoadStats(NamedTuple): |
| model_tag: str |
| load_latency: Optional[int] |
| byte_size: Optional[int] |
| |
| |
| ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be |
| a file path or a URL where the model can be accessed. It is used to load |
| the model for inference.""" |
| ModelMetadata.model_name.__doc__ = """Human-readable name for the model. This |
| can be used to identify the model in the metrics generated by the |
| RunInference transform.""" |
| |
| |
| def _to_milliseconds(time_ns: int) -> int: |
| return int(time_ns / _NANOSECOND_TO_MILLISECOND) |
| |
| |
| def _to_microseconds(time_ns: int) -> int: |
| return int(time_ns / _NANOSECOND_TO_MICROSECOND) |
| |
| |
| @dataclass(frozen=True) |
| class KeyModelPathMapping(Generic[KeyT]): |
| """ |
| Dataclass for mapping 1 or more keys to 1 model path. This is used in |
| conjunction with a KeyedModelHandler with many model handlers to update |
| a set of keys' model handlers with the new path. Given |
| `KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path', |
| model_id: 'id1')`, all examples with keys `key1` or `key2` will have their |
| corresponding model handler's update_model function called with |
| 'updated/path' and their metrics will correspond with 'id1'. For more |
| information see the KeyedModelHandler documentation |
| https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler |
| documentation and the website section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| keys: List[KeyT] |
| update_path: str |
| model_id: str = '' |
| |
| |
| class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): |
| """Has the ability to load and apply an ML model.""" |
| def __init__(self): |
| """Environment variables are set using a dict named 'env_vars' before |
| loading the model. Child classes can accept this dict as a kwarg.""" |
| self._env_vars = {} |
| |
| def load_model(self) -> ModelT: |
| """Loads and initializes a model for processing.""" |
| raise NotImplementedError(type(self)) |
| |
| def run_inference( |
| self, |
| batch: Sequence[ExampleT], |
| model: ModelT, |
| inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: |
| """Runs inferences on a batch of examples. |
| |
| Args: |
| batch: A sequence of examples or features. |
| model: The model used to make inferences. |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| |
| Returns: |
| An Iterable of Predictions. |
| """ |
| raise NotImplementedError(type(self)) |
| |
| def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: |
| """ |
| Returns: |
| The number of bytes of data for a batch. |
| """ |
| return len(pickle.dumps(batch)) |
| |
| def get_metrics_namespace(self) -> str: |
| """ |
| Returns: |
| A namespace for metrics collected by the RunInference transform. |
| """ |
| return 'RunInference' |
| |
| def get_resource_hints(self) -> dict: |
| """ |
| Returns: |
| Resource hints for the transform. |
| """ |
| return {} |
| |
| def batch_elements_kwargs(self) -> Mapping[str, Any]: |
| """ |
| Returns: |
| kwargs suitable for beam.BatchElements. |
| """ |
| return {} |
| |
| def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): |
| """Validates inference_args passed in the inference call. |
| |
| Because most frameworks do not need extra arguments in their predict() call, |
| the default behavior is to error out if inference_args are present. |
| """ |
| if inference_args: |
| raise ValueError( |
| 'inference_args were provided, but should be None because this ' |
| 'framework does not expect extra arguments on inferences.') |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| """ |
| Update the model path produced by side inputs. update_model_path should be |
| used when a ModelHandler represents a single model, not multiple models. |
| This will be true in most cases. For more information see the website |
| section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| pass |
| |
| def update_model_paths( |
| self, |
| model: ModelT, |
| model_paths: Optional[Union[str, List[KeyModelPathMapping]]] = None): |
| """ |
| Update the model paths produced by side inputs. update_model_paths should |
| be used when updating multiple models at once (e.g. when using a |
| KeyedModelHandler that holds multiple models). For more information see |
| the KeyedModelHandler documentation |
| https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler |
| documentation and the website section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| pass |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| """Gets all preprocessing functions to be run before batching/inference. |
| Functions are in order that they should be applied.""" |
| return [] |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| """Gets all postprocessing functions to be run after inference. |
| Functions are in order that they should be applied.""" |
| return [] |
| |
| def set_environment_vars(self): |
| """Sets environment variables using a dictionary provided via kwargs. |
| Keys are the env variable name, and values are the env variable value. |
| Child ModelHandler classes should set _env_vars via kwargs in __init__, |
| or else call super().__init__().""" |
| env_vars = getattr(self, '_env_vars', {}) |
| for env_variable, env_value in env_vars.items(): |
| os.environ[env_variable] = env_value |
| |
| def with_preprocess_fn( |
| self, fn: Callable[[PreProcessT], ExampleT] |
| ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]': |
| """Returns a new ModelHandler with a preprocessing function |
| associated with it. The preprocessing function will be run |
| before batching/inference and should map your input PCollection |
| to the base ModelHandler's input type. If you apply multiple |
| preprocessing functions, they will be run on your original |
| PCollection in order from last applied to first applied.""" |
| return _PreProcessingModelHandler(self, fn) |
| |
| def with_postprocess_fn( |
| self, fn: Callable[[PredictionT], PostProcessT] |
| ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]': |
| """Returns a new ModelHandler with a postprocessing function |
| associated with it. The postprocessing function will be run |
| after inference and should map the base ModelHandler's output |
| type to your desired output type. If you apply multiple |
| postprocessing functions, they will be run on your original |
| inference result in order from first applied to last applied.""" |
| return _PostProcessingModelHandler(self, fn) |
| |
| def share_model_across_processes(self) -> bool: |
| """Returns a boolean representing whether or not a model should |
| be shared across multiple processes instead of being loaded per process. |
| This is primary useful for large models that can't fit multiple copies in |
| memory. Multi-process support may vary by runner, but this will fallback to |
| loading per process as necessary. See |
| https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" |
| return False |
| |
| def override_metrics(self, metrics_namespace: str = '') -> bool: |
| """Returns a boolean representing whether or not a model handler will |
| override metrics reporting. If True, RunInference will not report any |
| metrics.""" |
| return False |
| |
| |
| class _ModelManager: |
| """ |
| A class for efficiently managing copies of multiple models. Will load a |
| single copy of each model into a multi_process_shared object and then |
| return a lookup key for that object. |
| """ |
| def __init__(self, mh_map: Dict[str, ModelHandler]): |
| """ |
| Args: |
| mh_map: A map from keys to model handlers which can be used to load a |
| model. |
| """ |
| self._max_models = None |
| # Map keys to model handlers |
| self._mh_map: Dict[str, ModelHandler] = mh_map |
| # Map keys to the last updated model path for that key |
| self._key_to_last_update: Dict[str, str] = defaultdict(str) |
| # Map key for a model to a unique tag that will persist for the life of |
| # that model in memory. A new tag will be generated if a model is swapped |
| # out of memory and reloaded. |
| self._tag_map: Dict[str, str] = OrderedDict() |
| # Map a tag to a multiprocessshared model object for that tag. Each entry |
| # of this map should last as long as the corresponding entry in _tag_map. |
| self._proxy_map: Dict[str, multi_process_shared.MultiProcessShared] = {} |
| |
| def load(self, key: str) -> _ModelLoadStats: |
| """ |
| Loads the appropriate model for the given key into memory. |
| Args: |
| key: the key associated with the model we'd like to load. |
| Returns: |
| _ModelLoadStats with tag, byte size, and latency to load the model. If |
| the model was already loaded, byte size/latency will be None. |
| """ |
| # Map the key for a model to a unique tag that will persist until the model |
| # is released. This needs to be unique between releasing/reacquiring th |
| # model because otherwise the ProxyManager will try to reuse the model that |
| # has been released and deleted. |
| if key in self._tag_map: |
| self._tag_map.move_to_end(key) |
| return _ModelLoadStats(self._tag_map[key], None, None) |
| else: |
| self._tag_map[key] = uuid.uuid4().hex |
| |
| tag = self._tag_map[key] |
| mh = self._mh_map[key] |
| |
| if self._max_models is not None and self._max_models < len(self._tag_map): |
| # If we're about to exceed our LRU size, release the last used model. |
| tag_to_remove = self._tag_map.popitem(last=False)[1] |
| shared_handle, model_to_remove = self._proxy_map[tag_to_remove] |
| shared_handle.release(model_to_remove) |
| del self._proxy_map[tag_to_remove] |
| |
| # Load the new model |
| memory_before = _get_current_process_memory_in_bytes() |
| start_time = _to_milliseconds(time.time_ns()) |
| shared_handle = multi_process_shared.MultiProcessShared( |
| mh.load_model, tag=tag) |
| model_reference = shared_handle.acquire() |
| self._proxy_map[tag] = (shared_handle, model_reference) |
| memory_after = _get_current_process_memory_in_bytes() |
| end_time = _to_milliseconds(time.time_ns()) |
| |
| return _ModelLoadStats( |
| tag, end_time - start_time, memory_after - memory_before) |
| |
| def increment_max_models(self, increment: int): |
| """ |
| Increments the number of models that this instance of a _ModelManager is |
| able to hold. If it is never called, no limit is imposed. |
| Args: |
| increment: the amount by which we are incrementing the number of models. |
| """ |
| if self._max_models is None: |
| self._max_models = 0 |
| self._max_models += increment |
| |
| def update_model_handler(self, key: str, model_path: str, previous_key: str): |
| """ |
| Updates the model path of this model handler and removes it from memory so |
| that it can be reloaded with the updated path. No-ops if no model update |
| needs to be applied. |
| Args: |
| key: the key associated with the model we'd like to update. |
| model_path: the new path to the model we'd like to load. |
| previous_key: the key that is associated with the old version of this |
| model. This will often be the same as the current key, but sometimes |
| we will want to keep both the old and new models to serve different |
| cohorts. In that case, the keys should be different. |
| """ |
| if self._key_to_last_update[key] == model_path: |
| return |
| self._key_to_last_update[key] = model_path |
| if key not in self._mh_map: |
| self._mh_map[key] = deepcopy(self._mh_map[previous_key]) |
| self._mh_map[key].update_model_path(model_path) |
| if key in self._tag_map: |
| tag_to_remove = self._tag_map[key] |
| shared_handle, model_to_remove = self._proxy_map[tag_to_remove] |
| shared_handle.release(model_to_remove) |
| del self._tag_map[key] |
| del self._proxy_map[tag_to_remove] |
| |
| |
| # Use a dataclass instead of named tuple because NamedTuples and generics don't |
| # mix well across the board for all versions: |
| # https://github.com/python/typing/issues/653 |
| class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): |
| """ |
| Dataclass for mapping 1 or more keys to 1 model handler. Given |
| `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` |
| or `key2` will be run against the model defined by the `myMh` ModelHandler. |
| """ |
| def __init__( |
| self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]): |
| self.keys = keys |
| self.mh = mh |
| |
| |
| class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], |
| ModelHandler[Tuple[KeyT, ExampleT], |
| Tuple[KeyT, PredictionT], |
| Union[ModelT, _ModelManager]]): |
| def __init__( |
| self, |
| unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], |
| List[KeyModelMapping[KeyT, ExampleT, PredictionT, |
| ModelT]]], |
| max_models_per_worker_hint: Optional[int] = None): |
| """A ModelHandler that takes keyed examples and returns keyed predictions. |
| |
| For example, if the original model is used with RunInference to take a |
| PCollection[E] to a PCollection[P], this ModelHandler would take a |
| PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible |
| to use the key to associate the outputs with the inputs. KeyedModelHandler |
| is able to accept either a single unkeyed ModelHandler or many different |
| model handlers corresponding to the keys for which that ModelHandler should |
| be used. For example, the following configuration could be used to map keys |
| 1-3 to ModelHandler1 and keys 4-5 to ModelHandler2: |
| |
| k1 = ['k1', 'k2', 'k3'] |
| k2 = ['k4', 'k5'] |
| KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)]) |
| |
| Note that a single copy of each of these models may all be held in memory |
| at the same time; be careful not to load too many large models or your |
| pipeline may cause Out of Memory exceptions. |
| |
| KeyedModelHandlers support Automatic Model Refresh to update your model |
| to a newer version without stopping your streaming pipeline. For an |
| overview of this feature, see |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| |
| |
| To use this feature with a KeyedModelHandler that has many models per key, |
| you can pass in a list of KeyModelPathMapping objects to define your new |
| model paths. For example, passing in the side input of |
| |
| [KeyModelPathMapping(keys=['k1', 'k2'], update_path='update/path/1'), |
| KeyModelPathMapping(keys=['k3'], update_path='update/path/2')] |
| |
| will update the model corresponding to keys 'k1' and 'k2' with path |
| 'update/path/1' and the model corresponding to 'k3' with 'update/path/2'. |
| In order to do a side input update: (1) all restrictions mentioned in |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| must be met, (2) all update_paths must be non-empty, even if they are not |
| being updated from their original values, and (3) the set of keys |
| originally defined cannot change. This means that if originally you have |
| defined model handlers for 'key1', 'key2', and 'key3', all 3 of those keys |
| must appear in your list of KeyModelPathMappings exactly once. No |
| additional keys can be added. |
| |
| When using many models defined per key, metrics about inference and model |
| loading will be gathered on an aggregate basis for all keys. These will be |
| reported with no prefix. Metrics will also be gathered on a per key basis. |
| Since some keys can share the same model, only one set of metrics will be |
| reported per key 'cohort'. These will be reported in the form: |
| `<cohort_key>-<metric_name>`, where `<cohort_key>` can be any key selected |
| from the cohort. When model updates occur, the metrics will be reported in |
| the form `<cohort_key>-<model id>-<metric_name>`. |
| |
| Loading multiple models at the same time can increase the risk of an out of |
| memory (OOM) exception. To avoid this issue, use the parameter |
| `max_models_per_worker_hint` to limit the number of models that are loaded |
| at the same time. For more information about memory management, see |
| `Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_`. # pylint: disable=line-too-long |
| |
| |
| Args: |
| unkeyed: Either (a) an implementation of ModelHandler that does not |
| require keys or (b) a list of KeyModelMappings mapping lists of keys to |
| unkeyed ModelHandlers. |
| max_models_per_worker_hint: A hint to the runner indicating how many |
| models can be held in memory at one time per worker process. For |
| example, if your worker has 8 GB of memory provisioned and your workers |
| take up 1 GB each, you should set this to 7 to allow all models to sit |
| in memory with some buffer. For more information about memory management, |
| see `Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_`. # pylint: disable=line-too-long |
| """ |
| self._metrics_collectors: Dict[str, _MetricsCollector] = {} |
| self._default_metrics_collector: _MetricsCollector = None |
| self._metrics_namespace = '' |
| self._single_model = not isinstance(unkeyed, list) |
| if self._single_model: |
| if len(unkeyed.get_preprocess_fns()) or len( |
| unkeyed.get_postprocess_fns()): |
| raise Exception( |
| 'Cannot make make an unkeyed model handler with pre or ' |
| 'postprocessing functions defined into a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| self._env_vars = getattr(unkeyed, '_env_vars', {}) |
| self._unkeyed = unkeyed |
| return |
| |
| self._max_models_per_worker_hint = max_models_per_worker_hint |
| # To maintain an efficient representation, we will map all keys in a given |
| # KeyModelMapping to a single id (the first key in the KeyModelMapping |
| # list). We will then map that key to a ModelHandler. This will allow us to |
| # quickly look up the appropriate ModelHandler for any given key. |
| self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT, |
| ModelT]] = {} |
| self._key_to_id_map: Dict[str, str] = {} |
| for mh_tuple in unkeyed: |
| mh = mh_tuple.mh |
| keys = mh_tuple.keys |
| if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()): |
| raise ValueError( |
| 'Cannot use an unkeyed model handler with pre or ' |
| 'postprocessing functions defined in a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| hints = mh.get_resource_hints() |
| if len(hints) > 0: |
| logging.warning( |
| 'mh %s defines the following resource hints, which will be' |
| 'ignored: %s. Resource hints are not respected when more than one ' |
| 'model handler is used in a KeyedModelHandler. If you would like ' |
| 'to specify resource hints, you can do so by overriding the ' |
| 'KeyedModelHandler.get_resource_hints() method.', |
| mh, |
| hints) |
| batch_kwargs = mh.batch_elements_kwargs() |
| if len(batch_kwargs) > 0: |
| logging.warning( |
| 'mh %s defines the following batching kwargs which will be ' |
| 'ignored %s. Batching kwargs are not respected when ' |
| 'more than one model handler is used in a KeyedModelHandler. If ' |
| 'you would like to specify resource hints, you can do so by ' |
| 'overriding the KeyedModelHandler.batch_elements_kwargs() method.', |
| hints, |
| batch_kwargs) |
| env_vars = getattr(mh, '_env_vars', {}) |
| if len(env_vars) > 0: |
| logging.warning( |
| 'mh %s defines the following _env_vars which will be ignored %s. ' |
| '_env_vars are not respected when more than one model handler is ' |
| 'used in a KeyedModelHandler. If you need env vars set at ' |
| 'inference time, you can do so with ' |
| 'a custom inference function.', |
| mh, |
| env_vars) |
| |
| if len(keys) == 0: |
| raise ValueError( |
| f'Empty list maps to model handler {mh}. All model handlers must ' |
| 'have one or more associated keys.') |
| self._id_to_mh_map[keys[0]] = mh |
| for key in keys: |
| if key in self._key_to_id_map: |
| raise ValueError( |
| f'key {key} maps to multiple model handlers. All keys must map ' |
| 'to exactly one model handler.') |
| self._key_to_id_map[key] = keys[0] |
| |
| def load_model(self) -> Union[ModelT, _ModelManager]: |
| if self._single_model: |
| return self._unkeyed.load_model() |
| return _ModelManager(self._id_to_mh_map) |
| |
| def run_inference( |
| self, |
| batch: Sequence[Tuple[KeyT, ExampleT]], |
| model: Union[ModelT, _ModelManager], |
| inference_args: Optional[Dict[str, Any]] = None |
| ) -> Iterable[Tuple[KeyT, PredictionT]]: |
| if self._single_model: |
| keys, unkeyed_batch = zip(*batch) |
| return zip( |
| keys, |
| self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) |
| |
| # The first time a MultiProcessShared ModelManager is used for inference |
| # from this process, we should increment its max model count |
| if self._max_models_per_worker_hint is not None: |
| lock = threading.Lock() |
| if lock.acquire(blocking=False): |
| model.increment_max_models(self._max_models_per_worker_hint) |
| self._max_models_per_worker_hint = None |
| |
| batch_by_key = defaultdict(list) |
| key_by_id = defaultdict(set) |
| for key, example in batch: |
| batch_by_key[key].append(example) |
| key_by_id[self._key_to_id_map[key]].add(key) |
| |
| predictions = [] |
| for id, keys in key_by_id.items(): |
| mh = self._id_to_mh_map[id] |
| loaded_model = model.load(id) |
| keyed_model_tag = loaded_model.model_tag |
| if loaded_model.byte_size is not None: |
| self._metrics_collectors[id].update_load_model_metrics( |
| loaded_model.load_latency, loaded_model.byte_size) |
| self._default_metrics_collector.update_load_model_metrics( |
| loaded_model.load_latency, loaded_model.byte_size) |
| keyed_model_shared_handle = multi_process_shared.MultiProcessShared( |
| mh.load_model, tag=keyed_model_tag) |
| keyed_model = keyed_model_shared_handle.acquire() |
| start_time = _to_microseconds(time.time_ns()) |
| num_bytes = 0 |
| num_elements = 0 |
| try: |
| for key in keys: |
| unkeyed_batches = batch_by_key[key] |
| try: |
| for inf in mh.run_inference(unkeyed_batches, |
| keyed_model, |
| inference_args): |
| predictions.append((key, inf)) |
| except BaseException as e: |
| self._metrics_collectors[id].failed_batches_counter.inc() |
| self._default_metrics_collector.failed_batches_counter.inc() |
| raise e |
| num_bytes += mh.get_num_bytes(unkeyed_batches) |
| num_elements += len(unkeyed_batches) |
| finally: |
| keyed_model_shared_handle.release(keyed_model) |
| end_time = _to_microseconds(time.time_ns()) |
| inference_latency = end_time - start_time |
| self._metrics_collectors[id].update( |
| num_elements, num_bytes, inference_latency) |
| self._default_metrics_collector.update( |
| num_elements, num_bytes, inference_latency) |
| |
| return predictions |
| |
| def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: |
| keys, unkeyed_batch = zip(*batch) |
| batch_bytes = len(pickle.dumps(keys)) |
| if self._single_model: |
| return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch) |
| |
| batch_by_key = defaultdict(list) |
| for key, examples in batch: |
| batch_by_key[key].append(examples) |
| |
| for key, examples in batch_by_key.items(): |
| mh_id = self._key_to_id_map[key] |
| batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples) |
| return batch_bytes |
| |
| def get_metrics_namespace(self) -> str: |
| if self._single_model: |
| return self._unkeyed.get_metrics_namespace() |
| return 'BeamML_KeyedModels' |
| |
| def get_resource_hints(self): |
| if self._single_model: |
| return self._unkeyed.get_resource_hints() |
| return {} |
| |
| def batch_elements_kwargs(self): |
| if self._single_model: |
| return self._unkeyed.batch_elements_kwargs() |
| return {} |
| |
| def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): |
| if self._single_model: |
| return self._unkeyed.validate_inference_args(inference_args) |
| for mh in self._id_to_mh_map.values(): |
| mh.validate_inference_args(inference_args) |
| |
| def update_model_paths( |
| self, |
| model: Union[ModelT, _ModelManager], |
| model_paths: List[KeyModelPathMapping[KeyT]] = None): |
| # When there are many models, the keyed model handler is responsible for |
| # reorganizing the model handlers into cohorts and telling the model |
| # manager to update every cohort's associated model handler. The model |
| # manager is responsible for performing the updates and tracking which |
| # updates have already been applied. |
| if model_paths is None or len(model_paths) == 0 or model is None: |
| return |
| if self._single_model: |
| raise RuntimeError( |
| 'Invalid model update: sent many model paths to ' |
| 'update, but KeyedModelHandler is wrapping a single ' |
| 'model.') |
| # Map cohort ids to a dictionary mapping new model paths to the keys that |
| # were originally in that cohort. We will use this to construct our new |
| # cohorts. |
| # cohort_path_mapping will be structured as follows: |
| # { |
| # original_cohort_id: { |
| # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'], |
| # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'], |
| # } |
| # } |
| cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {} |
| key_modelid_mapping: Dict[KeyT, str] = {} |
| seen_keys = set() |
| for mp in model_paths: |
| keys = mp.keys |
| update_path = mp.update_path |
| model_id = mp.model_id |
| if len(update_path) == 0: |
| raise ValueError(f'Invalid model update, path for {keys} is empty') |
| for key in keys: |
| if key in seen_keys: |
| raise ValueError( |
| f'Invalid model update: {key} appears in multiple ' |
| 'update lists. A single model update must provide exactly one ' |
| 'updated path per key.') |
| seen_keys.add(key) |
| if key not in self._key_to_id_map: |
| raise ValueError( |
| f'Invalid model update: {key} appears in ' |
| 'update, but not in the original configuration.') |
| key_modelid_mapping[key] = model_id |
| cohort_id = self._key_to_id_map[key] |
| if cohort_id not in cohort_path_mapping: |
| cohort_path_mapping[cohort_id] = defaultdict(list) |
| cohort_path_mapping[cohort_id][update_path].append(key) |
| for key in self._key_to_id_map: |
| if key not in seen_keys: |
| raise ValueError( |
| f'Invalid model update: {key} appears in the ' |
| 'original configuration, but not the update.') |
| |
| # We now have our new set of cohorts. For each one, update our local model |
| # handler configuration and send the results to the ModelManager |
| for old_cohort_id, path_key_mapping in cohort_path_mapping.items(): |
| for updated_path, keys in path_key_mapping.items(): |
| cohort_id = old_cohort_id |
| if old_cohort_id not in keys: |
| # Create new cohort |
| cohort_id = keys[0] |
| for key in keys: |
| self._key_to_id_map[key] = cohort_id |
| mh = self._id_to_mh_map[old_cohort_id] |
| self._id_to_mh_map[cohort_id] = deepcopy(mh) |
| self._id_to_mh_map[cohort_id].update_model_path(updated_path) |
| model.update_model_handler(cohort_id, updated_path, old_cohort_id) |
| model_id = key_modelid_mapping[cohort_id] |
| self._metrics_collectors[cohort_id] = _MetricsCollector( |
| self._metrics_namespace, f'{cohort_id}-{model_id}-') |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| if self._single_model: |
| return self._unkeyed.update_model_path(model_path=model_path) |
| if model_path is not None: |
| raise RuntimeError( |
| 'Model updates are currently not supported for ' + |
| 'KeyedModelHandlers with multiple different per-key ' + |
| 'ModelHandlers.') |
| |
| def share_model_across_processes(self) -> bool: |
| if self._single_model: |
| return self._unkeyed.share_model_across_processes() |
| return True |
| |
| def override_metrics(self, metrics_namespace: str = '') -> bool: |
| if self._single_model: |
| return self._unkeyed.override_metrics(metrics_namespace) |
| |
| self._metrics_namespace = metrics_namespace |
| self._default_metrics_collector = _MetricsCollector(metrics_namespace) |
| for cohort_id in self._id_to_mh_map: |
| self._metrics_collectors[cohort_id] = _MetricsCollector( |
| metrics_namespace, f'{cohort_id}-') |
| |
| return True |
| |
| |
| class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], |
| ModelHandler[Union[ExampleT, Tuple[KeyT, |
| ExampleT]], |
| Union[PredictionT, |
| Tuple[KeyT, PredictionT]], |
| ModelT]): |
| def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): |
| """A ModelHandler that takes examples that might have keys and returns |
| predictions that might have keys. |
| |
| For example, if the original model is used with RunInference to take a |
| PCollection[E] to a PCollection[P], this ModelHandler would take either |
| PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a |
| PCollection[Tuple[K, P]], depending on the whether the elements are |
| tuples. This pattern makes it possible to associate the outputs with the |
| inputs based on the key. |
| |
| Note that you cannot use this ModelHandler if E is a tuple type. |
| In addition, either all examples should be keyed, or none of them. |
| |
| Args: |
| unkeyed: An implementation of ModelHandler that does not require keys. |
| """ |
| if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()): |
| raise Exception( |
| 'Cannot make make an unkeyed model handler with pre or ' |
| 'postprocessing functions defined into a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| self._unkeyed = unkeyed |
| self._env_vars = getattr(unkeyed, '_env_vars', {}) |
| |
| def load_model(self) -> ModelT: |
| return self._unkeyed.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[Dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: |
| # Really the input should be |
| # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] |
| # but there's not a good way to express (or check) that. |
| if isinstance(batch[0], tuple): |
| is_keyed = True |
| keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] |
| else: |
| is_keyed = False |
| unkeyed_batch = batch # type: ignore[assignment] |
| unkeyed_results = self._unkeyed.run_inference( |
| unkeyed_batch, model, inference_args) |
| if is_keyed: |
| return zip(keys, unkeyed_results) |
| else: |
| return unkeyed_results |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: |
| # MyPy can't follow the branching logic. |
| if isinstance(batch[0], tuple): |
| keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] |
| return len( |
| pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) |
| else: |
| return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type] |
| |
| def get_metrics_namespace(self) -> str: |
| return self._unkeyed.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._unkeyed.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._unkeyed.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): |
| return self._unkeyed.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._unkeyed.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._unkeyed.get_preprocess_fns() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._unkeyed.get_postprocess_fns() |
| |
| def share_model_across_processes(self) -> bool: |
| return self._unkeyed.share_model_across_processes() |
| |
| |
| class _PreProcessingModelHandler(Generic[ExampleT, |
| PredictionT, |
| ModelT, |
| PreProcessT], |
| ModelHandler[PreProcessT, PredictionT, |
| ModelT]): |
| def __init__( |
| self, |
| base: ModelHandler[ExampleT, PredictionT, ModelT], |
| preprocess_fn: Callable[[PreProcessT], ExampleT]): |
| """A ModelHandler that has a preprocessing function associated with it. |
| |
| Args: |
| base: An implementation of the underlying model handler. |
| preprocess_fn: the preprocessing function to use. |
| """ |
| self._base = base |
| self._env_vars = getattr(base, '_env_vars', {}) |
| self._preprocess_fn = preprocess_fn |
| |
| def load_model(self) -> ModelT: |
| return self._base.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[Dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: |
| return self._base.run_inference(batch, model, inference_args) |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: |
| return self._base.get_num_bytes(batch) |
| |
| def get_metrics_namespace(self) -> str: |
| return self._base.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._base.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._base.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): |
| return self._base.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._base.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return [self._preprocess_fn] + self._base.get_preprocess_fns() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_postprocess_fns() |
| |
| |
| class _PostProcessingModelHandler(Generic[ExampleT, |
| PredictionT, |
| ModelT, |
| PostProcessT], |
| ModelHandler[ExampleT, PostProcessT, ModelT]): |
| def __init__( |
| self, |
| base: ModelHandler[ExampleT, PredictionT, ModelT], |
| postprocess_fn: Callable[[PredictionT], PostProcessT]): |
| """A ModelHandler that has a preprocessing function associated with it. |
| |
| Args: |
| base: An implementation of the underlying model handler. |
| postprocess_fn: the preprocessing function to use. |
| """ |
| self._base = base |
| self._env_vars = getattr(base, '_env_vars', {}) |
| self._postprocess_fn = postprocess_fn |
| |
| def load_model(self) -> ModelT: |
| return self._base.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[Dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: |
| return self._base.run_inference(batch, model, inference_args) |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: |
| return self._base.get_num_bytes(batch) |
| |
| def get_metrics_namespace(self) -> str: |
| return self._base.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._base.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._base.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): |
| return self._base.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._base.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_preprocess_fns() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_postprocess_fns() + [self._postprocess_fn] |
| |
| |
| class RunInference(beam.PTransform[beam.PCollection[ExampleT], |
| beam.PCollection[PredictionT]]): |
| def __init__( |
| self, |
| model_handler: ModelHandler[ExampleT, PredictionT, Any], |
| clock=time, |
| inference_args: Optional[Dict[str, Any]] = None, |
| metrics_namespace: Optional[str] = None, |
| *, |
| model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, |
| watch_model_pattern: Optional[str] = None, |
| **kwargs): |
| """ |
| A transform that takes a PCollection of examples (or features) for use |
| on an ML model. The transform then outputs inferences (or predictions) for |
| those examples in a PCollection of PredictionResults that contains the input |
| examples and the output inferences. |
| |
| Models for supported frameworks can be loaded using a URI. Supported |
| services can also be used. |
| |
| This transform attempts to batch examples using the beam.BatchElements |
| transform. Batching can be configured using the ModelHandler. |
| |
| Args: |
| model_handler: An implementation of ModelHandler. |
| clock: A clock implementing time_ns. *Used for unit testing.* |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| metrics_namespace: Namespace of the transform to collect metrics. |
| model_metadata_pcoll: PCollection that emits Singleton ModelMetadata |
| containing model path and model name, that is used as a side input |
| to the _RunInferenceDoFn. |
| watch_model_pattern: A glob pattern used to watch a directory |
| for automatic model refresh. |
| """ |
| self._model_handler = model_handler |
| self._inference_args = inference_args |
| self._clock = clock |
| self._metrics_namespace = metrics_namespace |
| self._model_metadata_pcoll = model_metadata_pcoll |
| self._with_exception_handling = False |
| self._watch_model_pattern = watch_model_pattern |
| self._kwargs = kwargs |
| # Generate a random tag to use for shared.py and multi_process_shared.py to |
| # allow us to effectively disambiguate in multi-model settings. |
| self._model_tag = uuid.uuid4().hex |
| |
| def annotations(self): |
| return { |
| 'model_handler': str(self._model_handler), |
| 'model_handler_type': ( |
| f'{self._model_handler.__class__.__module__}' |
| f'.{self._model_handler.__class__.__qualname__}'), |
| **super().annotations() |
| } |
| |
| def _get_model_metadata_pcoll(self, pipeline): |
| # avoid circular imports. |
| # pylint: disable=wrong-import-position |
| from apache_beam.ml.inference.utils import WatchFilePattern |
| extra_params = {} |
| if 'interval' in self._kwargs: |
| extra_params['interval'] = self._kwargs['interval'] |
| if 'stop_timestamp' in self._kwargs: |
| extra_params['stop_timestamp'] = self._kwargs['stop_timestamp'] |
| |
| return ( |
| pipeline | WatchFilePattern( |
| file_pattern=self._watch_model_pattern, **extra_params)) |
| |
| # TODO(BEAM-14046): Add and link to help documentation. |
| @classmethod |
| def from_callable(cls, model_handler_provider, **kwargs): |
| """Multi-language friendly constructor. |
| |
| Use this constructor with fully_qualified_named_transform to |
| initialize the RunInference transform from PythonCallableSource provided |
| by foreign SDKs. |
| |
| Args: |
| model_handler_provider: A callable object that returns ModelHandler. |
| kwargs: Keyword arguments for model_handler_provider. |
| """ |
| return cls(model_handler_provider(**kwargs)) |
| |
| def _apply_fns( |
| self, |
| pcoll: beam.PCollection, |
| fns: Iterable[Callable[[Any], Any]], |
| step_prefix: str) -> Tuple[beam.PCollection, Iterable[beam.PCollection]]: |
| bad_preprocessed = [] |
| for idx in range(len(fns)): |
| fn = fns[idx] |
| if self._with_exception_handling: |
| pcoll, bad = (pcoll |
| | f"{step_prefix}-{idx}" >> beam.Map( |
| fn).with_exception_handling( |
| exc_class=self._exc_class, |
| use_subprocess=self._use_subprocess, |
| threshold=self._threshold)) |
| bad_preprocessed.append(bad) |
| else: |
| pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn) |
| |
| return pcoll, bad_preprocessed |
| |
| # TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off |
| # in the case there are functional reasons large batch sizes cannot be |
| # handled. |
| def expand( |
| self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]: |
| self._model_handler.validate_inference_args(self._inference_args) |
| # DLQ pcollections |
| bad_preprocessed = [] |
| bad_inference = None |
| bad_postprocessed = [] |
| preprocess_fns = self._model_handler.get_preprocess_fns() |
| postprocess_fns = self._model_handler.get_postprocess_fns() |
| |
| pcoll, bad_preprocessed = self._apply_fns( |
| pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess') |
| |
| resource_hints = self._model_handler.get_resource_hints() |
| |
| # check for the side input |
| if self._watch_model_pattern: |
| self._model_metadata_pcoll = self._get_model_metadata_pcoll( |
| pcoll.pipeline) |
| |
| batched_elements_pcoll = ( |
| pcoll |
| # TODO(https://github.com/apache/beam/issues/21440): Hook into the |
| # batching DoFn APIs. |
| | beam.BatchElements(**self._model_handler.batch_elements_kwargs())) |
| |
| run_inference_pardo = beam.ParDo( |
| _RunInferenceDoFn( |
| self._model_handler, |
| self._clock, |
| self._metrics_namespace, |
| self._model_metadata_pcoll is not None, |
| self._model_tag), |
| self._inference_args, |
| beam.pvalue.AsSingleton( |
| self._model_metadata_pcoll, |
| ) if self._model_metadata_pcoll else None).with_resource_hints( |
| **resource_hints) |
| |
| if self._with_exception_handling: |
| results, bad_inference = ( |
| batched_elements_pcoll |
| | 'BeamML_RunInference' >> |
| run_inference_pardo.with_exception_handling( |
| exc_class=self._exc_class, |
| use_subprocess=self._use_subprocess, |
| threshold=self._threshold)) |
| else: |
| results = ( |
| batched_elements_pcoll |
| | 'BeamML_RunInference' >> run_inference_pardo) |
| |
| results, bad_postprocessed = self._apply_fns( |
| results, postprocess_fns, 'BeamML_RunInference_Postprocess') |
| |
| if self._with_exception_handling: |
| dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed) |
| return results, dlq |
| |
| return results |
| |
| def with_exception_handling( |
| self, *, exc_class=Exception, use_subprocess=False, threshold=1): |
| """Automatically provides a dead letter output for skipping bad records. |
| This can allow a pipeline to continue successfully rather than fail or |
| continuously throw errors on retry when bad elements are encountered. |
| |
| This returns a tagged output with two PCollections, the first being the |
| results of successfully processing the input PCollection, and the second |
| being the set of bad batches of records (those which threw exceptions |
| during processing) along with information about the errors raised. |
| |
| For example, one would write:: |
| |
| main, other = RunInference( |
| maybe_error_raising_model_handler |
| ).with_exception_handling() |
| |
| and `main` will be a PCollection of PredictionResults and `other` will |
| contain a `RunInferenceDLQ` object with PCollections containing failed |
| records for each failed inference, preprocess operation, or postprocess |
| operation. To access each collection of failed records, one would write: |
| |
| failed_inferences = other.failed_inferences |
| failed_preprocessing = other.failed_preprocessing |
| failed_postprocessing = other.failed_postprocessing |
| |
| failed_inferences is in the form |
| PCollection[Tuple[failed batch, exception]]. |
| |
| failed_preprocessing is in the form |
| list[PCollection[Tuple[failed record, exception]]]], where each element of |
| the list corresponds to a preprocess function. These PCollections are |
| in the same order that the preprocess functions are applied. |
| |
| failed_postprocessing is in the form |
| List[PCollection[Tuple[failed record, exception]]]], where each element of |
| the list corresponds to a postprocess function. These PCollections are |
| in the same order that the postprocess functions are applied. |
| |
| |
| Args: |
| exc_class: An exception class, or tuple of exception classes, to catch. |
| Optional, defaults to 'Exception'. |
| use_subprocess: Whether to execute the DoFn logic in a subprocess. This |
| allows one to recover from errors that can crash the calling process |
| (e.g. from an underlying library causing a segfault), but is |
| slower as elements and results must cross a process boundary. Note |
| that this starts up a long-running process that is used to handle |
| all the elements (until hard failure, which should be rare) rather |
| than a new process per element, so the overhead should be minimal |
| (and can be amortized if there's any per-process or per-bundle |
| initialization that needs to be done). Optional, defaults to False. |
| threshold: An upper bound on the ratio of records that can be bad before |
| aborting the entire pipeline. Optional, defaults to 1.0 (meaning |
| up to 100% of records can be bad and the pipeline will still succeed). |
| """ |
| self._with_exception_handling = True |
| self._exc_class = exc_class |
| self._use_subprocess = use_subprocess |
| self._threshold = threshold |
| return self |
| |
| |
| class _MetricsCollector: |
| """ |
| A metrics collector that tracks ML related performance and memory usage. |
| """ |
| def __init__(self, namespace: str, prefix: str = ''): |
| """ |
| Args: |
| namespace: Namespace for the metrics. |
| prefix: Unique identifier for metrics, used when models |
| are updated using side input. |
| """ |
| # Metrics |
| if prefix: |
| prefix = f'{prefix}_' |
| self._inference_counter = beam.metrics.Metrics.counter( |
| namespace, prefix + 'num_inferences') |
| self.failed_batches_counter = beam.metrics.Metrics.counter( |
| namespace, prefix + 'failed_batches_counter') |
| self._inference_request_batch_size = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_request_batch_size') |
| self._inference_request_batch_byte_size = ( |
| beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_request_batch_byte_size')) |
| # Batch inference latency in microseconds. |
| self._inference_batch_latency_micro_secs = ( |
| beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_batch_latency_micro_secs')) |
| self._model_byte_size = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'model_byte_size') |
| # Model load latency in milliseconds. |
| self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'load_model_latency_milli_secs') |
| |
| # Metrics cache |
| self._load_model_latency_milli_secs_cache = None |
| self._model_byte_size_cache = None |
| |
| def update_metrics_with_cache(self): |
| if self._load_model_latency_milli_secs_cache is not None: |
| self._load_model_latency_milli_secs.update( |
| self._load_model_latency_milli_secs_cache) |
| self._load_model_latency_milli_secs_cache = None |
| if self._model_byte_size_cache is not None: |
| self._model_byte_size.update(self._model_byte_size_cache) |
| self._model_byte_size_cache = None |
| |
| def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size): |
| self._load_model_latency_milli_secs_cache = load_model_latency_ms |
| self._model_byte_size_cache = model_byte_size |
| |
| def update_load_model_metrics(self, load_model_latency_ms, model_byte_size): |
| self._load_model_latency_milli_secs.update(load_model_latency_ms) |
| self._model_byte_size.update(model_byte_size) |
| |
| def update( |
| self, |
| examples_count: int, |
| examples_byte_size: int, |
| latency_micro_secs: int): |
| self._inference_batch_latency_micro_secs.update(latency_micro_secs) |
| self._inference_counter.inc(examples_count) |
| self._inference_request_batch_size.update(examples_count) |
| self._inference_request_batch_byte_size.update(examples_byte_size) |
| |
| |
| class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): |
| def __init__( |
| self, |
| model_handler: ModelHandler[ExampleT, PredictionT, Any], |
| clock, |
| metrics_namespace, |
| enable_side_input_loading: bool = False, |
| model_tag: str = "RunInference"): |
| """A DoFn implementation generic to frameworks. |
| |
| Args: |
| model_handler: An implementation of ModelHandler. |
| clock: A clock implementing time_ns. *Used for unit testing.* |
| metrics_namespace: Namespace of the transform to collect metrics. |
| enable_side_input_loading: Bool to indicate if model updates |
| with side inputs. |
| model_tag: Tag to use to disambiguate models in multi-model settings. |
| """ |
| self._model_handler = model_handler |
| self._shared_model_handle = shared.Shared() |
| self._clock = clock |
| self._model = None |
| self._metrics_namespace = metrics_namespace |
| self._enable_side_input_loading = enable_side_input_loading |
| self._side_input_path = None |
| self._model_tag = model_tag |
| |
| def _load_model( |
| self, |
| side_input_model_path: Optional[Union[str, |
| List[KeyModelPathMapping]]] = None): |
| def load(): |
| """Function for constructing shared LoadedModel.""" |
| memory_before = _get_current_process_memory_in_bytes() |
| start_time = _to_milliseconds(self._clock.time_ns()) |
| if isinstance(side_input_model_path, str): |
| self._model_handler.update_model_path(side_input_model_path) |
| else: |
| self._model_handler.update_model_paths( |
| self._model, side_input_model_path) |
| model = self._model_handler.load_model() |
| end_time = _to_milliseconds(self._clock.time_ns()) |
| memory_after = _get_current_process_memory_in_bytes() |
| load_model_latency_ms = end_time - start_time |
| model_byte_size = memory_after - memory_before |
| if self._metrics_collector: |
| self._metrics_collector.cache_load_model_metrics( |
| load_model_latency_ms, model_byte_size) |
| return model |
| |
| # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing |
| # model. |
| model_tag = self._model_tag |
| if isinstance(side_input_model_path, str) and side_input_model_path != '': |
| model_tag = side_input_model_path |
| if self._model_handler.share_model_across_processes(): |
| model = multi_process_shared.MultiProcessShared( |
| load, tag=model_tag, always_proxy=True).acquire() |
| else: |
| model = self._shared_model_handle.acquire(load, tag=model_tag) |
| # since shared_model_handle is shared across threads, the model path |
| # might not get updated in the model handler |
| # because we directly get cached weak ref model from shared cache, instead |
| # of calling load(). For sanity check, call update_model_path again. |
| if isinstance(side_input_model_path, str): |
| self._model_handler.update_model_path(side_input_model_path) |
| else: |
| self._model_handler.update_model_paths(self._model, side_input_model_path) |
| return model |
| |
| def get_metrics_collector(self, prefix: str = ''): |
| """ |
| Args: |
| prefix: Unique identifier for metrics, used when models |
| are updated using side input. |
| """ |
| metrics_namespace = ( |
| self._metrics_namespace) if self._metrics_namespace else ( |
| self._model_handler.get_metrics_namespace()) |
| if self._model_handler.override_metrics(metrics_namespace): |
| return None |
| return _MetricsCollector(metrics_namespace, prefix=prefix) |
| |
| def setup(self): |
| self._metrics_collector = self.get_metrics_collector() |
| self._model_handler.set_environment_vars() |
| if not self._enable_side_input_loading: |
| self._model = self._load_model() |
| |
| def update_model( |
| self, |
| side_input_model_path: Optional[Union[str, |
| List[KeyModelPathMapping]]] = None): |
| self._model = self._load_model(side_input_model_path=side_input_model_path) |
| |
| def _run_inference(self, batch, inference_args): |
| start_time = _to_microseconds(self._clock.time_ns()) |
| try: |
| result_generator = self._model_handler.run_inference( |
| batch, self._model, inference_args) |
| except BaseException as e: |
| if self._metrics_collector: |
| self._metrics_collector.failed_batches_counter.inc() |
| raise e |
| predictions = list(result_generator) |
| |
| end_time = _to_microseconds(self._clock.time_ns()) |
| inference_latency = end_time - start_time |
| num_bytes = self._model_handler.get_num_bytes(batch) |
| num_elements = len(batch) |
| if self._metrics_collector: |
| self._metrics_collector.update(num_elements, num_bytes, inference_latency) |
| |
| return predictions |
| |
| def process( |
| self, |
| batch, |
| inference_args, |
| si_model_metadata: Optional[Union[ModelMetadata, |
| List[ModelMetadata], |
| List[KeyModelPathMapping]]]): |
| """ |
| When side input is enabled: |
| The method checks if the side input model has been updated, and if so, |
| updates the model and runs inference on the batch of data. If the |
| side input is empty or the model has not been updated, the method |
| simply runs inference on the batch of data. |
| """ |
| if not si_model_metadata: |
| return self._run_inference(batch, inference_args) |
| |
| if isinstance(si_model_metadata, beam.pvalue.EmptySideInput): |
| self.update_model(side_input_model_path=None) |
| elif isinstance(si_model_metadata, List) and hasattr(si_model_metadata[0], |
| 'keys'): |
| # TODO(https://github.com/apache/beam/issues/27628): Update metrics here |
| self.update_model(si_model_metadata) |
| elif self._side_input_path != si_model_metadata.model_id: |
| self._side_input_path = si_model_metadata.model_id |
| self._metrics_collector = self.get_metrics_collector( |
| prefix=si_model_metadata.model_name) |
| lock = threading.Lock() |
| with lock: |
| self.update_model(si_model_metadata.model_id) |
| return self._run_inference(batch, inference_args) |
| |
| return self._run_inference(batch, inference_args) |
| |
| def finish_bundle(self): |
| # TODO(https://github.com/apache/beam/issues/21435): Figure out why there |
| # is a cache. |
| if self._metrics_collector: |
| self._metrics_collector.update_metrics_with_cache() |
| |
| |
| def _is_darwin() -> bool: |
| return sys.platform == 'darwin' |
| |
| |
| def _get_current_process_memory_in_bytes(): |
| """ |
| Returns: |
| memory usage in bytes. |
| """ |
| |
| if resource is not None: |
| usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss |
| if _is_darwin(): |
| return usage |
| return usage * 1024 |
| else: |
| logging.warning( |
| 'Resource module is not available for current platform, ' |
| 'memory usage cannot be fetched.') |
| return 0 |