Merge pull request #2 from apache/users/damccorm/runinference_example
diff --git a/README.md b/README.md index ace1d79..3e4e791 100644 --- a/README.md +++ b/README.md
@@ -33,7 +33,7 @@ # Activate the virtual environment managed by Poetry # Alternatively, you can prefix commands with `poetry run` -poetry shell +eval $(poetry env activate) ``` ## Overview
diff --git a/examples/provider_listing.yaml b/examples/provider_listing.yaml index 2ccfbc9..e9a69e7 100644 --- a/examples/provider_listing.yaml +++ b/examples/provider_listing.yaml
@@ -11,3 +11,4 @@ FromRoman: "my_provider.FromRomanNumerals" Stringify: "my_provider.StringifyRow" MultiInputMultiOutput: "my_provider.MultiInputMultiOutput" + RunHuggingFaceInference: "my_provider.RunHuggingFaceInference"
diff --git a/examples/run_inference.yaml b/examples/run_inference.yaml new file mode 100644 index 0000000..87740b4 --- /dev/null +++ b/examples/run_inference.yaml
@@ -0,0 +1,27 @@ + +pipeline: + type: composite + transforms: + + # Create some input. + - type: Create + name: Prompts + config: + elements: + - {example: "translate English to Spanish: How are you doing?"} + - {example: "translate English to Spanish: This is the Apache Beam project."} + + - type: chain + input: Prompts + transforms: + - type: RunHuggingFaceInference + config: + task: 'translation_XX_to_YY' + model: 'google/flan-t5-small' + load_pipeline_args: '{"framework": "pt"}' + inference_args: '{"max_length": 200}' + + - type: LogForTesting + +providers: + - include: "./provider_listing.yaml"
diff --git a/my_provider.py b/my_provider.py index 05f4f89..067bcdf 100644 --- a/my_provider.py +++ b/my_provider.py
@@ -1,8 +1,11 @@ import contextlib +import json from typing import Optional import apache_beam as beam import apache_beam.transforms.error_handling +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler try: from apache_beam.yaml.yaml_errors import map_errors_to_standard_format except ImportError: @@ -140,3 +143,45 @@ state=result.state, temperature=result.temperature))) # As will output pcolls. return dict(avg_per_city=avg_per_city, max_per_state=max_per_state) + + +class RunHuggingFaceInference(beam.PTransform): + # An example using Beam's RunInference transform with Hugging Face. + # This allows running yaml pipelines like: + # https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_huggingface.ipynb + def __init__( + self, + task: str, + model: str, + load_pipeline_args: Optional[str] = None, + inference_args: Optional[str] = None): + """ + Returns a customizable string representation of the each row. + + Args: + task: task supported by HuggingFace Pipelines. Accepts any string task supported by HuggingFace. Full list here - https://github.com/apache/beam/blob/6d0e00ea617f2c5eeb354e2b3a304445afeec669/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L75 + model: path to the pretrained model-id on Hugging Face Models Hub to use custom model for the chosen task. + load_pipeline_args: Json encoded keyword arguments to provide load options while loading pipelines from Hugging Face. Defaults to None. + inference_args: Json encoded non-batchable arguments required as inputs to the model's inference function. Defaults to None. + """ + self.task = task + self.model = model + self.load_pipeline_args = {} + if load_pipeline_args is not None: + self.load_pipeline_args = json.loads(load_pipeline_args) + self.inference_args = {} + if inference_args is not None: + self.inference_args = json.loads(inference_args) + + def expand(self, pcoll): + model_handler = HuggingFacePipelineModelHandler( + task=self.task, + model = self.model, + load_pipeline_args=self.load_pipeline_args, + inference_args=self.inference_args + ) + return (pcoll + | beam.Map(lambda row: row.example) + | RunInference(model_handler) + | beam.Map(lambda result: beam.Row(example=result.example, inference=str(result.inference))) + )
diff --git a/my_provider_test.py b/my_provider_test.py index 88f7034..87bc954 100644 --- a/my_provider_test.py +++ b/my_provider_test.py
@@ -104,3 +104,17 @@ beam.Row(state='WA', temperature=52), beam.Row(state='NY', temperature=43), ])) + + def test_run_hugging_face_inference(self): + # Just verify inference can run correctly since it is not deterministic + with beam.Pipeline() as p: + pcoll = p | beam.Create([ + beam.Row(example="translate English to Spanish: How are you doing?"), + beam.Row(example="translate English to Spanish: This is the Apache Beam project."), + ]) + result = pcoll | my_provider.RunHuggingFaceInference( + task='translation_XX_to_YY', + model='google/flan-t5-small', + load_pipeline_args='{"framework": "pt"}', + inference_args='{"max_length": 200}' + )
diff --git a/pyproject.toml b/pyproject.toml index 999c9b4..11ddc7b 100644 --- a/pyproject.toml +++ b/pyproject.toml
@@ -14,6 +14,9 @@ python = "^3.9" apache-beam = {extras = ["gcp", "yaml"], version = "^2.61.0"} roman = "^5.0" +torch = "^2.6" +tensorflow = "^2.18" +transformers = "4.44.2" [tool.poetry.group.dev.dependencies] pytest = "^8.3.4"