blob: 8aa638ebe7c2049d721a44f0fefe495215238ae4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import unittest
try:
from google.api_core.exceptions import TooManyRequests
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
from apache_beam.ml.inference.vertex_ai_inference import _retry_on_appropriate_gcp_error
except ImportError:
raise unittest.SkipTest('VertexAI dependencies are not installed')
class RetryOnClientErrorTest(unittest.TestCase):
def test_retry_on_client_error_positive(self):
e = TooManyRequests(message="fake service rate limiting")
self.assertTrue(_retry_on_appropriate_gcp_error(e))
def test_retry_on_client_error_negative(self):
e = ValueError()
self.assertFalse(_retry_on_appropriate_gcp_error(e))
class ModelHandlerArgConditions(unittest.TestCase):
def test_exception_on_private_without_network(self):
self.assertRaises(
ValueError,
VertexAIModelHandlerJSON,
endpoint_id="1",
project="testproject",
location="us-central1",
private=True)
class ParseInvokeResponseTest(unittest.TestCase):
"""Tests for _parse_invoke_response method."""
def _create_handler_with_invoke_route(self, invoke_route="/test"):
"""Creates a mock handler with invoke_route for testing."""
import unittest.mock as mock
# Mock both _retrieve_endpoint and aiplatform.init to prevent test
# pollution of global aiplatform state
with mock.patch.object(VertexAIModelHandlerJSON,
'_retrieve_endpoint',
return_value=None):
with mock.patch('google.cloud.aiplatform.init'):
handler = VertexAIModelHandlerJSON(
endpoint_id="1",
project="testproject",
location="us-central1",
invoke_route=invoke_route)
return handler
def test_parse_invoke_response_with_predictions_key(self):
"""Test parsing response with standard 'predictions' key."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}, {"input": "test2"}]
response = (
b'{"predictions": ["result1", "result2"], '
b'"deployedModelId": "model123"}')
results = list(handler._parse_invoke_response(batch, response))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].example, {"input": "test1"})
self.assertEqual(results[0].inference, "result1")
self.assertEqual(results[1].example, {"input": "test2"})
self.assertEqual(results[1].inference, "result2")
def test_parse_invoke_response_list_format(self):
"""Test parsing response as a list of predictions."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}, {"input": "test2"}]
response = b'["result1", "result2"]'
results = list(handler._parse_invoke_response(batch, response))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].inference, "result1")
self.assertEqual(results[1].inference, "result2")
def test_parse_invoke_response_single_prediction(self):
"""Test parsing response with a single prediction."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}]
response = b'{"output": "single result"}'
results = list(handler._parse_invoke_response(batch, response))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].inference, {"output": "single result"})
def test_parse_invoke_response_non_json(self):
"""Test handling non-JSON response."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}]
response = b'not valid json'
results = list(handler._parse_invoke_response(batch, response))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].inference, response)
if __name__ == '__main__':
unittest.main()