blob: e17a5d66fe91180c6d1e030ea6bdee5f38d86c72 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RAG-specific embedding implementations using HuggingFace models."""
import io
from collections.abc import Sequence
from typing import Optional
import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import _add_embedding_fn
from apache_beam.ml.rag.embeddings.base import create_text_adapter
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None
try:
from PIL import Image as PILImage
except ImportError:
PILImage = None # type: ignore[assignment]
class HuggingfaceTextEmbeddings(EmbeddingsManager):
def __init__(
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
"""HuggingFace text embeddings for RAG pipelines.
Args:
model_name: Name of the sentence-transformers model to use.
max_seq_length: Maximum sequence length for the model.
**kwargs: Additional arguments passed to
:class:`~apache_beam.ml.transforms.base.EmbeddingsManager`,
including:
- ``load_model_args``: dict passed to
``SentenceTransformer()`` constructor
(e.g. ``device``, ``cache_folder``).
- ``min_batch_size`` / ``max_batch_size``:
Control batching for inference.
- ``large_model``: If True, share the model
across processes to reduce memory usage.
- ``inference_args``: dict passed to
``model.encode()``
(e.g. ``normalize_embeddings``).
"""
if not SentenceTransformer:
raise ImportError(
"sentence-transformers is required to use "
"HuggingfaceTextEmbeddings."
"Please install it with using `pip install sentence-transformers`.")
super().__init__(type_adapter=create_text_adapter(), **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.model_class = SentenceTransformer
def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _SentenceTransformerModelHandler(
model_class=self.model_class,
max_seq_length=self.max_seq_length,
model_name=self.model_name,
load_model_args=self.load_model_args,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model)
def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[EmbeddableItem],
beam.PCollection[EmbeddableItem]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(EmbeddableItem)
def _extract_images(items: Sequence[EmbeddableItem]) -> list:
"""Extract images from items and convert to PIL.Image objects.
Supports raw bytes, local file paths, and remote URIs
(e.g. gs://, s3://) via Beam's FileSystems.
"""
images = []
for item in items:
if not item.content.image:
raise ValueError(
"Expected image content in "
f"{type(item).__name__} {item.id}, "
"got None")
img_data = item.content.image
if isinstance(img_data, bytes):
img = PILImage.open(io.BytesIO(img_data))
else:
with FileSystems.open(img_data, 'rb') as f:
img = PILImage.open(f)
img.load()
images.append(img.convert('RGB'))
return images
def _create_hf_image_adapter(
) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]:
"""Creates adapter for HuggingFace image embedding.
Extracts content.image from EmbeddableItems and converts
to PIL.Image objects. Supports both raw bytes and file paths.
Returns:
EmbeddingTypeAdapter for HuggingFace image embedding.
"""
return EmbeddingTypeAdapter(
input_fn=_extract_images, output_fn=_add_embedding_fn)
class HuggingfaceImageEmbeddings(EmbeddingsManager):
def __init__(
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
"""HuggingFace image embeddings for RAG pipelines.
Generates embeddings for images using sentence-transformers
models that support image input (e.g. clip-ViT-B-32).
Args:
model_name: Name of the sentence-transformers model.
Must be an image-text model. See
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models
max_seq_length: Maximum sequence length for the model
if applicable.
**kwargs: Additional arguments passed to
:class:`~apache_beam.ml.transforms.base.EmbeddingsManager`,
including:
- ``load_model_args``: dict passed to
``SentenceTransformer()`` constructor
(e.g. ``device``, ``cache_folder``,
``trust_remote_code``).
- ``min_batch_size`` / ``max_batch_size``:
Control batching for inference.
- ``large_model``: If True, share the model
across processes to reduce memory usage.
- ``inference_args``: dict passed to
``model.encode()``
(e.g. ``normalize_embeddings``).
"""
if not SentenceTransformer:
raise ImportError(
"sentence-transformers is required to use "
"HuggingfaceImageEmbeddings. "
"Please install it with `pip install sentence-transformers`.")
if not PILImage:
raise ImportError(
"Pillow is required to use HuggingfaceImageEmbeddings. "
"Please install it with `pip install pillow`.")
super().__init__(type_adapter=_create_hf_image_adapter(), **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.model_class = SentenceTransformer
def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _SentenceTransformerModelHandler(
model_class=self.model_class,
max_seq_length=self.max_seq_length,
model_name=self.model_name,
load_model_args=self.load_model_args,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model)
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"""Returns PTransform for image embedding."""
return RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(EmbeddableItem)