| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import json |
| import logging |
| from collections.abc import Callable |
| from collections.abc import Iterable |
| from collections.abc import Mapping |
| from collections.abc import Sequence |
| from typing import Any |
| from typing import Optional |
| |
| from google.api_core.exceptions import ServerError |
| from google.api_core.exceptions import TooManyRequests |
| from google.cloud import aiplatform |
| |
| from apache_beam.ml.inference import utils |
| from apache_beam.ml.inference.base import PredictionResult |
| from apache_beam.ml.inference.base import RemoteModelHandler |
| |
| LOGGER = logging.getLogger("VertexAIModelHandlerJSON") |
| |
| # pylint: disable=line-too-long |
| |
| |
| def _retry_on_appropriate_gcp_error(exception): |
| """ |
| Retry filter that returns True if a returned HTTP error code is 5xx or 429. |
| This is used to retry remote requests that fail, most notably 429 |
| (TooManyRequests.) |
| |
| Args: |
| exception: the returned exception encountered during the request/response |
| loop. |
| |
| Returns: |
| boolean indication whether or not the exception is a Server Error (5xx) or |
| a TooManyRequests (429) error. |
| """ |
| return isinstance(exception, (TooManyRequests, ServerError)) |
| |
| |
| class VertexAIModelHandlerJSON(RemoteModelHandler[Any, |
| PredictionResult, |
| aiplatform.Endpoint]): |
| def __init__( |
| self, |
| endpoint_id: str, |
| project: str, |
| location: str, |
| experiment: Optional[str] = None, |
| network: Optional[str] = None, |
| private: bool = False, |
| invoke_route: Optional[str] = None, |
| *, |
| min_batch_size: Optional[int] = None, |
| max_batch_size: Optional[int] = None, |
| max_batch_duration_secs: Optional[int] = None, |
| max_batch_weight: Optional[int] = None, |
| element_size_fn: Optional[Callable[[Any], int]] = None, |
| **kwargs): |
| """Implementation of the ModelHandler interface for Vertex AI. |
| **NOTE:** This API and its implementation are under development and |
| do not provide backward compatibility guarantees. |
| Unlike other ModelHandler implementations, this does not load the model |
| being used onto the worker and instead makes remote queries to a |
| Vertex AI endpoint. In that way it functions more like a mid-pipeline |
| IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB. |
| If you wish to make larger requests and use a private endpoint, provide |
| the Compute Engine network you wish to use and set `private=True` |
| |
| Args: |
| endpoint_id: the numerical ID of the Vertex AI endpoint to query |
| project: the GCP project name where the endpoint is deployed |
| location: the GCP location where the endpoint is deployed |
| experiment: optional. experiment label to apply to the |
| queries. See |
| https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments |
| for more information. |
| network: optional. the full name of the Compute Engine |
| network the endpoint is deployed on; used for private |
| endpoints. The network or subnetwork Dataflow pipeline |
| option must be set and match this network for pipeline |
| execution. |
| Ex: "projects/12345/global/networks/myVPC" |
| private: optional. if the deployed Vertex AI endpoint is |
| private, set to true. Requires a network to be provided |
| as well. |
| invoke_route: optional. the custom route path to use when invoking |
| endpoints with arbitrary prediction routes. When specified, uses |
| `Endpoint.invoke()` instead of `Endpoint.predict()`. The route |
| should start with a forward slash, e.g., "/predict/v1". |
| See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes |
| for more information. |
| min_batch_size: optional. the minimum batch size to use when batching |
| inputs. |
| max_batch_size: optional. the maximum batch size to use when batching |
| inputs. |
| max_batch_duration_secs: optional. the maximum amount of time to buffer |
| a batch before emitting; used in streaming contexts. |
| max_batch_weight: optional. the maximum total weight of a batch. |
| element_size_fn: optional. a function that returns the size (weight) |
| of an element. |
| """ |
| self._batching_kwargs = {} |
| self._env_vars = kwargs.get('env_vars', {}) |
| self._invoke_route = invoke_route |
| if min_batch_size is not None: |
| self._batching_kwargs["min_batch_size"] = min_batch_size |
| if max_batch_size is not None: |
| self._batching_kwargs["max_batch_size"] = max_batch_size |
| if max_batch_duration_secs is not None: |
| self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs |
| if max_batch_weight is not None: |
| self._batching_kwargs["max_batch_weight"] = max_batch_weight |
| if element_size_fn is not None: |
| self._batching_kwargs['element_size_fn'] = element_size_fn |
| |
| if private and network is None: |
| raise ValueError( |
| "A VPC network must be provided to use a private endpoint.") |
| |
| # TODO: support the full list of options for aiplatform.init() |
| # See https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init |
| aiplatform.init( |
| project=project, |
| location=location, |
| experiment=experiment, |
| network=network) |
| |
| # Check for liveness here but don't try to actually store the endpoint |
| # in the class yet |
| self.endpoint_name = endpoint_id |
| self.location = location |
| self.is_private = private |
| |
| _ = self._retrieve_endpoint( |
| self.endpoint_name, self.location, self.is_private) |
| |
| super().__init__( |
| namespace='VertexAIModelHandlerJSON', |
| retry_filter=_retry_on_appropriate_gcp_error, |
| **kwargs) |
| |
| def _retrieve_endpoint( |
| self, endpoint_id: str, location: str, |
| is_private: bool) -> aiplatform.Endpoint: |
| """Retrieves an AI Platform endpoint and queries it for liveness/deployed |
| models. |
| |
| Args: |
| endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve. |
| is_private: a boolean indicating if the Vertex AI endpoint is a private |
| endpoint |
| Returns: |
| An aiplatform.Endpoint object |
| Raises: |
| ValueError: if endpoint is inactive or has no models deployed to it. |
| """ |
| if is_private: |
| endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint( |
| endpoint_name=endpoint_id, location=location) |
| LOGGER.debug("Treating endpoint %s as private", endpoint_id) |
| else: |
| endpoint = aiplatform.Endpoint( |
| endpoint_name=endpoint_id, location=location) |
| LOGGER.debug("Treating endpoint %s as public", endpoint_id) |
| |
| try: |
| mod_list = endpoint.list_models() |
| except Exception as e: |
| raise ValueError( |
| "Failed to contact endpoint %s, got exception: %s", endpoint_id, e) |
| |
| if len(mod_list) == 0: |
| raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id) |
| |
| return endpoint |
| |
| def create_client(self) -> aiplatform.Endpoint: |
| """Loads the Endpoint object used to build and send prediction request to |
| Vertex AI. |
| """ |
| # Check to make sure the endpoint is still active since pipeline |
| # construction time |
| ep = self._retrieve_endpoint( |
| self.endpoint_name, self.location, self.is_private) |
| return ep |
| |
| def request( |
| self, |
| batch: Sequence[Any], |
| model: aiplatform.Endpoint, |
| inference_args: Optional[dict[str, Any]] = None |
| ) -> Iterable[PredictionResult]: |
| """ Sends a prediction request to a Vertex AI endpoint containing batch |
| of inputs and matches that input with the prediction response from |
| the endpoint as an iterable of PredictionResults. |
| |
| Args: |
| batch: a sequence of any values to be passed to the Vertex AI endpoint. |
| Should be encoded as the model expects. |
| model: an aiplatform.Endpoint object configured to access the desired |
| model. |
| inference_args: any additional arguments to send as part of the |
| prediction request. |
| |
| Returns: |
| An iterable of Predictions. |
| """ |
| if self._invoke_route: |
| # Use invoke() for endpoints with custom prediction routes |
| request_body: dict[str, Any] = {"instances": list(batch)} |
| if inference_args: |
| request_body["parameters"] = inference_args |
| response = model.invoke( |
| request_path=self._invoke_route, |
| body=json.dumps(request_body).encode("utf-8"), |
| headers={"Content-Type": "application/json"}) |
| if hasattr(response, "content"): |
| return self._parse_invoke_response(batch, response.content) |
| return self._parse_invoke_response(batch, bytes(response)) |
| else: |
| prediction = model.predict( |
| instances=list(batch), parameters=inference_args) |
| return utils._convert_to_result( |
| batch, prediction.predictions, prediction.deployed_model_id) |
| |
| def _parse_invoke_response(self, batch: Sequence[Any], |
| response: bytes) -> Iterable[PredictionResult]: |
| """Parses the response from Endpoint.invoke() into PredictionResults. |
| |
| Args: |
| batch: the original batch of inputs. |
| response: the raw bytes response from invoke(). |
| |
| Returns: |
| An iterable of PredictionResults. |
| """ |
| try: |
| response_json = json.loads(response.decode("utf-8")) |
| except (json.JSONDecodeError, UnicodeDecodeError) as e: |
| LOGGER.warning( |
| "Failed to decode invoke response as JSON, returning raw bytes: %s", |
| e) |
| # Return raw response for each batch item |
| return [ |
| PredictionResult(example=example, inference=response) |
| for example in batch |
| ] |
| |
| # Handle standard Vertex AI response format with "predictions" key |
| if isinstance(response_json, dict) and "predictions" in response_json: |
| predictions = response_json["predictions"] |
| model_id = response_json.get("deployedModelId") |
| return utils._convert_to_result(batch, predictions, model_id) |
| |
| # Handle response as a list of predictions (one per input) |
| if isinstance(response_json, list) and len(response_json) == len(batch): |
| return utils._convert_to_result(batch, response_json, None) |
| |
| # Handle single prediction response |
| if len(batch) == 1: |
| return [PredictionResult(example=batch[0], inference=response_json)] |
| |
| # Fallback: return the full response for each batch item |
| return [ |
| PredictionResult(example=example, inference=response_json) |
| for example in batch |
| ] |
| |
| def batch_elements_kwargs(self) -> Mapping[str, Any]: |
| return self._batching_kwargs |