Merge pull request #2 from apache/users/damccorm/runinference_example

diff --git a/README.md b/README.md
index ace1d79..3e4e791 100644
--- a/README.md
+++ b/README.md
@@ -33,7 +33,7 @@
 
 # Activate the virtual environment managed by Poetry
 # Alternatively, you can prefix commands with `poetry run`
-poetry shell
+eval $(poetry env activate)
 ```
 
 ## Overview
diff --git a/examples/provider_listing.yaml b/examples/provider_listing.yaml
index 2ccfbc9..e9a69e7 100644
--- a/examples/provider_listing.yaml
+++ b/examples/provider_listing.yaml
@@ -11,3 +11,4 @@
     FromRoman: "my_provider.FromRomanNumerals"
     Stringify: "my_provider.StringifyRow"
     MultiInputMultiOutput: "my_provider.MultiInputMultiOutput"
+    RunHuggingFaceInference: "my_provider.RunHuggingFaceInference"
diff --git a/examples/run_inference.yaml b/examples/run_inference.yaml
new file mode 100644
index 0000000..87740b4
--- /dev/null
+++ b/examples/run_inference.yaml
@@ -0,0 +1,27 @@
+
+pipeline:
+  type: composite
+  transforms:
+
+    # Create some input.
+    - type: Create
+      name: Prompts
+      config:
+        elements:
+          - {example: "translate English to Spanish: How are you doing?"}
+          - {example: "translate English to Spanish: This is the Apache Beam project."}
+
+    - type: chain
+      input: Prompts
+      transforms:
+        - type: RunHuggingFaceInference
+          config:
+            task: 'translation_XX_to_YY'
+            model: 'google/flan-t5-small'
+            load_pipeline_args: '{"framework": "pt"}'
+            inference_args: '{"max_length": 200}'
+
+        - type: LogForTesting
+
+providers:
+  - include: "./provider_listing.yaml"
diff --git a/my_provider.py b/my_provider.py
index 05f4f89..067bcdf 100644
--- a/my_provider.py
+++ b/my_provider.py
@@ -1,8 +1,11 @@
 import contextlib
+import json
 from typing import Optional
 
 import apache_beam as beam
 import apache_beam.transforms.error_handling
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
 try:
   from apache_beam.yaml.yaml_errors import map_errors_to_standard_format
 except ImportError:
@@ -140,3 +143,45 @@
                 state=result.state, temperature=result.temperature)))
     # As will output pcolls.
     return dict(avg_per_city=avg_per_city, max_per_state=max_per_state)
+
+
+class RunHuggingFaceInference(beam.PTransform):
+  # An example using Beam's RunInference transform with Hugging Face.
+  # This allows running yaml pipelines like:
+  # https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_huggingface.ipynb
+  def __init__(
+      self,
+      task: str,
+      model: str,
+      load_pipeline_args: Optional[str] = None,
+      inference_args: Optional[str] = None):
+    """
+    Returns a customizable string representation of the each row.
+    
+    Args:
+      task: task supported by HuggingFace Pipelines. Accepts any string task supported by HuggingFace. Full list here - https://github.com/apache/beam/blob/6d0e00ea617f2c5eeb354e2b3a304445afeec669/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L75
+      model: path to the pretrained model-id on Hugging Face Models Hub to use custom model for the chosen task.
+      load_pipeline_args: Json encoded keyword arguments to provide load options while loading pipelines from Hugging Face. Defaults to None.
+      inference_args: Json encoded non-batchable arguments required as inputs to the model's inference function. Defaults to None.
+    """
+    self.task = task
+    self.model = model
+    self.load_pipeline_args = {}
+    if load_pipeline_args is not None:
+      self.load_pipeline_args = json.loads(load_pipeline_args)
+    self.inference_args = {}
+    if inference_args is not None:
+      self.inference_args = json.loads(inference_args)
+
+  def expand(self, pcoll):
+    model_handler = HuggingFacePipelineModelHandler(
+        task=self.task,
+        model = self.model,
+        load_pipeline_args=self.load_pipeline_args,
+        inference_args=self.inference_args
+    )
+    return (pcoll
+    | beam.Map(lambda row: row.example)
+    | RunInference(model_handler)
+    | beam.Map(lambda result: beam.Row(example=result.example, inference=str(result.inference)))
+    )
diff --git a/my_provider_test.py b/my_provider_test.py
index 88f7034..87bc954 100644
--- a/my_provider_test.py
+++ b/my_provider_test.py
@@ -104,3 +104,17 @@
               beam.Row(state='WA', temperature=52),
               beam.Row(state='NY', temperature=43),
           ]))
+
+  def test_run_hugging_face_inference(self):
+    # Just verify inference can run correctly since it is not deterministic
+    with beam.Pipeline() as p:
+      pcoll = p | beam.Create([
+          beam.Row(example="translate English to Spanish: How are you doing?"),
+          beam.Row(example="translate English to Spanish: This is the Apache Beam project."),
+      ])
+      result = pcoll | my_provider.RunHuggingFaceInference(
+        task='translation_XX_to_YY',
+        model='google/flan-t5-small',
+        load_pipeline_args='{"framework": "pt"}',
+        inference_args='{"max_length": 200}'
+      )
diff --git a/pyproject.toml b/pyproject.toml
index 999c9b4..11ddc7b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,9 @@
 python = "^3.9"
 apache-beam = {extras = ["gcp", "yaml"], version = "^2.61.0"}
 roman = "^5.0"
+torch = "^2.6"
+tensorflow = "^2.18"
+transformers = "4.44.2"
 
 [tool.poetry.group.dev.dependencies]
 pytest = "^8.3.4"