blob: 65c21771524dc51db0e7113d5604eea9fb567364 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for apache_beam.ml.rag.embeddings.huggingface."""
import io
import os
import shutil
import tempfile
import unittest
import pytest
import apache_beam as beam
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceImageEmbeddings
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings
from apache_beam.ml.rag.embeddings.huggingface import _create_hf_image_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
# pylint: disable=unused-import
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False
return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
@pytest.mark.uses_transformers
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available")
class HuggingfaceTextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]
def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
def test_embedding_pipeline(self):
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 384),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 384),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]
embedder = HuggingfaceTextEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2")
with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))
assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
@pytest.mark.uses_transformers
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE or not PIL_AVAILABLE,
"sentence-transformers or PIL not available")
class HuggingfaceImageAdapterTest(unittest.TestCase):
def test_image_adapter_missing_content(self):
adapter = _create_hf_image_adapter()
item = EmbeddableItem(content=Content(), id="no_img")
with self.assertRaisesRegex(ValueError, "Expected image content"):
adapter.input_fn([item])
def test_image_adapter_bytes_input(self):
adapter = _create_hf_image_adapter()
png_bytes = _create_png_bytes()
item = EmbeddableItem.from_image(png_bytes, id='img1')
images = adapter.input_fn([item])
self.assertEqual(len(images), 1)
def test_image_adapter_path_input(self):
adapter = _create_hf_image_adapter()
png_bytes = _create_png_bytes()
tmpdir = tempfile.mkdtemp()
try:
img_path = os.path.join(tmpdir, 'test.png')
with open(img_path, 'wb') as f:
f.write(png_bytes)
item = EmbeddableItem.from_image(img_path, id='img2')
images = adapter.input_fn([item])
self.assertEqual(len(images), 1)
finally:
shutil.rmtree(tmpdir)
def test_image_adapter_output(self):
adapter = _create_hf_image_adapter()
items = [EmbeddableItem.from_image(b'\x89PNG', id='img1')]
mock_embeddings = [[0.1, 0.2, 0.3]]
result = adapter.output_fn(items, mock_embeddings)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].embedding.dense_embedding, [0.1, 0.2, 0.3])
@pytest.mark.uses_transformers
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE or not PIL_AVAILABLE,
"sentence-transformers or PIL not available")
class HuggingfaceImageEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='hf_image_')
def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
def test_image_embedding_pipeline(self):
png_bytes = _create_png_bytes()
test_items = [
EmbeddableItem.from_image(
png_bytes, id="img1", metadata={"source": "test"}),
]
# clip-ViT-B-32 produces 512-dimensional embeddings
expected = [
EmbeddableItem(
id="img1",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={"source": "test"},
content=Content(image=png_bytes)),
]
embedder = HuggingfaceImageEmbeddings(model_name="clip-ViT-B-32")
with TestPipeline() as p:
embeddings = (
p
| beam.Create(test_items)
| MLTransform(write_artifact_location=(
self.artifact_location)).with_transform(embedder))
assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
def test_image_embedding_pipeline_from_path(self):
png_bytes = _create_png_bytes()
img_path = os.path.join(self.artifact_location, 'test_img.png')
with open(img_path, 'wb') as f:
f.write(png_bytes)
test_items = [
EmbeddableItem.from_image(
img_path, id="img1", metadata={"source": "test"}),
]
expected = [
EmbeddableItem(
id="img1",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={"source": "test"},
content=Content(image=img_path)),
]
artifact_location = tempfile.mkdtemp(prefix='hf_image_path_')
embedder = HuggingfaceImageEmbeddings(model_name="clip-ViT-B-32")
try:
with TestPipeline() as p:
embeddings = (
p
| beam.Create(test_items)
| MLTransform(write_artifact_location=(
artifact_location)).with_transform(embedder))
assert_that(
embeddings,
equal_to(expected, equals_fn=chunk_approximately_equals))
finally:
shutil.rmtree(artifact_location)
def _create_png_bytes():
"""Create a small valid RGB PNG image."""
from PIL import Image
img = Image.new('RGB', (10, 10), color=(128, 64, 32))
buf = io.BytesIO()
img.save(buf, format='PNG')
return buf.getvalue()
if __name__ == '__main__':
unittest.main()