blob: 1a479c8e7dbad0a8e74373ae3ebdf8d0544188b9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for apache_beam.ml.rag.embeddings.vertex_ai."""
import os
import shutil
import struct
import tempfile
import unittest
import zlib
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
# pylint: disable=ungrouped-imports
try:
import vertexai # pylint: disable=unused-import
from apache_beam.ml.rag.embeddings.vertex_ai import VertexAIImageEmbeddings
from apache_beam.ml.rag.embeddings.vertex_ai import VertexAITextEmbeddings
from apache_beam.ml.rag.embeddings.vertex_ai import _create_image_adapter
VERTEX_AI_AVAILABLE = True
except ImportError:
VERTEX_AI_AVAILABLE = False
def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False
return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAITextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]
def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
def test_embedding_pipeline(self):
# gecko@002 produces 768-dimensional embeddings
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]
embedder = VertexAITextEmbeddings(model_name="text-embedding-005")
with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))
assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAIImageAdapterTest(unittest.TestCase):
def test_image_adapter_missing_content(self):
adapter = _create_image_adapter()
item = EmbeddableItem(content=Content(), id="no_img")
with self.assertRaisesRegex(ValueError, "Expected image content"):
adapter.input_fn([item])
def test_image_adapter_str_input(self):
adapter = _create_image_adapter()
item = EmbeddableItem.from_image('gs://bucket/img.jpg', id='img1')
images = adapter.input_fn([item])
self.assertEqual(len(images), 1)
def test_image_adapter_bytes_input(self):
adapter = _create_image_adapter()
png_bytes = _create_png_bytes()
item = EmbeddableItem.from_image(png_bytes, id='img2')
images = adapter.input_fn([item])
self.assertEqual(len(images), 1)
def test_image_adapter_output(self):
adapter = _create_image_adapter()
items = [
EmbeddableItem.from_image('gs://bucket/img.jpg', id='img1'),
]
mock_embeddings = [[0.1, 0.2, 0.3]]
result = adapter.output_fn(items, mock_embeddings)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].embedding.dense_embedding, [0.1, 0.2, 0.3])
@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAIImageEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_img_')
def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
def test_image_embedding_pipeline(self):
png_bytes = _create_png_bytes()
test_items = [
EmbeddableItem.from_image(
png_bytes, id="img1", metadata={"source": "test"}),
]
# Default dimension for multimodal is 1408
expected = [
EmbeddableItem(
id="img1",
embedding=Embedding(dense_embedding=[0.0] * 1408),
metadata={"source": "test"},
content=Content(image=png_bytes)),
]
embedder = VertexAIImageEmbeddings(model_name="multimodalembedding@001")
with TestPipeline() as p:
embeddings = (
p
| beam.Create(test_items)
| MLTransform(write_artifact_location=(
self.artifact_location)).with_transform(embedder))
assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
def test_image_embedding_pipeline_from_path(self):
png_bytes = _create_png_bytes()
img_path = os.path.join(self.artifact_location, 'test_img.png')
with open(img_path, 'wb') as f:
f.write(png_bytes)
test_items = [
EmbeddableItem.from_image(
img_path, id="img1", metadata={"source": "test"}),
]
expected = [
EmbeddableItem(
id="img1",
embedding=Embedding(dense_embedding=[0.0] * 1408),
metadata={"source": "test"},
content=Content(image=img_path)),
]
artifact_location = tempfile.mkdtemp(prefix='vertex_ai_img_path_')
embedder = VertexAIImageEmbeddings(model_name="multimodalembedding@001")
try:
with TestPipeline() as p:
embeddings = (
p
| beam.Create(test_items)
| MLTransform(write_artifact_location=(
artifact_location)).with_transform(embedder))
assert_that(
embeddings,
equal_to(expected, equals_fn=chunk_approximately_equals))
finally:
shutil.rmtree(artifact_location)
def _create_png_bytes():
"""Create a minimal valid 1x1 RGB PNG image (no PIL dependency)."""
png_sig = b'\x89PNG\r\n\x1a\n'
ihdr_data = struct.pack('>IIBBBBB', 1, 1, 8, 2, 0, 0, 0)
ihdr_crc = zlib.crc32(b'IHDR' + ihdr_data)
ihdr = struct.pack('>I', 13) + b'IHDR' + ihdr_data + struct.pack(
'>I', ihdr_crc)
raw = zlib.compress(b'\x00\x00\x00\x00')
idat_crc = zlib.crc32(b'IDAT' + raw)
idat = struct.pack('>I', len(raw)) + b'IDAT' + raw + struct.pack(
'>I', idat_crc)
iend_crc = zlib.crc32(b'IEND')
iend = struct.pack('>I', 0) + b'IEND' + struct.pack('>I', iend_crc)
return png_sig + ihdr + idat + iend
if __name__ == '__main__':
unittest.main()