| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| # TODO: https://github.com/apache/beam/issues/21822 |
| # mypy: ignore-errors |
| |
| """An extensible run inference transform. |
| |
| Users of this module can extend the ModelHandler class for any machine learning |
| framework. A ModelHandler implementation is a required parameter of |
| RunInference. |
| |
| The transform handles standard inference functionality, like metric |
| collection, sharing model between threads, and batching elements. |
| """ |
| |
| import functools |
| import logging |
| import os |
| import pickle |
| import sys |
| import threading |
| import time |
| import uuid |
| from abc import ABC |
| from abc import abstractmethod |
| from collections import OrderedDict |
| from collections import defaultdict |
| from collections.abc import Callable |
| from collections.abc import Iterable |
| from collections.abc import Mapping |
| from collections.abc import Sequence |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from datetime import datetime |
| from datetime import timedelta |
| from typing import Any |
| from typing import Generic |
| from typing import NamedTuple |
| from typing import Optional |
| from typing import TypeVar |
| from typing import Union |
| |
| import apache_beam as beam |
| from apache_beam.io.components.adaptive_throttler import ReactiveThrottler |
| from apache_beam.utils import multi_process_shared |
| from apache_beam.utils import retry |
| from apache_beam.utils import shared |
| |
| try: |
| from apache_beam.io.components.rate_limiter import RateLimiter |
| except ImportError: |
| RateLimiter = None |
| |
| try: |
| # pylint: disable=wrong-import-order, wrong-import-position |
| import resource |
| except ImportError: |
| resource = None # type: ignore[assignment] |
| |
| _NANOSECOND_TO_MILLISECOND = 1_000_000 |
| _NANOSECOND_TO_MICROSECOND = 1_000 |
| _MILLISECOND_TO_SECOND = 1_000 |
| |
| ModelT = TypeVar('ModelT') |
| ExampleT = TypeVar('ExampleT') |
| PreProcessT = TypeVar('PreProcessT') |
| PredictionT = TypeVar('PredictionT') |
| PostProcessT = TypeVar('PostProcessT') |
| _INPUT_TYPE = TypeVar('_INPUT_TYPE') |
| _OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') |
| KeyT = TypeVar('KeyT') |
| |
| |
| # We use NamedTuple to define the structure of the PredictionResult, |
| # however, as support for generic NamedTuples is not available in Python |
| # versions prior to 3.11, we use the __new__ method to provide default |
| # values for the fields while maintaining backwards compatibility. |
| class PredictionResult(NamedTuple('PredictionResult', |
| [('example', _INPUT_TYPE), |
| ('inference', _OUTPUT_TYPE), |
| ('model_id', Optional[str])])): |
| __slots__ = () |
| |
| def __new__(cls, example, inference, model_id=None): |
| return super().__new__(cls, example, inference, model_id) |
| |
| |
| PredictionResult.__doc__ = """A NamedTuple containing both input and output |
| from the inference.""" |
| PredictionResult.example.__doc__ = """The input example.""" |
| PredictionResult.inference.__doc__ = """Results for the inference on the model |
| for the given example.""" |
| PredictionResult.model_id.__doc__ = """Model ID used to run the prediction.""" |
| |
| |
| class RateLimitExceeded(RuntimeError): |
| """RateLimit Exceeded to process a batch of requests.""" |
| pass |
| |
| |
| class ModelMetadata(NamedTuple): |
| model_id: str |
| model_name: str |
| |
| |
| class RunInferenceDLQ(NamedTuple): |
| failed_inferences: beam.PCollection |
| failed_preprocessing: Sequence[beam.PCollection] |
| failed_postprocessing: Sequence[beam.PCollection] |
| |
| |
| class _ModelLoadStats(NamedTuple): |
| model_tag: str |
| load_latency: Optional[int] |
| byte_size: Optional[int] |
| |
| |
| ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be |
| a file path or a URL where the model can be accessed. It is used to load |
| the model for inference.""" |
| ModelMetadata.model_name.__doc__ = """Human-readable name for the model. This |
| can be used to identify the model in the metrics generated by the |
| RunInference transform.""" |
| |
| |
| def _to_milliseconds(time_ns: int) -> int: |
| return int(time_ns / _NANOSECOND_TO_MILLISECOND) |
| |
| |
| def _to_microseconds(time_ns: int) -> int: |
| return int(time_ns / _NANOSECOND_TO_MICROSECOND) |
| |
| |
| @dataclass(frozen=True) |
| class KeyModelPathMapping(Generic[KeyT]): |
| """ |
| Dataclass for mapping 1 or more keys to 1 model path. This is used in |
| conjunction with a KeyedModelHandler with many model handlers to update |
| a set of keys' model handlers with the new path. Given |
| `KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path', |
| model_id: 'id1')`, all examples with keys `key1` or `key2` will have their |
| corresponding model handler's update_model function called with |
| 'updated/path' and their metrics will correspond with 'id1'. For more |
| information see the KeyedModelHandler documentation |
| https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler |
| documentation and the website section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| keys: list[KeyT] |
| update_path: str |
| model_id: str = '' |
| |
| |
| class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): |
| """Has the ability to load and apply an ML model.""" |
| def __init__( |
| self, |
| *, |
| min_batch_size: Optional[int] = None, |
| max_batch_size: Optional[int] = None, |
| max_batch_duration_secs: Optional[int] = None, |
| max_batch_weight: Optional[int] = None, |
| element_size_fn: Optional[Callable[[Any], int]] = None, |
| large_model: bool = False, |
| model_copies: Optional[int] = None, |
| **kwargs): |
| """Initializes the ModelHandler. |
| |
| Args: |
| min_batch_size: the minimum batch size to use when batching inputs. |
| max_batch_size: the maximum batch size to use when batching inputs. |
| max_batch_duration_secs: the maximum amount of time to buffer a batch |
| before emitting; used in streaming contexts. |
| max_batch_weight: the maximum weight of a batch. Requires element_size_fn. |
| element_size_fn: a function that returns the size (weight) of an element. |
| large_model: set to true if your model is large enough to run into |
| memory pressure if you load multiple copies. |
| model_copies: The exact number of models that you would like loaded |
| onto your machine. |
| kwargs: 'env_vars' can be used to set environment variables |
| before loading the model. |
| """ |
| self._env_vars = kwargs.get('env_vars', {}) |
| self._batching_kwargs: dict[str, Any] = {} |
| if min_batch_size is not None: |
| self._batching_kwargs['min_batch_size'] = min_batch_size |
| if max_batch_size is not None: |
| self._batching_kwargs['max_batch_size'] = max_batch_size |
| if max_batch_duration_secs is not None: |
| self._batching_kwargs['max_batch_duration_secs'] = max_batch_duration_secs |
| if max_batch_weight is not None: |
| self._batching_kwargs['max_batch_weight'] = max_batch_weight |
| if element_size_fn is not None: |
| self._batching_kwargs['element_size_fn'] = element_size_fn |
| self._large_model = large_model |
| self._model_copies = model_copies |
| self._share_across_processes = large_model or (model_copies is not None) |
| |
| def load_model(self) -> ModelT: |
| """Loads and initializes a model for processing.""" |
| raise NotImplementedError(type(self)) |
| |
| def run_inference( |
| self, |
| batch: Sequence[ExampleT], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None) -> Iterable[PredictionT]: |
| """Runs inferences on a batch of examples. |
| |
| Args: |
| batch: A sequence of examples or features. |
| model: The model used to make inferences. |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| |
| Returns: |
| An Iterable of Predictions. |
| """ |
| raise NotImplementedError(type(self)) |
| |
| def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: |
| """ |
| Returns: |
| The number of bytes of data for a batch. |
| """ |
| return len(pickle.dumps(batch)) |
| |
| def get_metrics_namespace(self) -> str: |
| """ |
| Returns: |
| A namespace for metrics collected by the RunInference transform. |
| """ |
| return 'RunInference' |
| |
| def get_resource_hints(self) -> dict: |
| """ |
| Returns: |
| Resource hints for the transform. |
| """ |
| return {} |
| |
| def batch_elements_kwargs(self) -> Mapping[str, Any]: |
| """ |
| Returns: |
| kwargs suitable for beam.BatchElements. |
| """ |
| return getattr(self, '_batching_kwargs', {}) |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| """ |
| Allows model handlers to provide some validation to make sure passed in |
| inference args are valid. Some ModelHandlers throw here to disallow |
| inference args altogether. |
| """ |
| pass |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| """ |
| Update the model path produced by side inputs. update_model_path should be |
| used when a ModelHandler represents a single model, not multiple models. |
| This will be true in most cases. For more information see the website |
| section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| pass |
| |
| def update_model_paths( |
| self, |
| model: ModelT, |
| model_paths: Optional[Union[str, list[KeyModelPathMapping]]] = None): |
| """ |
| Update the model paths produced by side inputs. update_model_paths should |
| be used when updating multiple models at once (e.g. when using a |
| KeyedModelHandler that holds multiple models). For more information see |
| the KeyedModelHandler documentation |
| https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler |
| documentation and the website section on model updates |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| """ |
| pass |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| """Gets all preprocessing functions to be run before batching/inference. |
| Functions are in order that they should be applied.""" |
| return [] |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| """Gets all postprocessing functions to be run after inference. |
| Functions are in order that they should be applied.""" |
| return [] |
| |
| def should_skip_batching(self) -> bool: |
| """Whether RunInference's batching should be skipped. Can be flipped to |
| True by using `with_no_batching`""" |
| return False |
| |
| def set_environment_vars(self): |
| """Sets environment variables using a dictionary provided via kwargs. |
| Keys are the env variable name, and values are the env variable value. |
| Child ModelHandler classes should set _env_vars via kwargs in __init__, |
| or else call super().__init__().""" |
| env_vars = getattr(self, '_env_vars', {}) |
| for env_variable, env_value in env_vars.items(): |
| os.environ[env_variable] = env_value |
| |
| def with_preprocess_fn( |
| self, fn: Callable[[PreProcessT], ExampleT] |
| ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT]': |
| """Returns a new ModelHandler with a preprocessing function |
| associated with it. The preprocessing function will be run |
| before batching/inference and should map your input PCollection |
| to the base ModelHandler's input type. If you apply multiple |
| preprocessing functions, they will be run on your original |
| PCollection in order from last applied to first applied.""" |
| return _PreProcessingModelHandler(self, fn) |
| |
| def with_postprocess_fn( |
| self, fn: Callable[[PredictionT], PostProcessT] |
| ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT]': |
| """Returns a new ModelHandler with a postprocessing function |
| associated with it. The postprocessing function will be run |
| after inference and should map the base ModelHandler's output |
| type to your desired output type. If you apply multiple |
| postprocessing functions, they will be run on your original |
| inference result in order from first applied to last applied.""" |
| return _PostProcessingModelHandler(self, fn) |
| |
| def with_no_batching( |
| self |
| ) -> """ModelHandler[Union[ |
| ExampleT, Iterable[ExampleT]], PostProcessT, ModelT, PostProcessT]""": |
| """Returns a new ModelHandler which does not require batching |
| of inputs so that RunInference will skip this step. RunInference will |
| expect the input to be pre-batched and passed in as an Iterable of records. |
| If you skip batching, any preprocessing functions should accept a batch of |
| data, not just a single record. |
| |
| This option is only recommended if you want to do custom batching yourself. |
| If you just want to pass in records without a batching dimension, it is |
| recommended to (1) add `max_batch_size=1` to `batch_elements_kwargs` and |
| (2) remove the batching dimension as part of your inference call (by |
| calling `record=batch[0]`)""" |
| return _PrebatchedModelHandler(self) |
| |
| def share_model_across_processes(self) -> bool: |
| """Returns a boolean representing whether or not a model should |
| be shared across multiple processes instead of being loaded per process. |
| This is primary useful for large models that can't fit multiple copies in |
| memory. Multi-process support may vary by runner, but this will fallback to |
| loading per process as necessary. See |
| https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" |
| return getattr(self, '_share_across_processes', False) |
| |
| def model_copies(self) -> int: |
| """Returns the maximum number of model copies that should be loaded at one |
| time. This only impacts model handlers that are using |
| share_model_across_processes to share their model across processes instead |
| of being loaded per process.""" |
| return getattr(self, '_model_copies', None) or 1 |
| |
| def override_metrics(self, metrics_namespace: str = '') -> bool: |
| """Returns a boolean representing whether or not a model handler will |
| override metrics reporting. If True, RunInference will not report any |
| metrics.""" |
| return False |
| |
| def should_garbage_collect_on_timeout(self) -> bool: |
| """Whether the model should be garbage collected if model loading or |
| inference timeout, or if it should be left for future calls. Usually should |
| not be overriden unless the model handler implements other mechanisms for |
| garbage collection.""" |
| return self.share_model_across_processes() |
| |
| |
| class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]): |
| """Has the ability to call a model at a remote endpoint.""" |
| def __init__( |
| self, |
| namespace: str = '', |
| num_retries: int = 5, |
| throttle_delay_secs: int = 5, |
| retry_filter: Callable[[Exception], bool] = lambda x: True, |
| *, |
| window_ms: int = 1 * _MILLISECOND_TO_SECOND, |
| bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, |
| overload_ratio: float = 2, |
| rate_limiter: Optional[RateLimiter] = None): |
| """Initializes a ReactiveThrottler class for enabling |
| client-side throttling for remote calls to an inference service. Also wraps |
| provided calls to the service with retry logic. |
| |
| See https://s.apache.org/beam-client-side-throttling for more details |
| on the configuration of the throttling and retry |
| mechanics. |
| |
| Args: |
| namespace: the metrics and logging namespace |
| num_retries: the maximum number of times to retry a request on retriable |
| errors before failing |
| throttle_delay_secs: the amount of time to throttle when the client-side |
| elects to throttle |
| retry_filter: a function accepting an exception as an argument and |
| returning a boolean. On a true return, the run_inference call will |
| be retried. Defaults to always retrying. |
| window_ms: length of history to consider, in ms, to set throttling. |
| bucket_ms: granularity of time buckets that we store data in, in ms. |
| overload_ratio: the target ratio between requests sent and successful |
| requests. This is "K" in the formula in |
| https://landing.google.com/sre/book/chapters/handling-overload.html. |
| rate_limiter: A RateLimiter object for setting a global rate limit. |
| """ |
| # Configure ReactiveThrottler for client-side throttling behavior. |
| self.throttler = ReactiveThrottler( |
| window_ms=window_ms, |
| bucket_ms=bucket_ms, |
| overload_ratio=overload_ratio, |
| namespace=namespace, |
| throttle_delay_secs=throttle_delay_secs) |
| self.logger = logging.getLogger(namespace) |
| self.num_retries = num_retries |
| self.retry_filter = retry_filter |
| self._rate_limiter = rate_limiter |
| self._shared_rate_limiter = None |
| self._shared_handle = shared.Shared() |
| |
| def __init_subclass__(cls): |
| if cls.load_model is not RemoteModelHandler.load_model: |
| raise Exception( |
| "Cannot override RemoteModelHandler.load_model, ", |
| "implement create_client instead.") |
| if cls.run_inference is not RemoteModelHandler.run_inference: |
| raise Exception( |
| "Cannot override RemoteModelHandler.run_inference, ", |
| "implement request instead.") |
| |
| @abstractmethod |
| def create_client(self) -> ModelT: |
| """Creates the client that is used to make the remote inference request |
| in request(). All relevant arguments should be passed to __init__(). |
| """ |
| raise NotImplementedError(type(self)) |
| |
| def load_model(self) -> ModelT: |
| return self.create_client() |
| |
| def retry_on_exception(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| return retry.with_exponential_backoff( |
| num_retries=self.num_retries, |
| retry_filter=self.retry_filter)(func)(self, *args, **kwargs) |
| |
| return wrapper |
| |
| @retry_on_exception |
| def run_inference( |
| self, |
| batch: Sequence[ExampleT], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None) -> Iterable[PredictionT]: |
| """Runs inferences on a batch of examples. Calls a remote model for |
| predictions and will retry if a retryable exception is raised. |
| |
| Args: |
| batch: A sequence of examples or features. |
| model: The model used to make inferences. |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| |
| Returns: |
| An Iterable of Predictions. |
| """ |
| if self._rate_limiter: |
| if self._shared_rate_limiter is None: |
| |
| def init_limiter(): |
| return self._rate_limiter |
| |
| self._shared_rate_limiter = self._shared_handle.acquire(init_limiter) |
| |
| if not self._shared_rate_limiter.allow(hits_added=len(batch)): |
| raise RateLimitExceeded( |
| "Rate Limit Exceeded, " |
| "Could not process this batch.") |
| |
| self.throttler.throttle() |
| |
| try: |
| req_time = time.time() |
| predictions = self.request(batch, model, inference_args) |
| self.throttler.successful_request(req_time * _MILLISECOND_TO_SECOND) |
| return predictions |
| except Exception as e: |
| self.logger.error("exception raised as part of request, got %s", e) |
| raise |
| |
| @abstractmethod |
| def request( |
| self, |
| batch: Sequence[ExampleT], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None) -> Iterable[PredictionT]: |
| """Makes a request to a remote inference service and returns the response. |
| Should raise an exception of some kind if there is an error to enable the |
| retry and client-side throttling logic to work. Returns an iterable of the |
| desired prediction type. This method should return the values directly, as |
| handling return values as a generator can prevent the retry logic from |
| functioning correctly. |
| |
| Args: |
| batch: A sequence of examples or features. |
| model: The model used to make inferences. |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| |
| Returns: |
| An Iterable of Predictions. |
| """ |
| raise NotImplementedError(type(self)) |
| |
| |
| class _ModelManager: |
| """ |
| A class for efficiently managing copies of multiple models. Will load a |
| single copy of each model into a multi_process_shared object and then |
| return a lookup key for that object. |
| """ |
| def __init__(self, mh_map: dict[str, ModelHandler]): |
| """ |
| Args: |
| mh_map: A map from keys to model handlers which can be used to load a |
| model. |
| """ |
| self._max_models = None |
| # Map keys to model handlers |
| self._mh_map: dict[str, ModelHandler] = mh_map |
| # Map keys to the last updated model path for that key |
| self._key_to_last_update: dict[str, str] = defaultdict(str) |
| # Map key for a model to a unique tag that will persist for the life of |
| # that model in memory. A new tag will be generated if a model is swapped |
| # out of memory and reloaded. |
| self._tag_map: dict[str, str] = OrderedDict() |
| # Map a tag to a multiprocessshared model object for that tag. Each entry |
| # of this map should last as long as the corresponding entry in _tag_map. |
| self._proxy_map: dict[str, multi_process_shared.MultiProcessShared] = {} |
| |
| def load(self, key: str) -> _ModelLoadStats: |
| """ |
| Loads the appropriate model for the given key into memory. |
| Args: |
| key: the key associated with the model we'd like to load. |
| Returns: |
| _ModelLoadStats with tag, byte size, and latency to load the model. If |
| the model was already loaded, byte size/latency will be None. |
| """ |
| # Map the key for a model to a unique tag that will persist until the model |
| # is released. This needs to be unique between releasing/reacquiring th |
| # model because otherwise the ProxyManager will try to reuse the model that |
| # has been released and deleted. |
| if key in self._tag_map: |
| self._tag_map.move_to_end(key) |
| return _ModelLoadStats(self._tag_map[key], None, None) |
| else: |
| self._tag_map[key] = uuid.uuid4().hex |
| |
| tag = self._tag_map[key] |
| mh = self._mh_map[key] |
| |
| if self._max_models is not None and self._max_models < len(self._tag_map): |
| # If we're about to exceed our LRU size, release the last used model. |
| tag_to_remove = self._tag_map.popitem(last=False)[1] |
| shared_handle, model_to_remove = self._proxy_map[tag_to_remove] |
| shared_handle.release(model_to_remove) |
| del self._proxy_map[tag_to_remove] |
| |
| # Load the new model |
| memory_before = _get_current_process_memory_in_bytes() |
| start_time = _to_milliseconds(time.time_ns()) |
| shared_handle = multi_process_shared.MultiProcessShared( |
| mh.load_model, tag=tag) |
| model_reference = shared_handle.acquire() |
| self._proxy_map[tag] = (shared_handle, model_reference) |
| memory_after = _get_current_process_memory_in_bytes() |
| end_time = _to_milliseconds(time.time_ns()) |
| |
| return _ModelLoadStats( |
| tag, end_time - start_time, memory_after - memory_before) |
| |
| def increment_max_models(self, increment: int): |
| """ |
| Increments the number of models that this instance of a _ModelManager is |
| able to hold. If it is never called, no limit is imposed. |
| Args: |
| increment: the amount by which we are incrementing the number of models. |
| """ |
| if self._max_models is None: |
| self._max_models = 0 |
| self._max_models += increment |
| |
| def update_model_handler(self, key: str, model_path: str, previous_key: str): |
| """ |
| Updates the model path of this model handler and removes it from memory so |
| that it can be reloaded with the updated path. No-ops if no model update |
| needs to be applied. |
| Args: |
| key: the key associated with the model we'd like to update. |
| model_path: the new path to the model we'd like to load. |
| previous_key: the key that is associated with the old version of this |
| model. This will often be the same as the current key, but sometimes |
| we will want to keep both the old and new models to serve different |
| cohorts. In that case, the keys should be different. |
| """ |
| if self._key_to_last_update[key] == model_path: |
| return |
| self._key_to_last_update[key] = model_path |
| if key not in self._mh_map: |
| self._mh_map[key] = deepcopy(self._mh_map[previous_key]) |
| self._mh_map[key].update_model_path(model_path) |
| if key in self._tag_map: |
| tag_to_remove = self._tag_map[key] |
| shared_handle, model_to_remove = self._proxy_map[tag_to_remove] |
| shared_handle.release(model_to_remove) |
| del self._tag_map[key] |
| del self._proxy_map[tag_to_remove] |
| |
| |
| # Use a dataclass instead of named tuple because NamedTuples and generics don't |
| # mix well across the board for all versions: |
| # https://github.com/python/typing/issues/653 |
| class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): |
| """ |
| Dataclass for mapping 1 or more keys to 1 model handler. Given |
| `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` |
| or `key2` will be run against the model defined by the `myMh` ModelHandler. |
| """ |
| def __init__( |
| self, keys: list[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]): |
| self.keys = keys |
| self.mh = mh |
| |
| |
| class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], |
| ModelHandler[tuple[KeyT, ExampleT], |
| tuple[KeyT, PredictionT], |
| Union[ModelT, _ModelManager]]): |
| def __init__( |
| self, |
| unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], |
| list[KeyModelMapping[KeyT, ExampleT, PredictionT, |
| ModelT]]], |
| max_models_per_worker_hint: Optional[int] = None): |
| """A ModelHandler that takes keyed examples and returns keyed predictions. |
| |
| For example, if the original model is used with RunInference to take a |
| PCollection[E] to a PCollection[P], this ModelHandler would take a |
| PCollection[tuple[K, E]] to a PCollection[tuple[K, P]], making it possible |
| to use the key to associate the outputs with the inputs. KeyedModelHandler |
| is able to accept either a single unkeyed ModelHandler or many different |
| model handlers corresponding to the keys for which that ModelHandler should |
| be used. For example, the following configuration could be used to map keys |
| 1-3 to ModelHandler1 and keys 4-5 to ModelHandler2: |
| |
| k1 = ['k1', 'k2', 'k3'] |
| k2 = ['k4', 'k5'] |
| KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)]) |
| |
| Note that a single copy of each of these models may all be held in memory |
| at the same time; be careful not to load too many large models or your |
| pipeline may cause Out of Memory exceptions. |
| |
| KeyedModelHandlers support Automatic Model Refresh to update your model |
| to a newer version without stopping your streaming pipeline. For an |
| overview of this feature, see |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| |
| |
| To use this feature with a KeyedModelHandler that has many models per key, |
| you can pass in a list of KeyModelPathMapping objects to define your new |
| model paths. For example, passing in the side input of |
| |
| [KeyModelPathMapping(keys=['k1', 'k2'], update_path='update/path/1'), |
| KeyModelPathMapping(keys=['k3'], update_path='update/path/2')] |
| |
| will update the model corresponding to keys 'k1' and 'k2' with path |
| 'update/path/1' and the model corresponding to 'k3' with 'update/path/2'. |
| In order to do a side input update: (1) all restrictions mentioned in |
| https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh |
| must be met, (2) all update_paths must be non-empty, even if they are not |
| being updated from their original values, and (3) the set of keys |
| originally defined cannot change. This means that if originally you have |
| defined model handlers for 'key1', 'key2', and 'key3', all 3 of those keys |
| must appear in your list of KeyModelPathMappings exactly once. No |
| additional keys can be added. |
| |
| When using many models defined per key, metrics about inference and model |
| loading will be gathered on an aggregate basis for all keys. These will be |
| reported with no prefix. Metrics will also be gathered on a per key basis. |
| Since some keys can share the same model, only one set of metrics will be |
| reported per key 'cohort'. These will be reported in the form: |
| `<cohort_key>-<metric_name>`, where `<cohort_key>` can be any key selected |
| from the cohort. When model updates occur, the metrics will be reported in |
| the form `<cohort_key>-<model id>-<metric_name>`. |
| |
| Loading multiple models at the same time can increase the risk of an out of |
| memory (OOM) exception. To avoid this issue, use the parameter |
| `max_models_per_worker_hint` to limit the number of models that are loaded |
| at the same time. For more information about memory management, see |
| `Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_`. # pylint: disable=line-too-long |
| |
| |
| Args: |
| unkeyed: Either (a) an implementation of ModelHandler that does not |
| require keys or (b) a list of KeyModelMappings mapping lists of keys to |
| unkeyed ModelHandlers. |
| max_models_per_worker_hint: A hint to the runner indicating how many |
| models can be held in memory at one time per worker process. For |
| example, if your worker has 8 GB of memory provisioned and your workers |
| take up 1 GB each, you should set this to 7 to allow all models to sit |
| in memory with some buffer. For more information about memory management, |
| see `Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_`. # pylint: disable=line-too-long |
| """ |
| self._metrics_collectors: dict[str, _MetricsCollector] = {} |
| self._default_metrics_collector: _MetricsCollector = None |
| self._metrics_namespace = '' |
| self._single_model = not isinstance(unkeyed, list) |
| if self._single_model: |
| if len(unkeyed.get_preprocess_fns()) or len( |
| unkeyed.get_postprocess_fns()): |
| raise Exception( |
| 'Cannot make make an unkeyed model handler with pre or ' |
| 'postprocessing functions defined into a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| self._env_vars = getattr(unkeyed, '_env_vars', {}) |
| self._unkeyed = unkeyed |
| return |
| |
| self._max_models_per_worker_hint = max_models_per_worker_hint |
| # To maintain an efficient representation, we will map all keys in a given |
| # KeyModelMapping to a single id (the first key in the KeyModelMapping |
| # list). We will then map that key to a ModelHandler. This will allow us to |
| # quickly look up the appropriate ModelHandler for any given key. |
| self._id_to_mh_map: dict[str, ModelHandler[ExampleT, PredictionT, |
| ModelT]] = {} |
| self._key_to_id_map: dict[str, str] = {} |
| for mh_tuple in unkeyed: |
| mh = mh_tuple.mh |
| keys = mh_tuple.keys |
| if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()): |
| raise ValueError( |
| 'Cannot use an unkeyed model handler with pre or ' |
| 'postprocessing functions defined in a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| hints = mh.get_resource_hints() |
| if len(hints) > 0: |
| logging.warning( |
| 'mh %s defines the following resource hints, which will be' |
| 'ignored: %s. Resource hints are not respected when more than one ' |
| 'model handler is used in a KeyedModelHandler. If you would like ' |
| 'to specify resource hints, you can do so by overriding the ' |
| 'KeyedModelHandler.get_resource_hints() method.', |
| mh, |
| hints) |
| batch_kwargs = mh.batch_elements_kwargs() |
| if len(batch_kwargs) > 0: |
| logging.warning( |
| 'mh %s defines the following batching kwargs which will be ' |
| 'ignored %s. Batching kwargs are not respected when ' |
| 'more than one model handler is used in a KeyedModelHandler. If ' |
| 'you would like to specify resource hints, you can do so by ' |
| 'overriding the KeyedModelHandler.batch_elements_kwargs() method.', |
| hints, |
| batch_kwargs) |
| env_vars = getattr(mh, '_env_vars', {}) |
| if len(env_vars) > 0: |
| logging.warning( |
| 'mh %s defines the following _env_vars which will be ignored %s. ' |
| '_env_vars are not respected when more than one model handler is ' |
| 'used in a KeyedModelHandler. If you need env vars set at ' |
| 'inference time, you can do so with ' |
| 'a custom inference function.', |
| mh, |
| env_vars) |
| |
| if len(keys) == 0: |
| raise ValueError( |
| f'Empty list maps to model handler {mh}. All model handlers must ' |
| 'have one or more associated keys.') |
| self._id_to_mh_map[keys[0]] = mh |
| for key in keys: |
| if key in self._key_to_id_map: |
| raise ValueError( |
| f'key {key} maps to multiple model handlers. All keys must map ' |
| 'to exactly one model handler.') |
| self._key_to_id_map[key] = keys[0] |
| |
| def load_model(self) -> Union[ModelT, _ModelManager]: |
| if self._single_model: |
| return self._unkeyed.load_model() |
| return _ModelManager(self._id_to_mh_map) |
| |
| def run_inference( |
| self, |
| batch: Sequence[tuple[KeyT, ExampleT]], |
| model: Union[ModelT, _ModelManager], |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Iterable[tuple[KeyT, PredictionT]]: |
| if self._single_model: |
| keys, unkeyed_batch = zip(*batch) |
| return zip( |
| keys, |
| self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) |
| |
| # The first time a MultiProcessShared ModelManager is used for inference |
| # from this process, we should increment its max model count |
| if self._max_models_per_worker_hint is not None: |
| lock = threading.Lock() |
| if lock.acquire(blocking=False): |
| model.increment_max_models(self._max_models_per_worker_hint) |
| self._max_models_per_worker_hint = None |
| |
| batch_by_key = defaultdict(list) |
| key_by_id = defaultdict(set) |
| for key, example in batch: |
| batch_by_key[key].append(example) |
| key_by_id[self._key_to_id_map[key]].add(key) |
| |
| predictions = [] |
| for id, keys in key_by_id.items(): |
| mh = self._id_to_mh_map[id] |
| loaded_model = model.load(id) |
| keyed_model_tag = loaded_model.model_tag |
| if loaded_model.byte_size is not None: |
| self._metrics_collectors[id].update_load_model_metrics( |
| loaded_model.load_latency, loaded_model.byte_size) |
| self._default_metrics_collector.update_load_model_metrics( |
| loaded_model.load_latency, loaded_model.byte_size) |
| keyed_model_shared_handle = multi_process_shared.MultiProcessShared( |
| mh.load_model, tag=keyed_model_tag) |
| keyed_model = keyed_model_shared_handle.acquire() |
| start_time = _to_microseconds(time.time_ns()) |
| num_bytes = 0 |
| num_elements = 0 |
| try: |
| for key in keys: |
| unkeyed_batches = batch_by_key[key] |
| try: |
| for inf in mh.run_inference(unkeyed_batches, |
| keyed_model, |
| inference_args): |
| predictions.append((key, inf)) |
| except BaseException as e: |
| self._metrics_collectors[id].failed_batches_counter.inc() |
| self._default_metrics_collector.failed_batches_counter.inc() |
| raise e |
| num_bytes += mh.get_num_bytes(unkeyed_batches) |
| num_elements += len(unkeyed_batches) |
| finally: |
| keyed_model_shared_handle.release(keyed_model) |
| end_time = _to_microseconds(time.time_ns()) |
| inference_latency = end_time - start_time |
| self._metrics_collectors[id].update( |
| num_elements, num_bytes, inference_latency) |
| self._default_metrics_collector.update( |
| num_elements, num_bytes, inference_latency) |
| |
| return predictions |
| |
| def get_num_bytes(self, batch: Sequence[tuple[KeyT, ExampleT]]) -> int: |
| keys, unkeyed_batch = zip(*batch) |
| batch_bytes = len(pickle.dumps(keys)) |
| if self._single_model: |
| return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch) |
| |
| batch_by_key = defaultdict(list) |
| for key, examples in batch: |
| batch_by_key[key].append(examples) |
| |
| for key, examples in batch_by_key.items(): |
| mh_id = self._key_to_id_map[key] |
| batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples) |
| return batch_bytes |
| |
| def get_metrics_namespace(self) -> str: |
| if self._single_model: |
| return self._unkeyed.get_metrics_namespace() |
| return 'BeamML_KeyedModels' |
| |
| def get_resource_hints(self): |
| if self._single_model: |
| return self._unkeyed.get_resource_hints() |
| return {} |
| |
| def batch_elements_kwargs(self): |
| if self._single_model: |
| return self._unkeyed.batch_elements_kwargs() |
| return {} |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| if self._single_model: |
| return self._unkeyed.validate_inference_args(inference_args) |
| for mh in self._id_to_mh_map.values(): |
| mh.validate_inference_args(inference_args) |
| |
| def update_model_paths( |
| self, |
| model: Union[ModelT, _ModelManager], |
| model_paths: list[KeyModelPathMapping[KeyT]] = None): |
| # When there are many models, the keyed model handler is responsible for |
| # reorganizing the model handlers into cohorts and telling the model |
| # manager to update every cohort's associated model handler. The model |
| # manager is responsible for performing the updates and tracking which |
| # updates have already been applied. |
| if model_paths is None or len(model_paths) == 0 or model is None: |
| return |
| if self._single_model: |
| raise RuntimeError( |
| 'Invalid model update: sent many model paths to ' |
| 'update, but KeyedModelHandler is wrapping a single ' |
| 'model.') |
| # Map cohort ids to a dictionary mapping new model paths to the keys that |
| # were originally in that cohort. We will use this to construct our new |
| # cohorts. |
| # cohort_path_mapping will be structured as follows: |
| # { |
| # original_cohort_id: { |
| # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'], |
| # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'], |
| # } |
| # } |
| cohort_path_mapping: dict[KeyT, dict[str, list[KeyT]]] = {} |
| key_modelid_mapping: dict[KeyT, str] = {} |
| seen_keys = set() |
| for mp in model_paths: |
| keys = mp.keys |
| update_path = mp.update_path |
| model_id = mp.model_id |
| if len(update_path) == 0: |
| raise ValueError(f'Invalid model update, path for {keys} is empty') |
| for key in keys: |
| if key in seen_keys: |
| raise ValueError( |
| f'Invalid model update: {key} appears in multiple ' |
| 'update lists. A single model update must provide exactly one ' |
| 'updated path per key.') |
| seen_keys.add(key) |
| if key not in self._key_to_id_map: |
| raise ValueError( |
| f'Invalid model update: {key} appears in ' |
| 'update, but not in the original configuration.') |
| key_modelid_mapping[key] = model_id |
| cohort_id = self._key_to_id_map[key] |
| if cohort_id not in cohort_path_mapping: |
| cohort_path_mapping[cohort_id] = defaultdict(list) |
| cohort_path_mapping[cohort_id][update_path].append(key) |
| for key in self._key_to_id_map: |
| if key not in seen_keys: |
| raise ValueError( |
| f'Invalid model update: {key} appears in the ' |
| 'original configuration, but not the update.') |
| |
| # We now have our new set of cohorts. For each one, update our local model |
| # handler configuration and send the results to the ModelManager |
| for old_cohort_id, path_key_mapping in cohort_path_mapping.items(): |
| for updated_path, keys in path_key_mapping.items(): |
| cohort_id = old_cohort_id |
| if old_cohort_id not in keys: |
| # Create new cohort |
| cohort_id = keys[0] |
| for key in keys: |
| self._key_to_id_map[key] = cohort_id |
| mh = self._id_to_mh_map[old_cohort_id] |
| self._id_to_mh_map[cohort_id] = deepcopy(mh) |
| self._id_to_mh_map[cohort_id].update_model_path(updated_path) |
| model.update_model_handler(cohort_id, updated_path, old_cohort_id) |
| model_id = key_modelid_mapping[cohort_id] |
| self._metrics_collectors[cohort_id] = _MetricsCollector( |
| self._metrics_namespace, f'{cohort_id}-{model_id}-') |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| if self._single_model: |
| return self._unkeyed.update_model_path(model_path=model_path) |
| if model_path is not None: |
| raise RuntimeError( |
| 'Model updates are currently not supported for ' + |
| 'KeyedModelHandlers with multiple different per-key ' + |
| 'ModelHandlers.') |
| |
| def share_model_across_processes(self) -> bool: |
| if self._single_model: |
| return self._unkeyed.share_model_across_processes() |
| return True |
| |
| def model_copies(self) -> int: |
| if self._single_model: |
| return self._unkeyed.model_copies() |
| for mh in self._id_to_mh_map.values(): |
| if mh.model_copies() != 1: |
| raise ValueError( |
| 'KeyedModelHandler cannot map records to multiple ' |
| 'models if one or more of its ModelHandlers ' |
| 'require multiple model copies (set via ' |
| 'model_copies). To fix, verify that each ' |
| 'ModelHandler is not set to load multiple copies of ' |
| 'its model.') |
| |
| return 1 |
| |
| def override_metrics(self, metrics_namespace: str = '') -> bool: |
| if self._single_model: |
| return self._unkeyed.override_metrics(metrics_namespace) |
| |
| self._metrics_namespace = metrics_namespace |
| self._default_metrics_collector = _MetricsCollector(metrics_namespace) |
| for cohort_id in self._id_to_mh_map: |
| self._metrics_collectors[cohort_id] = _MetricsCollector( |
| metrics_namespace, f'{cohort_id}-') |
| |
| return True |
| |
| def should_garbage_collect_on_timeout(self) -> bool: |
| return self._single_model and self.share_model_across_processes() |
| |
| |
| class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], |
| ModelHandler[Union[ExampleT, tuple[KeyT, |
| ExampleT]], |
| Union[PredictionT, |
| tuple[KeyT, PredictionT]], |
| ModelT]): |
| def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): |
| """A ModelHandler that takes examples that might have keys and returns |
| predictions that might have keys. |
| |
| For example, if the original model is used with RunInference to take a |
| PCollection[E] to a PCollection[P], this ModelHandler would take either |
| PCollection[E] to a PCollection[P] or PCollection[tuple[K, E]] to a |
| PCollection[tuple[K, P]], depending on the whether the elements are |
| tuples. This pattern makes it possible to associate the outputs with the |
| inputs based on the key. |
| |
| Note that you cannot use this ModelHandler if E is a tuple type. |
| In addition, either all examples should be keyed, or none of them. |
| |
| Args: |
| unkeyed: An implementation of ModelHandler that does not require keys. |
| """ |
| if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()): |
| raise Exception( |
| 'Cannot make make an unkeyed model handler with pre or ' |
| 'postprocessing functions defined into a keyed model handler. All ' |
| 'pre/postprocessing functions must be defined on the outer model' |
| 'handler.') |
| self._unkeyed = unkeyed |
| self._env_vars = getattr(unkeyed, '_env_vars', {}) |
| |
| def load_model(self) -> ModelT: |
| return self._unkeyed.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[tuple[KeyT, PredictionT]]]: |
| # Really the input should be |
| # Union[Sequence[ExampleT], Sequence[tuple[KeyT, ExampleT]]] |
| # but there's not a good way to express (or check) that. |
| if isinstance(batch[0], tuple): |
| is_keyed = True |
| keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] |
| else: |
| is_keyed = False |
| unkeyed_batch = batch # type: ignore[assignment] |
| unkeyed_results = self._unkeyed.run_inference( |
| unkeyed_batch, model, inference_args) |
| if is_keyed: |
| return zip(keys, unkeyed_results) |
| else: |
| return unkeyed_results |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]]) -> int: |
| # MyPy can't follow the branching logic. |
| if isinstance(batch[0], tuple): |
| keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] |
| return len( |
| pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) |
| else: |
| return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type] |
| |
| def get_metrics_namespace(self) -> str: |
| return self._unkeyed.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._unkeyed.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._unkeyed.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| return self._unkeyed.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._unkeyed.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._unkeyed.get_preprocess_fns() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._unkeyed.get_postprocess_fns() |
| |
| def should_skip_batching(self) -> bool: |
| return self._unkeyed.should_skip_batching() |
| |
| def share_model_across_processes(self) -> bool: |
| return self._unkeyed.share_model_across_processes() |
| |
| def model_copies(self) -> int: |
| return self._unkeyed.model_copies() |
| |
| |
| class _PrebatchedModelHandler(Generic[ExampleT, PredictionT, ModelT], |
| ModelHandler[Sequence[ExampleT], |
| PredictionT, |
| ModelT]): |
| def __init__(self, base: ModelHandler[ExampleT, PredictionT, ModelT]): |
| """A ModelHandler that skips batching in RunInference. |
| |
| Args: |
| base: An implementation of the underlying model handler. |
| """ |
| self._base = base |
| self._env_vars = getattr(base, '_env_vars', {}) |
| |
| def load_model(self) -> ModelT: |
| return self._base.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[tuple[KeyT, PredictionT]]]: |
| return self._base.run_inference(batch, model, inference_args) |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]]) -> int: |
| return self._base.get_num_bytes(batch) |
| |
| def get_metrics_namespace(self) -> str: |
| return self._base.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._base.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._base.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| return self._base.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._base.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_preprocess_fns() |
| |
| def should_skip_batching(self) -> bool: |
| return True |
| |
| def share_model_across_processes(self) -> bool: |
| return self._base.share_model_across_processes() |
| |
| def model_copies(self) -> int: |
| return self._base.model_copies() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_postprocess_fns() |
| |
| |
| class _PreProcessingModelHandler(Generic[ExampleT, |
| PredictionT, |
| ModelT, |
| PreProcessT], |
| ModelHandler[PreProcessT, PredictionT, |
| ModelT]): |
| def __init__( |
| self, |
| base: ModelHandler[ExampleT, PredictionT, ModelT], |
| preprocess_fn: Callable[[PreProcessT], ExampleT]): |
| """A ModelHandler that has a preprocessing function associated with it. |
| |
| Args: |
| base: An implementation of the underlying model handler. |
| preprocess_fn: the preprocessing function to use. |
| """ |
| self._base = base |
| self._env_vars = getattr(base, '_env_vars', {}) |
| self._preprocess_fn = preprocess_fn |
| |
| def load_model(self) -> ModelT: |
| return self._base.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[tuple[KeyT, PredictionT]]]: |
| return self._base.run_inference(batch, model, inference_args) |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]]) -> int: |
| return self._base.get_num_bytes(batch) |
| |
| def get_metrics_namespace(self) -> str: |
| return self._base.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._base.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._base.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| return self._base.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._base.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return [self._preprocess_fn] + self._base.get_preprocess_fns() |
| |
| def should_skip_batching(self) -> bool: |
| return self._base.should_skip_batching() |
| |
| def share_model_across_processes(self) -> bool: |
| return self._base.share_model_across_processes() |
| |
| def model_copies(self) -> int: |
| return self._base.model_copies() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_postprocess_fns() |
| |
| |
| class _PostProcessingModelHandler(Generic[ExampleT, |
| PredictionT, |
| ModelT, |
| PostProcessT], |
| ModelHandler[ExampleT, PostProcessT, ModelT]): |
| def __init__( |
| self, |
| base: ModelHandler[ExampleT, PredictionT, ModelT], |
| postprocess_fn: Callable[[PredictionT], PostProcessT]): |
| """A ModelHandler that has a preprocessing function associated with it. |
| |
| Args: |
| base: An implementation of the underlying model handler. |
| postprocess_fn: the preprocessing function to use. |
| """ |
| self._base = base |
| self._env_vars = getattr(base, '_env_vars', {}) |
| self._postprocess_fn = postprocess_fn |
| |
| def load_model(self) -> ModelT: |
| return self._base.load_model() |
| |
| def run_inference( |
| self, |
| batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]], |
| model: ModelT, |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Union[Iterable[PredictionT], Iterable[tuple[KeyT, PredictionT]]]: |
| return self._base.run_inference(batch, model, inference_args) |
| |
| def get_num_bytes( |
| self, batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]]) -> int: |
| return self._base.get_num_bytes(batch) |
| |
| def get_metrics_namespace(self) -> str: |
| return self._base.get_metrics_namespace() |
| |
| def get_resource_hints(self): |
| return self._base.get_resource_hints() |
| |
| def batch_elements_kwargs(self): |
| return self._base.batch_elements_kwargs() |
| |
| def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): |
| return self._base.validate_inference_args(inference_args) |
| |
| def update_model_path(self, model_path: Optional[str] = None): |
| return self._base.update_model_path(model_path=model_path) |
| |
| def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_preprocess_fns() |
| |
| def should_skip_batching(self) -> bool: |
| return self._base.should_skip_batching() |
| |
| def share_model_across_processes(self) -> bool: |
| return self._base.share_model_across_processes() |
| |
| def model_copies(self) -> int: |
| return self._base.model_copies() |
| |
| def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: |
| return self._base.get_postprocess_fns() + [self._postprocess_fn] |
| |
| |
| class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT, |
| Iterable[ExampleT]]], |
| beam.PCollection[PredictionT]]): |
| def __init__( |
| self, |
| model_handler: ModelHandler[ExampleT, PredictionT, Any], |
| clock=time, |
| inference_args: Optional[dict[str, Any]] = None, |
| metrics_namespace: Optional[str] = None, |
| *, |
| model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, |
| watch_model_pattern: Optional[str] = None, |
| model_identifier: Optional[str] = None, |
| **kwargs): |
| """ |
| A transform that takes a PCollection of examples (or features) for use |
| on an ML model. The transform then outputs inferences (or predictions) for |
| those examples in a PCollection of PredictionResults that contains the input |
| examples and the output inferences. |
| |
| Models for supported frameworks can be loaded using a URI. Supported |
| services can also be used. |
| |
| This transform attempts to batch examples using the beam.BatchElements |
| transform. Batching can be configured using the ModelHandler. |
| |
| Args: |
| model_handler: An implementation of ModelHandler. |
| clock: A clock implementing time_ns. *Used for unit testing.* |
| inference_args: Extra arguments for models whose inference call requires |
| extra parameters. |
| metrics_namespace: Namespace of the transform to collect metrics. |
| model_metadata_pcoll: PCollection that emits Singleton ModelMetadata |
| containing model path and model name, that is used as a side input |
| to the _RunInferenceDoFn. |
| watch_model_pattern: A glob pattern used to watch a directory |
| for automatic model refresh. |
| model_identifier: A string used to identify the model being loaded. You |
| can set this if you want to reuse the same model across multiple |
| RunInference steps and don't want to reload it twice. Note that using |
| the same tag for different models will lead to non-deterministic |
| results, so exercise caution when using this parameter. This only |
| impacts models which are already being shared across processes. |
| """ |
| self._model_handler = model_handler |
| self._inference_args = inference_args |
| self._clock = clock |
| self._metrics_namespace = metrics_namespace |
| self._model_metadata_pcoll = model_metadata_pcoll |
| self._with_exception_handling = False |
| self._exception_handling_timeout = None |
| self._timeout = None |
| self._watch_model_pattern = watch_model_pattern |
| self._kwargs = kwargs |
| # Generate a random tag to use for shared.py and multi_process_shared.py to |
| # allow us to effectively disambiguate in multi-model settings. Only use |
| # the same tag if the model being loaded across multiple steps is actually |
| # the same. |
| self._model_tag = model_identifier |
| if model_identifier is None: |
| self._model_tag = uuid.uuid4().hex |
| |
| def annotations(self): |
| return { |
| 'model_handler': str(self._model_handler), |
| 'model_handler_type': ( |
| f'{self._model_handler.__class__.__module__}' |
| f'.{self._model_handler.__class__.__qualname__}'), |
| **super().annotations() |
| } |
| |
| def _get_model_metadata_pcoll(self, pipeline): |
| # avoid circular imports. |
| # pylint: disable=wrong-import-position |
| from apache_beam.ml.inference.utils import WatchFilePattern |
| extra_params = {} |
| if 'interval' in self._kwargs: |
| extra_params['interval'] = self._kwargs['interval'] |
| if 'stop_timestamp' in self._kwargs: |
| extra_params['stop_timestamp'] = self._kwargs['stop_timestamp'] |
| |
| return ( |
| pipeline | WatchFilePattern( |
| file_pattern=self._watch_model_pattern, **extra_params)) |
| |
| # TODO(BEAM-14046): Add and link to help documentation. |
| @classmethod |
| def from_callable(cls, model_handler_provider, **kwargs): |
| """Multi-language friendly constructor. |
| |
| Use this constructor with fully_qualified_named_transform to |
| initialize the RunInference transform from PythonCallableSource provided |
| by foreign SDKs. |
| |
| Args: |
| model_handler_provider: A callable object that returns ModelHandler. |
| kwargs: Keyword arguments for model_handler_provider. |
| """ |
| return cls(model_handler_provider(**kwargs)) |
| |
| def _apply_fns( |
| self, |
| pcoll: beam.PCollection, |
| fns: Iterable[Callable[[Any], Any]], |
| step_prefix: str) -> tuple[beam.PCollection, Iterable[beam.PCollection]]: |
| bad_preprocessed = [] |
| for idx in range(len(fns)): |
| fn = fns[idx] |
| if self._with_exception_handling: |
| pcoll, bad = (pcoll |
| | f"{step_prefix}-{idx}" >> beam.Map( |
| fn).with_exception_handling( |
| exc_class=self._exc_class, |
| use_subprocess=self._use_subprocess, |
| threshold=self._threshold, |
| timeout = self._timeout)) |
| bad_preprocessed.append(bad) |
| else: |
| pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn) |
| |
| return pcoll, bad_preprocessed |
| |
| # TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off |
| # in the case there are functional reasons large batch sizes cannot be |
| # handled. |
| def expand( |
| self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]: |
| self._model_handler.validate_inference_args(self._inference_args) |
| # DLQ pcollections |
| bad_preprocessed = [] |
| bad_inference = None |
| bad_postprocessed = [] |
| preprocess_fns = self._model_handler.get_preprocess_fns() |
| postprocess_fns = self._model_handler.get_postprocess_fns() |
| |
| pcoll, bad_preprocessed = self._apply_fns( |
| pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess') |
| |
| resource_hints = self._model_handler.get_resource_hints() |
| |
| # check for the side input |
| if self._watch_model_pattern: |
| self._model_metadata_pcoll = self._get_model_metadata_pcoll( |
| pcoll.pipeline) |
| |
| if self._model_handler.should_skip_batching(): |
| batched_elements_pcoll = pcoll |
| else: |
| batched_elements_pcoll = ( |
| pcoll |
| # TODO(https://github.com/apache/beam/issues/21440): Hook into the |
| # batching DoFn APIs. |
| | beam.BatchElements(**self._model_handler.batch_elements_kwargs())) |
| |
| # Skip loading in setup if we are dependent on side inputs or we want to |
| # enforce a timeout since neither of these are available in a helpful way |
| # in setup. |
| load_model_at_runtime = ( |
| self._model_metadata_pcoll is not None or self._timeout is not None) |
| run_inference_pardo = beam.ParDo( |
| _RunInferenceDoFn( |
| self._model_handler, |
| self._clock, |
| self._metrics_namespace, |
| load_model_at_runtime, |
| self._model_tag), |
| self._inference_args, |
| beam.pvalue.AsSingleton( |
| self._model_metadata_pcoll, |
| ) if self._model_metadata_pcoll else None).with_resource_hints( |
| **resource_hints) |
| |
| if self._with_exception_handling: |
| # On timeouts, report back to the central model metadata |
| # that the model is invalid |
| model_tag = self._model_tag |
| share_across_processes = self._model_handler.share_model_across_processes( |
| ) |
| timeout = self._timeout |
| |
| def failure_callback(exception: Exception, element: Any): |
| if type(exception) is not TimeoutError: |
| return |
| model_metadata = load_model_status(model_tag, share_across_processes) |
| model_metadata.try_mark_current_model_invalid(timeout) |
| logging.warning( |
| "Inference failed with a timeout, marking the current " + |
| "model for garbage collection") |
| |
| callback = None |
| if (self._timeout is not None and |
| self._model_handler.should_garbage_collect_on_timeout()): |
| callback = failure_callback |
| results, bad_inference = ( |
| batched_elements_pcoll |
| | 'BeamML_RunInference' >> |
| run_inference_pardo.with_exception_handling( |
| exc_class=self._exc_class, |
| use_subprocess=self._use_subprocess, |
| threshold=self._threshold, |
| timeout = self._timeout, |
| on_failure_callback=callback)) |
| else: |
| results = ( |
| batched_elements_pcoll |
| | 'BeamML_RunInference' >> run_inference_pardo) |
| |
| results, bad_postprocessed = self._apply_fns( |
| results, postprocess_fns, 'BeamML_RunInference_Postprocess') |
| |
| if self._with_exception_handling: |
| dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed) |
| return results, dlq |
| |
| return results |
| |
| def with_exception_handling( |
| self, |
| *, |
| exc_class=Exception, |
| use_subprocess=False, |
| threshold=1, |
| timeout: Optional[int] = None): |
| """Automatically provides a dead letter output for skipping bad records. |
| This can allow a pipeline to continue successfully rather than fail or |
| continuously throw errors on retry when bad elements are encountered. |
| |
| This returns a tagged output with two PCollections, the first being the |
| results of successfully processing the input PCollection, and the second |
| being the set of bad batches of records (those which threw exceptions |
| during processing) along with information about the errors raised. |
| |
| For example, one would write:: |
| |
| main, other = RunInference( |
| maybe_error_raising_model_handler |
| ).with_exception_handling() |
| |
| and `main` will be a PCollection of PredictionResults and `other` will |
| contain a `RunInferenceDLQ` object with PCollections containing failed |
| records for each failed inference, preprocess operation, or postprocess |
| operation. To access each collection of failed records, one would write: |
| |
| failed_inferences = other.failed_inferences |
| failed_preprocessing = other.failed_preprocessing |
| failed_postprocessing = other.failed_postprocessing |
| |
| failed_inferences is in the form |
| PCollection[tuple[failed batch, exception]]. |
| |
| failed_preprocessing is in the form |
| list[PCollection[tuple[failed record, exception]]]], where each element of |
| the list corresponds to a preprocess function. These PCollections are |
| in the same order that the preprocess functions are applied. |
| |
| failed_postprocessing is in the form |
| list[PCollection[tuple[failed record, exception]]]], where each element of |
| the list corresponds to a postprocess function. These PCollections are |
| in the same order that the postprocess functions are applied. |
| |
| |
| Args: |
| exc_class: An exception class, or tuple of exception classes, to catch. |
| Optional, defaults to 'Exception'. |
| use_subprocess: Whether to execute the DoFn logic in a subprocess. This |
| allows one to recover from errors that can crash the calling process |
| (e.g. from an underlying library causing a segfault), but is |
| slower as elements and results must cross a process boundary. Note |
| that this starts up a long-running process that is used to handle |
| all the elements (until hard failure, which should be rare) rather |
| than a new process per element, so the overhead should be minimal |
| (and can be amortized if there's any per-process or per-bundle |
| initialization that needs to be done). Optional, defaults to False. |
| threshold: An upper bound on the ratio of records that can be bad before |
| aborting the entire pipeline. Optional, defaults to 1.0 (meaning |
| up to 100% of records can be bad and the pipeline will still succeed). |
| timeout: The maximum amount of time in seconds given to load a model, run |
| inference on a batch of elements and perform and pre/postprocessing |
| operations. Since the timeout applies in multiple places, it should |
| be equal to the maximum possible timeout for any of these operations. |
| Note in particular that model load and inference on a single batch |
| count to the same timeout value. When an inference fails, all related |
| resources, including the model, will be deleted and reloaded. As a |
| result, it is recommended to leave significant buffer and set the |
| timeout to at least `2 * (time to load model + time to run |
| inference on a batch of data)`. |
| """ |
| self._with_exception_handling = True |
| self._exc_class = exc_class |
| self._use_subprocess = use_subprocess |
| self._threshold = threshold |
| self._timeout = timeout |
| return self |
| |
| |
| class _MetricsCollector: |
| """ |
| A metrics collector that tracks ML related performance and memory usage. |
| """ |
| def __init__(self, namespace: str, prefix: str = ''): |
| """ |
| Args: |
| namespace: Namespace for the metrics. |
| prefix: Unique identifier for metrics, used when models |
| are updated using side input. |
| """ |
| # Metrics |
| if prefix: |
| prefix = f'{prefix}_' |
| self._inference_counter = beam.metrics.Metrics.counter( |
| namespace, prefix + 'num_inferences') |
| self.failed_batches_counter = beam.metrics.Metrics.counter( |
| namespace, prefix + 'failed_batches_counter') |
| self._inference_request_batch_size = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_request_batch_size') |
| self._inference_request_batch_byte_size = ( |
| beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_request_batch_byte_size')) |
| # Batch inference latency in microseconds. |
| self._inference_batch_latency_micro_secs = ( |
| beam.metrics.Metrics.distribution( |
| namespace, prefix + 'inference_batch_latency_micro_secs')) |
| self._model_byte_size = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'model_byte_size') |
| # Model load latency in milliseconds. |
| self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution( |
| namespace, prefix + 'load_model_latency_milli_secs') |
| |
| # Metrics cache |
| self._load_model_latency_milli_secs_cache = None |
| self._model_byte_size_cache = None |
| |
| def update_metrics_with_cache(self): |
| if self._load_model_latency_milli_secs_cache is not None: |
| self._load_model_latency_milli_secs.update( |
| self._load_model_latency_milli_secs_cache) |
| self._load_model_latency_milli_secs_cache = None |
| if self._model_byte_size_cache is not None: |
| self._model_byte_size.update(self._model_byte_size_cache) |
| self._model_byte_size_cache = None |
| |
| def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size): |
| self._load_model_latency_milli_secs_cache = load_model_latency_ms |
| self._model_byte_size_cache = model_byte_size |
| |
| def update_load_model_metrics(self, load_model_latency_ms, model_byte_size): |
| self._load_model_latency_milli_secs.update(load_model_latency_ms) |
| self._model_byte_size.update(model_byte_size) |
| |
| def update( |
| self, |
| examples_count: int, |
| examples_byte_size: int, |
| latency_micro_secs: int): |
| self._inference_batch_latency_micro_secs.update(latency_micro_secs) |
| self._inference_counter.inc(examples_count) |
| self._inference_request_batch_size.update(examples_count) |
| self._inference_request_batch_byte_size.update(examples_byte_size) |
| |
| |
| class _ModelRoutingStrategy(): |
| """A class meant to sit in a shared location for mapping incoming batches to |
| different models. Currently only supports round-robin, but can be extended |
| to support other protocols if needed. |
| """ |
| def __init__(self): |
| self._cur_index = 0 |
| |
| def next_model_index(self, num_models): |
| self._cur_index = (self._cur_index + 1) % num_models |
| return self._cur_index |
| |
| |
| class _ModelStatus(): |
| """A class holding any metadata about a model required by RunInference. |
| |
| Currently, this only includes whether or not the model is valid. Uses the |
| model tag to map models to metadata. |
| """ |
| def __init__(self, share_model_across_processes: bool): |
| self._active_tags = set() |
| self._invalid_tags = set() |
| self._tag_mapping = {} |
| self._model_first_seen = {} |
| self._pending_hard_delete = [] |
| self._share_model_across_process = share_model_across_processes |
| |
| def try_mark_current_model_invalid(self, min_model_life_seconds): |
| """Mark the current model invalid. |
| |
| Since we don't have sufficient information to say which model is being |
| marked invalid, but there may be multiple active models, we will mark all |
| models currently in use as inactive so that they all get reloaded. To |
| avoid thrashing, however, we will only mark models as invalid if they've |
| been active at least min_model_life_seconds seconds. |
| """ |
| cutoff_time = datetime.now() - timedelta(seconds=min_model_life_seconds) |
| for tag in list(self._active_tags): |
| if cutoff_time >= self._model_first_seen[tag]: |
| self._invalid_tags.add(tag) |
| # Delete old models after a grace period of 2 * the model life. |
| # This already happens automatically for shared.Shared models, so |
| # cleanup is only necessary for multi_process_shared models. |
| if self._share_model_across_process: |
| self._pending_hard_delete.append(( |
| tag, |
| datetime.now() + 2 * timedelta(seconds=min_model_life_seconds))) |
| self._active_tags.remove(tag) |
| |
| def get_valid_tag(self, tag: str) -> str: |
| """Takes in a proposed valid tag and returns a valid one. |
| |
| Will always return a valid tag. If the passed in tag is valid, this |
| function will simply return it, otherwise it will deterministically |
| generate a new tag to use instead. The new tag will be the original tag |
| with an incrementing suffix (e.g. `my_tag_reload_1`, `my_tag_reload_2`) |
| for each reload |
| """ |
| if tag not in self._invalid_tags: |
| if tag not in self._model_first_seen: |
| self._model_first_seen[tag] = datetime.now() |
| self._active_tags.add(tag) |
| return tag |
| if (tag in self._tag_mapping and |
| self._tag_mapping[tag] not in self._invalid_tags): |
| return self._tag_mapping[tag] |
| i = 1 |
| new_tag = f'{tag}_reload_{i}' |
| while new_tag in self._invalid_tags: |
| i += 1 |
| new_tag = f'{tag}_reload_{i}' |
| self._tag_mapping[tag] = new_tag |
| self._model_first_seen[new_tag] = datetime.now() |
| self._active_tags.add(new_tag) |
| return new_tag |
| |
| def is_valid_tag(self, tag: str) -> bool: |
| return tag == self.get_valid_tag(tag) |
| |
| def get_tags_for_garbage_collection(self) -> list[str]: |
| # Since this function may be in multi_process_shared space, delegate model |
| # deletion to the calling process which is not to avoid having a |
| # multi_process_shared reference in multi_process_shared space, which |
| # can create issues with python's multiprocessing module. |
| # We will rely on the calling process to report back deleted models so that |
| # we can confirm deletion. |
| to_delete = [] |
| cur_time = datetime.now() |
| for i in range(len(self._pending_hard_delete)): |
| delete_time = self._pending_hard_delete[i][1] |
| tag = self._pending_hard_delete[i][0] |
| if delete_time <= cur_time: |
| to_delete.append(tag) |
| else: |
| # early return once we hit a model which was added later since models |
| # are added in order. |
| return to_delete |
| |
| return to_delete |
| |
| def mark_tags_deleted(self, deleted_tags: set[str]): |
| while len(self._pending_hard_delete) > 0: |
| tag = self._pending_hard_delete[0][0] |
| if tag in deleted_tags: |
| self._pending_hard_delete.pop(0) |
| else: |
| return |
| |
| |
| def load_model_status( |
| model_tag: str, share_across_processes: bool) -> _ModelStatus: |
| tag = f'{model_tag}_model_status' |
| if share_across_processes: |
| return multi_process_shared.MultiProcessShared( |
| lambda: _ModelStatus(True), tag=tag, always_proxy=True).acquire() |
| return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) |
| |
| |
| class _SharedModelWrapper(): |
| """A router class to map incoming calls to the correct model. |
| |
| This allows us to round robin calls to models sitting in different |
| processes so that we can more efficiently use resources (e.g. GPUs). |
| """ |
| def __init__(self, models: list[Any], model_tag: str): |
| self.models = models |
| if len(models) > 1: |
| self.model_router = multi_process_shared.MultiProcessShared( |
| lambda: _ModelRoutingStrategy(), |
| tag=f'{model_tag}_counter', |
| always_proxy=True).acquire() |
| |
| def next_model(self): |
| if len(self.models) == 1: |
| # Short circuit if there's no routing strategy needed in order to |
| # avoid the cross-process call |
| return self.models[0] |
| |
| return self.models[self.model_router.next_model_index(len(self.models))] |
| |
| def all_models(self): |
| return self.models |
| |
| |
| class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): |
| def __init__( |
| self, |
| model_handler: ModelHandler[ExampleT, PredictionT, Any], |
| clock, |
| metrics_namespace, |
| load_model_at_runtime: bool = False, |
| model_tag: str = "RunInference"): |
| """A DoFn implementation generic to frameworks. |
| |
| Args: |
| model_handler: An implementation of ModelHandler. |
| clock: A clock implementing time_ns. *Used for unit testing.* |
| metrics_namespace: Namespace of the transform to collect metrics. |
| load_model_at_runtime: Bool to indicate if model loading should be |
| deferred to runtime - for example if we are depending on side |
| inputs to get the model path or we want to enforce a timeout on |
| model loading. |
| model_tag: Tag to use to disambiguate models in multi-model settings. |
| """ |
| self._model_handler = model_handler |
| self._shared_model_handle = shared.Shared() |
| self._clock = clock |
| self._model = None |
| self._metrics_namespace = metrics_namespace |
| self._load_model_at_runtime = load_model_at_runtime |
| self._side_input_path = None |
| # _model_tag is the original tag passed in. |
| # _cur_tag is the tag of the actually loaded model |
| self._model_tag = model_tag |
| self._cur_tag = model_tag |
| |
| def _load_model( |
| self, |
| side_input_model_path: Optional[Union[str, |
| list[KeyModelPathMapping]]] = None |
| ) -> _SharedModelWrapper: |
| def load(): |
| """Function for constructing shared LoadedModel.""" |
| memory_before = _get_current_process_memory_in_bytes() |
| start_time = _to_milliseconds(self._clock.time_ns()) |
| if isinstance(side_input_model_path, str): |
| self._model_handler.update_model_path(side_input_model_path) |
| else: |
| if self._model is not None: |
| models = self._model.all_models() |
| for m in models: |
| self._model_handler.update_model_paths(m, side_input_model_path) |
| model = self._model_handler.load_model() |
| end_time = _to_milliseconds(self._clock.time_ns()) |
| memory_after = _get_current_process_memory_in_bytes() |
| load_model_latency_ms = end_time - start_time |
| model_byte_size = memory_after - memory_before |
| if self._metrics_collector: |
| self._metrics_collector.cache_load_model_metrics( |
| load_model_latency_ms, model_byte_size) |
| return model |
| |
| # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing |
| # model. |
| model_tag = self._model_tag |
| if isinstance(side_input_model_path, str) and side_input_model_path != '': |
| model_tag = side_input_model_path |
| # Ensure the tag we're loading is valid, if not replace it with a valid tag |
| self._cur_tag = self._model_metadata.get_valid_tag(model_tag) |
| if self._model_handler.share_model_across_processes(): |
| models = [] |
| for copy_tag in _get_tags_for_copies(self._cur_tag, |
| self._model_handler.model_copies()): |
| models.append( |
| multi_process_shared.MultiProcessShared( |
| load, tag=copy_tag, always_proxy=True).acquire()) |
| model_wrapper = _SharedModelWrapper(models, self._cur_tag) |
| else: |
| model = self._shared_model_handle.acquire(load, tag=self._cur_tag) |
| model_wrapper = _SharedModelWrapper([model], self._cur_tag) |
| # since shared_model_handle is shared across threads, the model path |
| # might not get updated in the model handler |
| # because we directly get cached weak ref model from shared cache, instead |
| # of calling load(). For sanity check, call update_model_path again. |
| if isinstance(side_input_model_path, str): |
| self._model_handler.update_model_path(side_input_model_path) |
| else: |
| if self._model is not None: |
| models = self._model.all_models() |
| for m in models: |
| self._model_handler.update_model_paths(m, side_input_model_path) |
| return model_wrapper |
| |
| def get_metrics_collector(self, prefix: str = ''): |
| """ |
| Args: |
| prefix: Unique identifier for metrics, used when models |
| are updated using side input. |
| """ |
| metrics_namespace = ( |
| self._metrics_namespace) if self._metrics_namespace else ( |
| self._model_handler.get_metrics_namespace()) |
| if self._model_handler.override_metrics(metrics_namespace): |
| return None |
| return _MetricsCollector(metrics_namespace, prefix=prefix) |
| |
| def setup(self): |
| self._metrics_collector = self.get_metrics_collector() |
| self._model_handler.set_environment_vars() |
| self._model_metadata = load_model_status( |
| self._model_tag, self._model_handler.share_model_across_processes()) |
| if not self._load_model_at_runtime: |
| self._model = self._load_model() |
| |
| def update_model( |
| self, |
| side_input_model_path: Optional[Union[str, |
| list[KeyModelPathMapping]]] = None): |
| self._model = self._load_model(side_input_model_path=side_input_model_path) |
| |
| def _run_inference(self, batch, inference_args): |
| start_time = _to_microseconds(self._clock.time_ns()) |
| try: |
| model = self._model.next_model() |
| result_generator = self._model_handler.run_inference( |
| batch, model, inference_args) |
| except BaseException as e: |
| if self._metrics_collector: |
| self._metrics_collector.failed_batches_counter.inc() |
| if (e is pickle.PickleError and |
| self._model_handler.share_model_across_processes()): |
| raise TypeError( |
| 'Pickling error encountered while running inference. ' |
| 'This may be caused by trying to send unpickleable ' |
| 'data to a model which is shared across processes. ' |
| 'For more information, see ' |
| 'https://beam.apache.org/documentation/ml/large-language-modeling/#pickling-errors' # pylint: disable=line-too-long |
| ) from e |
| raise e |
| predictions = list(result_generator) |
| |
| end_time = _to_microseconds(self._clock.time_ns()) |
| inference_latency = end_time - start_time |
| num_bytes = self._model_handler.get_num_bytes(batch) |
| num_elements = len(batch) |
| if self._metrics_collector: |
| self._metrics_collector.update(num_elements, num_bytes, inference_latency) |
| |
| return predictions |
| |
| def process( |
| self, |
| batch, |
| inference_args, |
| si_model_metadata: Optional[Union[ModelMetadata, |
| list[ModelMetadata], |
| list[KeyModelPathMapping]]]): |
| """ |
| When side input is enabled: |
| The method checks if the side input model has been updated, and if so, |
| updates the model and runs inference on the batch of data. If the |
| side input is empty or the model has not been updated, the method |
| simply runs inference on the batch of data. |
| """ |
| if not si_model_metadata: |
| if (not self._model_metadata.is_valid_tag(self._cur_tag) or |
| self._model is None): |
| self.update_model(side_input_model_path=None) |
| return self._run_inference(batch, inference_args) |
| |
| if isinstance(si_model_metadata, beam.pvalue.EmptySideInput): |
| self.update_model(side_input_model_path=None) |
| elif isinstance(si_model_metadata, list) and hasattr(si_model_metadata[0], |
| 'keys'): |
| # TODO(https://github.com/apache/beam/issues/27628): Update metrics here |
| self.update_model(si_model_metadata) |
| elif self._side_input_path != si_model_metadata.model_id: |
| self._side_input_path = si_model_metadata.model_id |
| self._metrics_collector = self.get_metrics_collector( |
| prefix=si_model_metadata.model_name) |
| lock = threading.Lock() |
| with lock: |
| self.update_model(si_model_metadata.model_id) |
| return self._run_inference(batch, inference_args) |
| |
| return self._run_inference(batch, inference_args) |
| |
| def finish_bundle(self): |
| # TODO(https://github.com/apache/beam/issues/21435): Figure out why there |
| # is a cache. |
| if self._metrics_collector: |
| self._metrics_collector.update_metrics_with_cache() |
| |
| # Do garbage collection of old models |
| tags_to_gc = self._model_metadata.get_tags_for_garbage_collection() |
| if len(tags_to_gc) > 0: |
| for unprefixed_tag in tags_to_gc: |
| for tag in _get_tags_for_copies(unprefixed_tag, |
| self._model_handler.model_copies()): |
| multi_process_shared.MultiProcessShared(lambda: None, |
| tag).unsafe_hard_delete() |
| self._model_metadata.mark_tags_deleted(tags_to_gc) |
| |
| |
| def _is_darwin() -> bool: |
| return sys.platform == 'darwin' |
| |
| |
| def _get_current_process_memory_in_bytes(): |
| """ |
| Returns: |
| memory usage in bytes. |
| """ |
| |
| if resource is not None: |
| usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss |
| if _is_darwin(): |
| return usage |
| return usage * 1024 |
| else: |
| logging.warning( |
| 'Resource module is not available for current platform, ' |
| 'memory usage cannot be fetched.') |
| return 0 |
| |
| |
| def _get_tags_for_copies(base_tag, num_copies): |
| tags = [] |
| for i in range(num_copies): |
| tags.append(f'{base_tag}{i}') |
| return tags |