| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| # pytype: skip-file |
| |
| import os |
| import secrets |
| import shutil |
| import tempfile |
| import time |
| import unittest |
| from collections.abc import Sequence |
| from dataclasses import dataclass |
| from typing import Any |
| from typing import Optional |
| |
| import numpy as np |
| from parameterized import param |
| from parameterized import parameterized |
| |
| import apache_beam as beam |
| from apache_beam.metrics.metric import MetricsFilter |
| from apache_beam.ml.inference.base import ModelHandler |
| from apache_beam.ml.inference.base import RunInference |
| from apache_beam.ml.transforms import base |
| from apache_beam.testing.util import assert_that |
| from apache_beam.testing.util import equal_to |
| |
| # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports |
| try: |
| from apache_beam.ml.transforms import tft |
| from apache_beam.ml.transforms.handlers import TFTProcessHandler |
| from apache_beam.ml.transforms.tft import TFTOperation |
| except ImportError: |
| tft = None # type: ignore |
| |
| try: |
| import PIL |
| from PIL.Image import Image as PIL_Image |
| except ImportError: |
| PIL = None |
| PIL_Image = Any |
| |
| try: |
| |
| class _FakeOperation(TFTOperation): |
| def __init__(self, name, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.name = name |
| |
| def apply_transform(self, inputs, output_column_name, **kwargs): |
| return {output_column_name: inputs} |
| except: # pylint: disable=bare-except |
| pass |
| |
| try: |
| from apache_beam.runners.dataflow.internal import apiclient |
| except ImportError: |
| apiclient = None # type: ignore |
| |
| |
| class BaseMLTransformTest(unittest.TestCase): |
| def setUp(self) -> None: |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self): |
| shutil.rmtree(self.artifact_location) |
| |
| def test_ml_transform_no_read_or_write_artifact_lcoation(self): |
| with self.assertRaises(ValueError): |
| _ = base.MLTransform(transforms=[]) |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_appends_transforms_to_process_handler_correctly(self): |
| fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) |
| transforms = [fake_fn_1] |
| ml_transform = base.MLTransform( |
| transforms=transforms, write_artifact_location=self.artifact_location) |
| ml_transform = ml_transform.with_transform( |
| transform=_FakeOperation(name='fake_fn_2', columns=['x'])) |
| |
| self.assertEqual(len(ml_transform.transforms), 2) |
| self.assertEqual(ml_transform.transforms[0].name, 'fake_fn_1') |
| self.assertEqual(ml_transform.transforms[1].name, 'fake_fn_2') |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_on_dict(self): |
| transforms = [tft.ScaleTo01(columns=['x'])] |
| data = [{'x': 1}, {'x': 2}] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location, |
| transforms=transforms)) |
| expected_output = [ |
| np.array([0.0], dtype=np.float32), |
| np.array([1.0], dtype=np.float32), |
| ] |
| actual_output = result | beam.Map(lambda x: x.x) |
| assert_that( |
| actual_output, equal_to(expected_output, equals_fn=np.array_equal)) |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_on_list_dict(self): |
| transforms = [tft.ScaleTo01(columns=['x'])] |
| data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| transforms=transforms, |
| write_artifact_location=self.artifact_location)) |
| expected_output = [ |
| np.array([0, 0.2, 0.4], dtype=np.float32), |
| np.array([0.6, 0.8, 1], dtype=np.float32), |
| ] |
| actual_output = result | beam.Map(lambda x: x.x) |
| assert_that( |
| actual_output, equal_to(expected_output, equals_fn=np.array_equal)) |
| |
| @parameterized.expand([ |
| param( |
| input_data=[{ |
| 'x': 1, |
| 'y': 2.0, |
| }], |
| input_types={ |
| 'x': int, 'y': float |
| }, |
| expected_dtype={ |
| 'x': Sequence[np.float32], |
| 'y': Sequence[np.float32], |
| }, |
| ), |
| param( |
| input_data=[{ |
| 'x': np.array([1], dtype=np.int64), |
| 'y': np.array([2.0], dtype=np.float32), |
| }], |
| input_types={ |
| 'x': np.int32, 'y': np.float32 |
| }, |
| expected_dtype={ |
| 'x': Sequence[np.float32], |
| 'y': Sequence[np.float32], |
| }, |
| ), |
| param( |
| input_data=[{ |
| 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] |
| }], |
| input_types={ |
| 'x': list[int], 'y': list[float] |
| }, |
| expected_dtype={ |
| 'x': Sequence[np.float32], |
| 'y': Sequence[np.float32], |
| }, |
| ), |
| param( |
| input_data=[{ |
| 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] |
| }], |
| input_types={ |
| 'x': Sequence[int], |
| 'y': Sequence[float], |
| }, |
| expected_dtype={ |
| 'x': Sequence[np.float32], |
| 'y': Sequence[np.float32], |
| }, |
| ), |
| ]) |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_dict_output_pcoll_schema( |
| self, input_data, input_types, expected_dtype): |
| transforms = [tft.ScaleTo01(columns=['x'])] |
| with beam.Pipeline() as p: |
| schema_data = ( |
| p |
| | beam.Create(input_data) |
| | beam.Map(lambda x: beam.Row(**x)).with_output_types( |
| beam.row_type.RowTypeConstraint.from_fields( |
| list(input_types.items())))) |
| transformed_data = schema_data | base.MLTransform( |
| write_artifact_location=self.artifact_location, transforms=transforms) |
| for name, typ in transformed_data.element_type._fields: |
| if name in expected_dtype: |
| self.assertEqual(expected_dtype[name], typ) |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): |
| transforms = [tft.ScaleTo01(columns=['x'])] |
| with beam.Pipeline() as p: |
| with self.assertRaises(RuntimeError): |
| _ = ( |
| p |
| | beam.Create([{ |
| 'x': 1, 'y': 2.0 |
| }]) |
| | beam.WindowInto(beam.window.FixedWindows(1)) |
| | base.MLTransform( |
| transforms=transforms, |
| write_artifact_location=self.artifact_location, |
| )) |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transform_on_multiple_columns_single_transform(self): |
| transforms = [tft.ScaleTo01(columns=['x', 'y'])] |
| data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| transforms=transforms, |
| write_artifact_location=self.artifact_location)) |
| expected_output_x = [ |
| np.array([0, 0.5, 1], dtype=np.float32), |
| ] |
| expected_output_y = [np.array([0, 0.47368422, 1], dtype=np.float32)] |
| actual_output_x = result | beam.Map(lambda x: x.x) |
| actual_output_y = result | beam.Map(lambda x: x.y) |
| assert_that( |
| actual_output_x, |
| equal_to(expected_output_x, equals_fn=np.array_equal)) |
| assert_that( |
| actual_output_y, |
| equal_to(expected_output_y, equals_fn=np.array_equal), |
| label='y') |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_ml_transforms_on_multiple_columns_multiple_transforms(self): |
| transforms = [ |
| tft.ScaleTo01(columns=['x']), |
| tft.ComputeAndApplyVocabulary(columns=['y']) |
| ] |
| data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| transforms=transforms, |
| write_artifact_location=self.artifact_location)) |
| expected_output_x = [ |
| np.array([0, 0.5, 1], dtype=np.float32), |
| ] |
| expected_output_y = [np.array([2, 1, 0])] |
| actual_output_x = result | beam.Map(lambda x: x.x) |
| actual_output_y = result | beam.Map(lambda x: x.y) |
| |
| assert_that( |
| actual_output_x, |
| equal_to(expected_output_x, equals_fn=np.array_equal)) |
| assert_that( |
| actual_output_y, |
| equal_to(expected_output_y, equals_fn=np.array_equal), |
| label='actual_output_y') |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_mltransform_with_counter(self): |
| transforms = [ |
| tft.ComputeAndApplyVocabulary(columns=['y']), |
| tft.ScaleTo01(columns=['x']) |
| ] |
| data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] |
| with beam.Pipeline() as p: |
| _ = ( |
| p | beam.Create(data) |
| | base.MLTransform( |
| transforms=transforms, |
| write_artifact_location=self.artifact_location)) |
| scale_to_01_counter = MetricsFilter().with_name('BeamML_ScaleTo01') |
| vocab_counter = MetricsFilter().with_name( |
| 'BeamML_ComputeAndApplyVocabulary') |
| mltransform_counter = MetricsFilter().with_name('BeamML_MLTransform') |
| result = p.result |
| self.assertEqual( |
| result.metrics().query(scale_to_01_counter)['counters'][0].result, 1) |
| self.assertEqual( |
| result.metrics().query(vocab_counter)['counters'][0].result, 1) |
| self.assertEqual( |
| result.metrics().query(mltransform_counter)['counters'][0].result, 1) |
| |
| def test_non_ptransfrom_provider_class_to_mltransform(self): |
| class Add: |
| def __call__(self, x): |
| return x + 1 |
| |
| with self.assertRaisesRegex(TypeError, 'transform must be a subclass of'): |
| with beam.Pipeline() as p: |
| _ = ( |
| p |
| | beam.Create([{ |
| 'x': 1 |
| }]) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| Add())) |
| |
| def test_read_mode_with_transforms(self): |
| with self.assertRaises(ValueError): |
| _ = base.MLTransform( |
| # fake callable |
| transforms=[lambda x: x], |
| read_artifact_location=self.artifact_location) |
| |
| |
| class FakeModel: |
| def __call__(self, example: list[str]) -> list[str]: |
| for i in range(len(example)): |
| if not isinstance(example[i], str): |
| raise TypeError('Input must be a string') |
| example[i] = example[i][::-1] |
| return example |
| |
| |
| class FakeModelHandler(ModelHandler): |
| def run_inference( |
| self, |
| batch: Sequence[str], |
| model: Any, |
| inference_args: Optional[dict[str, Any]] = None): |
| return model(batch) |
| |
| def load_model(self): |
| return FakeModel() |
| |
| |
| class FakeEmbeddingsManager(base.EmbeddingsManager): |
| def __init__(self, columns, **kwargs): |
| super().__init__(columns=columns, **kwargs) |
| |
| def get_model_handler(self) -> ModelHandler: |
| FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[method-assign] |
| return FakeModelHandler() |
| |
| def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: |
| return (RunInference(model_handler=base._TextEmbeddingHandler(self))) |
| |
| def __repr__(self): |
| return 'FakeEmbeddingsManager' |
| |
| |
| class InvalidEmbeddingsManager(base.EmbeddingsManager): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| |
| def get_model_handler(self) -> ModelHandler: |
| InvalidEmbeddingsManager.__repr__ = lambda x: 'InvalidEmbeddingsManager' # type: ignore[method-assign] |
| return FakeModelHandler() |
| |
| def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: |
| return (RunInference(model_handler=base._TextEmbeddingHandler(self))) |
| |
| def __repr__(self): |
| return 'InvalidEmbeddingsManager' |
| |
| |
| class TextEmbeddingHandlerTest(unittest.TestCase): |
| def setUp(self) -> None: |
| self.embedding_conig = FakeEmbeddingsManager(columns=['x']) |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self) -> None: |
| shutil.rmtree(self.artifact_location) |
| |
| def test_no_columns_or_type_adapter(self): |
| with self.assertRaises(ValueError): |
| _ = InvalidEmbeddingsManager() |
| |
| def test_handler_with_incompatible_datatype(self): |
| text_handler = base._TextEmbeddingHandler( |
| embeddings_manager=self.embedding_conig) |
| data = [ |
| ('x', 1), |
| ('x', 2), |
| ('x', 3), |
| ] |
| with self.assertRaises(TypeError): |
| text_handler.run_inference(data, None, None) |
| |
| def test_handler_with_dict_inputs(self): |
| data = [ |
| { |
| 'x': "Hello world" |
| }, |
| { |
| 'x': "Apache Beam" |
| }, |
| ] |
| expected_data = [{ |
| key: value[::-1] |
| for key, value in d.items() |
| } for d in data] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_conig)) |
| assert_that( |
| result, |
| equal_to(expected_data), |
| ) |
| |
| def test_handler_with_batch_sizes(self): |
| self.embedding_conig.max_batch_size = 100 |
| self.embedding_conig.min_batch_size = 10 |
| data = [ |
| { |
| 'x': "Hello world" |
| }, |
| { |
| 'x': "Apache Beam" |
| }, |
| ] * 100 |
| expected_data = [{ |
| key: value[::-1] |
| for key, value in d.items() |
| } for d in data] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_conig)) |
| assert_that( |
| result, |
| equal_to(expected_data), |
| ) |
| |
| def test_handler_on_multiple_columns(self): |
| data = [ |
| { |
| 'x': "Hello world", 'y': "Apache Beam", 'z': 'unchanged' |
| }, |
| { |
| 'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged' |
| }, |
| ] |
| embedding_config = FakeEmbeddingsManager(columns=['x', 'y']) |
| expected_data = [{ |
| key: (value[::-1] if key in embedding_config.columns else value) |
| for key, value in d.items() |
| } for d in data] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform(write_artifact_location=self.artifact_location). |
| with_transform(embedding_config)) |
| assert_that( |
| result, |
| equal_to(expected_data), |
| ) |
| |
| def test_handler_on_columns_not_exist_in_input_data(self): |
| data = [ |
| { |
| 'x': "Hello world", 'y': "Apache Beam" |
| }, |
| { |
| 'x': "Apache Beam", 'y': "Hello world" |
| }, |
| ] |
| embedding_config = FakeEmbeddingsManager(columns=['x', 'y', 'a']) |
| |
| with self.assertRaises(RuntimeError): |
| with beam.Pipeline() as p: |
| _ = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform(write_artifact_location=self.artifact_location). |
| with_transform(embedding_config)) |
| |
| def test_handler_with_list_data(self): |
| data = [{ |
| 'x': ['Hello world', 'Apache Beam'], |
| }, { |
| 'x': ['Apache Beam', 'Hello world'], |
| }] |
| with self.assertRaisesRegex(Exception, "Embeddings can only be generated"): |
| with beam.Pipeline() as p: |
| _ = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_conig)) |
| |
| def test_handler_with_inconsistent_keys(self): |
| data = [ |
| { |
| 'x': 'foo', 'y': 'bar', 'z': 'baz' |
| }, |
| { |
| 'x': 'foo2', 'y': 'bar2' |
| }, |
| { |
| 'x': 'foo3', 'y': 'bar3', 'z': 'baz3' |
| }, |
| ] |
| self.embedding_conig.min_batch_size = 2 |
| with self.assertRaises(RuntimeError): |
| with beam.Pipeline() as p: |
| _ = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_conig)) |
| |
| |
| class FakeImageModel: |
| def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: |
| for i in range(len(example)): |
| if not isinstance(example[i], PIL_Image): |
| raise TypeError('Input must be an Image') |
| return example |
| |
| |
| class FakeImageModelHandler(ModelHandler): |
| def run_inference( |
| self, |
| batch: Sequence[PIL_Image], |
| model: Any, |
| inference_args: Optional[dict[str, Any]] = None): |
| return model(batch) |
| |
| def load_model(self): |
| return FakeImageModel() |
| |
| |
| class FakeImageEmbeddingsManager(base.EmbeddingsManager): |
| def __init__(self, columns, **kwargs): |
| super().__init__(columns=columns, **kwargs) |
| |
| def get_model_handler(self) -> ModelHandler: |
| FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[method-assign] |
| return FakeImageModelHandler() |
| |
| def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: |
| return (RunInference(model_handler=base._ImageEmbeddingHandler(self))) |
| |
| def __repr__(self): |
| return 'FakeImageEmbeddingsManager' |
| |
| |
| class TestImageEmbeddingHandler(unittest.TestCase): |
| def setUp(self) -> None: |
| self.embedding_config = FakeImageEmbeddingsManager(columns=['x']) |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self) -> None: |
| shutil.rmtree(self.artifact_location) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_non_dict_datatype(self): |
| image_handler = base._ImageEmbeddingHandler( |
| embeddings_manager=self.embedding_config) |
| data = [ |
| ('x', 'hi there'), |
| ('x', 'not an image'), |
| ('x', 'image_path.jpg'), |
| ] |
| with self.assertRaises(TypeError): |
| image_handler.run_inference(data, None, None) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_non_image_datatype(self): |
| image_handler = base._ImageEmbeddingHandler( |
| embeddings_manager=self.embedding_config) |
| data = [ |
| { |
| 'x': 'hi there' |
| }, |
| { |
| 'x': 'not an image' |
| }, |
| { |
| 'x': 'image_path.jpg' |
| }, |
| ] |
| with self.assertRaises(TypeError): |
| image_handler.run_inference(data, None, None) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_dict_inputs(self): |
| img_one = PIL.Image.new(mode='RGB', size=(1, 1)) |
| img_two = PIL.Image.new(mode='RGB', size=(1, 1)) |
| data = [ |
| { |
| 'x': img_one |
| }, |
| { |
| 'x': img_two |
| }, |
| ] |
| expected_data = [{key: value for key, value in d.items()} for d in data] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_config)) |
| assert_that( |
| result, |
| equal_to(expected_data), |
| ) |
| |
| |
| @dataclass |
| class FakeMultiModalInput: |
| image: Optional[PIL_Image] = None |
| video: Optional[Any] = None |
| text: Optional[str] = None |
| |
| |
| class FakeMultiModalModel: |
| def __call__(self, |
| example: list[FakeMultiModalInput]) -> list[FakeMultiModalInput]: |
| for i in range(len(example)): |
| if not isinstance(example[i], FakeMultiModalInput): |
| raise TypeError('Input must be a MultiModalInput') |
| return example |
| |
| |
| class FakeMultiModalModelHandler(ModelHandler): |
| def run_inference( |
| self, |
| batch: Sequence[FakeMultiModalInput], |
| model: Any, |
| inference_args: Optional[dict[str, Any]] = None): |
| return model(batch) |
| |
| def load_model(self): |
| return FakeMultiModalModel() |
| |
| |
| class FakeMultiModalEmbeddingsManager(base.EmbeddingsManager): |
| def __init__(self, columns, **kwargs): |
| super().__init__(columns=columns, **kwargs) |
| |
| def get_model_handler(self) -> ModelHandler: |
| FakeModelHandler.__repr__ = lambda x: 'FakeMultiModalEmbeddingsManager' # type: ignore[method-assign] |
| return FakeMultiModalModelHandler() |
| |
| def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: |
| return (RunInference(model_handler=base._MultiModalEmbeddingHandler(self))) |
| |
| def __repr__(self): |
| return 'FakeMultiModalEmbeddingsManager' |
| |
| |
| class TestMultiModalEmbeddingHandler(unittest.TestCase): |
| def setUp(self) -> None: |
| self.embedding_config = FakeMultiModalEmbeddingsManager(columns=['x']) |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self) -> None: |
| shutil.rmtree(self.artifact_location) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_non_dict_datatype(self): |
| image_handler = base._MultiModalEmbeddingHandler( |
| embeddings_manager=self.embedding_config) |
| data = [ |
| ('x', 'hi there'), |
| ('x', 'not an image'), |
| ('x', 'image_path.jpg'), |
| ] |
| with self.assertRaises(TypeError): |
| image_handler.run_inference(data, None, None) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_incorrect_datatype(self): |
| image_handler = base._MultiModalEmbeddingHandler( |
| embeddings_manager=self.embedding_config) |
| data = [ |
| { |
| 'x': 'hi there' |
| }, |
| { |
| 'x': 'not an image' |
| }, |
| { |
| 'x': 'image_path.jpg' |
| }, |
| ] |
| with self.assertRaises(TypeError): |
| image_handler.run_inference(data, None, None) |
| |
| @unittest.skipIf(PIL is None, 'PIL module is not installed.') |
| def test_handler_with_dict_inputs(self): |
| input_one = FakeMultiModalInput( |
| image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image one") |
| input_two = FakeMultiModalInput( |
| image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image two") |
| input_three = FakeMultiModalInput( |
| image=PIL.Image.new(mode='RGB', size=(1, 1)), |
| video=bytes.fromhex('2Ef0 F1f2 '), |
| text="test image three with video") |
| data = [ |
| { |
| 'x': input_one |
| }, |
| { |
| 'x': input_two |
| }, |
| { |
| 'x': input_three |
| }, |
| ] |
| expected_data = [{key: value for key, value in d.items()} for d in data] |
| with beam.Pipeline() as p: |
| result = ( |
| p |
| | beam.Create(data) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| self.embedding_config)) |
| assert_that( |
| result, |
| equal_to(expected_data), |
| ) |
| |
| |
| class TestUtilFunctions(unittest.TestCase): |
| def test_dict_input_fn_normal(self): |
| input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] |
| columns = ['a', 'b'] |
| |
| expected_output = [1, 2, 3, 4] |
| self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) |
| |
| def test_dict_output_fn_normal(self): |
| input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] |
| columns = ['a', 'b'] |
| embeddings = [1.1, 2.2, 3.3, 4.4] |
| |
| expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] |
| self.assertEqual( |
| base._dict_output_fn(columns, input_list, embeddings), expected_output) |
| |
| def test_dict_input_fn_on_list_inputs(self): |
| input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] |
| columns = ['a', 'b'] |
| |
| expected_output = [[1, 2, 10], 3, [1], 5] |
| self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) |
| |
| def test_dict_output_fn_on_list_inputs(self): |
| input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] |
| columns = ['a', 'b'] |
| embeddings = [1.1, 2.2, 3.3, 4.4] |
| |
| expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] |
| self.assertEqual( |
| base._dict_output_fn(columns, input_list, embeddings), expected_output) |
| |
| |
| class TestJsonPickleTransformAttributeManager(unittest.TestCase): |
| def setUp(self): |
| self.attribute_manager = base._transform_attribute_manager |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self): |
| shutil.rmtree(self.artifact_location) |
| |
| @unittest.skipIf(tft is None, 'tft module is not installed.') |
| def test_save_tft_process_handler(self): |
| transforms = [ |
| tft.ScaleTo01(columns=['x']), |
| tft.ComputeAndApplyVocabulary(columns=['y']) |
| ] |
| process_handler = TFTProcessHandler( |
| transforms=transforms, |
| artifact_location=self.artifact_location, |
| ) |
| self.attribute_manager.save_attributes( |
| ptransform_list=[process_handler], |
| artifact_location=self.artifact_location, |
| ) |
| |
| files = os.listdir(self.artifact_location) |
| self.assertTrue(len(files) == 1) |
| self.assertTrue(files[0] == base._ATTRIBUTE_FILE_NAME) |
| |
| def test_save_run_inference(self): |
| self.attribute_manager.save_attributes( |
| ptransform_list=[RunInference(model_handler=FakeModelHandler())], |
| artifact_location=self.artifact_location, |
| ) |
| files = os.listdir(self.artifact_location) |
| self.assertTrue(len(files) == 1) |
| self.assertTrue(files[0] == base._ATTRIBUTE_FILE_NAME) |
| |
| def test_save_and_load_run_inference(self): |
| ptransform_list = [RunInference(model_handler=FakeModelHandler())] |
| self.attribute_manager.save_attributes( |
| ptransform_list=ptransform_list, |
| artifact_location=self.artifact_location, |
| ) |
| loaded_ptransform_list = self.attribute_manager.load_attributes( |
| artifact_location=self.artifact_location, |
| ) |
| |
| self.assertTrue(len(loaded_ptransform_list) == len(ptransform_list)) |
| self.assertListEqual( |
| list(loaded_ptransform_list[0].__dict__.keys()), |
| list(ptransform_list[0].__dict__.keys())) |
| |
| get_keys = lambda x: list(x.__dict__.keys()) |
| for i, transform in enumerate(ptransform_list): |
| self.assertListEqual( |
| get_keys(transform), get_keys(loaded_ptransform_list[i])) |
| if hasattr(transform, 'model_handler'): |
| model_handler = transform.model_handler |
| loaded_model_handler = loaded_ptransform_list[i].model_handler |
| self.assertListEqual( |
| get_keys(model_handler), get_keys(loaded_model_handler)) |
| |
| def test_mltransform_to_ptransform_wrapper(self): |
| transforms = [ |
| FakeEmbeddingsManager(columns=['x']), |
| FakeEmbeddingsManager(columns=['y', 'z']), |
| ] |
| ptransform_mapper = base._MLTransformToPTransformMapper( |
| transforms=transforms, |
| artifact_location=self.artifact_location, |
| artifact_mode=None) |
| |
| ptransform_list = ptransform_mapper.create_ptransform_list() |
| self.assertTrue(len(ptransform_list) == 2) |
| |
| self.assertEqual(type(ptransform_list[0]), RunInference) |
| expected_columns = [['x'], ['y', 'z']] |
| for i in range(len(ptransform_list)): |
| self.assertEqual(type(ptransform_list[i]), RunInference) |
| self.assertEqual( |
| type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) |
| self.assertEqual( |
| ptransform_list[i]._model_handler.columns, expected_columns[i]) |
| |
| @unittest.skipIf(apiclient is None, 'apache_beam[gcp] is not installed.') |
| def test_with_gcs_location_with_none_options(self): |
| path = f"gs://fake_path_{secrets.token_hex(3)}_{int(time.time())}" |
| with self.assertRaises(RuntimeError): |
| self.attribute_manager.save_attributes( |
| ptransform_list=[], artifact_location=path, options=None) |
| with self.assertRaises(RuntimeError): |
| self.attribute_manager.save_attributes( |
| ptransform_list=[], artifact_location=path) |
| |
| def test_with_same_local_artifact_location(self): |
| artifact_location = self.artifact_location |
| attribute_manager = base._JsonPickleTransformAttributeManager() |
| |
| ptransform_list = [RunInference(model_handler=FakeModelHandler())] |
| |
| attribute_manager.save_attributes( |
| ptransform_list, artifact_location=artifact_location) |
| |
| with self.assertRaises(FileExistsError): |
| attribute_manager.save_attributes([lambda x: x], |
| artifact_location=artifact_location) |
| |
| |
| class MLTransformDLQTest(unittest.TestCase): |
| def setUp(self) -> None: |
| self.artifact_location = tempfile.mkdtemp() |
| |
| def tearDown(self): |
| shutil.rmtree(self.artifact_location) |
| |
| def test_dlq_with_embeddings(self): |
| with beam.Pipeline() as p: |
| good, bad = ( |
| p |
| | beam.Create([{ |
| 'x': 1 |
| }, |
| { |
| 'x': 3, |
| }, |
| { |
| 'x': 'Hello' |
| } |
| ], |
| ) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| FakeEmbeddingsManager( |
| columns=['x'])).with_exception_handling()) |
| |
| good_expected_elements = [{'x': 'olleH'}] |
| |
| assert_that( |
| good, |
| equal_to(good_expected_elements), |
| label='good', |
| ) |
| |
| # batching happens in RunInference hence elements |
| # are in lists in the bad pcoll. |
| bad_expected_elements = [[{'x': 1}], [{'x': 3}]] |
| |
| assert_that( |
| bad | beam.Map(lambda x: x.element), |
| equal_to(bad_expected_elements), |
| label='bad', |
| ) |
| |
| def test_mltransform_with_dlq_and_extract_tranform_name(self): |
| with beam.Pipeline() as p: |
| good, bad = ( |
| p |
| | beam.Create([{ |
| 'x': 1 |
| }, |
| { |
| 'x': 3, |
| }, |
| { |
| 'x': 'Hello' |
| } |
| ], |
| ) |
| | base.MLTransform( |
| write_artifact_location=self.artifact_location).with_transform( |
| FakeEmbeddingsManager( |
| columns=['x'])).with_exception_handling()) |
| |
| good_expected_elements = [{'x': 'olleH'}] |
| assert_that( |
| good, |
| equal_to(good_expected_elements), |
| label='good', |
| ) |
| |
| bad_expected_transform_name = [ |
| 'FakeEmbeddingsManager', 'FakeEmbeddingsManager' |
| ] |
| |
| assert_that( |
| bad | beam.Map(lambda x: x.transform_name), |
| equal_to(bad_expected_transform_name), |
| ) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |