blob: c371d6fd96b4f73a9c414d0a53b50b28daca3d74 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Integration tests for Spanner vector store writer."""
import logging
import os
import time
import unittest
import uuid
import pytest
import apache_beam as beam
from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.testing.test_pipeline import TestPipeline
# pylint: disable=wrong-import-order, wrong-import-position
try:
from google.cloud import spanner
except ImportError:
spanner = None
try:
from testcontainers.core.container import DockerContainer
except ImportError:
DockerContainer = None
# pylint: enable=wrong-import-order, wrong-import-position
def retry(fn, retries, err_msg, *args, **kwargs):
"""Retry a function with exponential backoff."""
for _ in range(retries):
try:
return fn(*args, **kwargs)
except: # pylint: disable=bare-except
time.sleep(1)
logging.error(err_msg)
raise RuntimeError(err_msg)
class SpannerEmulatorHelper:
"""Helper for managing Spanner emulator lifecycle."""
def __init__(self, project_id: str, instance_id: str, table_name: str):
self.project_id = project_id
self.instance_id = instance_id
self.table_name = table_name
self.host = None
# Start emulator
self.emulator = DockerContainer(
'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
9010, 9020)
retry(self.emulator.start, 3, 'Could not start spanner emulator.')
time.sleep(3)
self.host = f'{self.emulator.get_container_host_ip()}:' \
f'{self.emulator.get_exposed_port(9010)}'
os.environ['SPANNER_EMULATOR_HOST'] = self.host
# Create client and instance
self.client = spanner.Client(project_id)
self.instance = self.client.instance(instance_id)
self.create_instance()
def create_instance(self):
"""Create Spanner instance in emulator."""
self.instance.create().result(120)
def create_database(self, database_id: str):
"""Create database with default vector table schema."""
database = self.instance.database(
database_id,
ddl_statements=[
f'''
CREATE TABLE {self.table_name} (
id STRING(1024) NOT NULL,
embedding ARRAY<FLOAT32>(vector_length=>3),
content STRING(MAX),
metadata JSON
) PRIMARY KEY (id)'''
])
database.create().result(120)
def read_data(self, database_id: str):
"""Read all data from the table."""
database = self.instance.database(database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(
f'SELECT * FROM {self.table_name} ORDER BY id')
return list(results) if results else []
def drop_database(self, database_id: str):
"""Drop the database."""
database = self.instance.database(database_id)
database.drop()
def shutdown(self):
"""Stop the emulator."""
if self.emulator:
try:
self.emulator.stop()
except: # pylint: disable=bare-except
logging.error('Could not stop Spanner emulator.')
def get_emulator_host(self) -> str:
"""Get the emulator host URL."""
return f'http://{self.host}'
@pytest.mark.uses_gcp_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_JARS'),
"EXPANSION_JARS environment var is not provided, "
"indicating that jars have not been built")
@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
@unittest.skipIf(
DockerContainer is None, 'testcontainers package is not installed.')
class SpannerVectorWriterTest(unittest.TestCase):
"""Integration tests for Spanner vector writer."""
@classmethod
def setUpClass(cls):
"""Set up Spanner emulator for all tests."""
pipeline = TestPipeline(is_integration_test=True)
runner_name = type(pipeline.runner).__name__
if 'DataflowRunner' in runner_name:
pytest.skip("Spanner emulator not compatible with dataflow runner.")
cls.project_id = 'test-project'
cls.instance_id = 'test-instance'
cls.table_name = 'embeddings'
cls.spanner_helper = SpannerEmulatorHelper(
cls.project_id, cls.instance_id, cls.table_name)
@classmethod
def tearDownClass(cls):
"""Tear down Spanner emulator."""
cls.spanner_helper.shutdown()
def setUp(self):
"""Create a unique database for each test."""
self.database_id = f'test_db_{uuid.uuid4().hex}'[:30]
self.spanner_helper.create_database(self.database_id)
def tearDown(self):
"""Drop the test database."""
self.spanner_helper.drop_database(self.database_id)
def test_write_default_schema(self):
"""Test writing with default schema (id, embedding, content, metadata)."""
# Create test chunks
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='First document'),
metadata={
'source': 'test', 'page': 1
}),
Chunk(
id='doc2',
embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
content=Content(text='Second document'),
metadata={
'source': 'test', 'page': 2
}),
]
# Create config with default schema
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify data was written
results = self.spanner_helper.read_data(self.database_id)
self.assertEqual(len(results), 2)
# Check first row
row1 = results[0]
self.assertEqual(row1[0], 'doc1') # id
self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0]) # embedding
self.assertEqual(row1[2], 'First document') # content
# metadata is JSON
metadata1 = row1[3]
self.assertEqual(metadata1['source'], 'test')
self.assertEqual(metadata1['page'], 1)
# Check second row
row2 = results[1]
self.assertEqual(row2[0], 'doc2')
self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0])
self.assertEqual(row2[2], 'Second document')
def test_write_flattened_metadata(self):
"""Test writing with flattened metadata fields."""
from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
# Create custom database with flattened columns
self.spanner_helper.drop_database(self.database_id)
database = self.spanner_helper.instance.database(
self.database_id,
ddl_statements=[
f'''
CREATE TABLE {self.table_name} (
id STRING(1024) NOT NULL,
embedding ARRAY<FLOAT32>(vector_length=>3),
content STRING(MAX),
source STRING(MAX),
page_number INT64,
metadata JSON
) PRIMARY KEY (id)'''
])
database.create().result(120)
# Create test chunks
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='First document'),
metadata={
'source': 'book.pdf', 'page': 10, 'author': 'John'
}),
Chunk(
id='doc2',
embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
content=Content(text='Second document'),
metadata={
'source': 'article.txt', 'page': 5, 'author': 'Jane'
}),
]
# Create config with flattened metadata
specs = (
SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
with_content_spec().add_metadata_field(
'source', str, column_name='source').add_metadata_field(
'page', int,
column_name='page_number').with_metadata_spec().build())
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
column_specs=specs,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify data
database = self.spanner_helper.instance.database(self.database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(
f'SELECT id, embedding, content, source, page_number, metadata '
f'FROM {self.table_name} ORDER BY id')
rows = list(results)
self.assertEqual(len(rows), 2)
# Check first row
self.assertEqual(rows[0][0], 'doc1')
self.assertEqual(list(rows[0][1]), [1.0, 2.0, 3.0])
self.assertEqual(rows[0][2], 'First document')
self.assertEqual(rows[0][3], 'book.pdf') # flattened source
self.assertEqual(rows[0][4], 10) # flattened page_number
metadata1 = rows[0][5]
self.assertEqual(metadata1['author'], 'John')
def test_write_minimal_schema(self):
"""Test writing with minimal schema (only id and embedding)."""
from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
# Create custom database with minimal schema
self.spanner_helper.drop_database(self.database_id)
database = self.spanner_helper.instance.database(
self.database_id,
ddl_statements=[
f'''
CREATE TABLE {self.table_name} (
id STRING(1024) NOT NULL,
embedding ARRAY<FLOAT32>(vector_length=>3)
) PRIMARY KEY (id)'''
])
database.create().result(120)
# Create test chunks
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='First document'),
metadata={'source': 'test'}),
Chunk(
id='doc2',
embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
content=Content(text='Second document'),
metadata={'source': 'test'}),
]
# Create config with minimal schema
specs = (
SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().build(
))
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
column_specs=specs,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify data
results = self.spanner_helper.read_data(self.database_id)
self.assertEqual(len(results), 2)
self.assertEqual(results[0][0], 'doc1')
self.assertEqual(list(results[0][1]), [1.0, 2.0, 3.0])
def test_write_with_converter(self):
"""Test writing with custom converter function."""
from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
# Create test chunks with embeddings that need normalization
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[3.0, 4.0, 0.0]),
content=Content(text='First document'),
metadata={'source': 'test'}),
]
# Define normalizer
def normalize(vec):
norm = (sum(x**2 for x in vec)**0.5) or 1.0
return [x / norm for x in vec]
# Create config with normalized embeddings
specs = (
SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
convert_fn=normalize).with_content_spec().with_metadata_spec().
build())
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
column_specs=specs,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify data - embedding should be normalized
results = self.spanner_helper.read_data(self.database_id)
self.assertEqual(len(results), 1)
embedding = list(results[0][1])
# Original was [3.0, 4.0, 0.0], normalized should be [0.6, 0.8, 0.0]
self.assertAlmostEqual(embedding[0], 0.6, places=5)
self.assertAlmostEqual(embedding[1], 0.8, places=5)
self.assertAlmostEqual(embedding[2], 0.0, places=5)
# Check norm is 1.0
norm = sum(x**2 for x in embedding)**0.5
self.assertAlmostEqual(norm, 1.0, places=5)
def test_write_update_mode(self):
"""Test writing with UPDATE mode."""
# First insert data
chunks_insert = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='Original content'),
metadata={'version': 1}),
]
config_insert = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
write_mode='INSERT',
emulator_host=self.spanner_helper.get_emulator_host(),
)
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (
p
| beam.Create(chunks_insert)
| config_insert.create_write_transform())
# Update existing row
chunks_update = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
content=Content(text='Updated content'),
metadata={'version': 2}),
]
config_update = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
write_mode='UPDATE',
emulator_host=self.spanner_helper.get_emulator_host(),
)
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (
p
| beam.Create(chunks_update)
| config_update.create_write_transform())
# Verify update succeeded
results = self.spanner_helper.read_data(self.database_id)
self.assertEqual(len(results), 1)
self.assertEqual(results[0][0], 'doc1')
self.assertEqual(list(results[0][1]), [4.0, 5.0, 6.0])
self.assertEqual(results[0][2], 'Updated content')
metadata = results[0][3]
self.assertEqual(metadata['version'], 2)
def test_write_custom_column(self):
"""Test writing with custom computed column."""
from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
# Create custom database with computed column
self.spanner_helper.drop_database(self.database_id)
database = self.spanner_helper.instance.database(
self.database_id,
ddl_statements=[
f'''
CREATE TABLE {self.table_name} (
id STRING(1024) NOT NULL,
embedding ARRAY<FLOAT32>(vector_length=>3),
content STRING(MAX),
word_count INT64,
metadata JSON
) PRIMARY KEY (id)'''
])
database.create().result(120)
# Create test chunks
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='Hello world test'),
metadata={}),
Chunk(
id='doc2',
embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
content=Content(text='This is a longer test document'),
metadata={}),
]
# Create config with custom word_count column
specs = (
SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
).with_content_spec().add_column(
column_name='word_count',
python_type=int,
value_fn=lambda chunk: len(chunk.content.text.split())).
with_metadata_spec().build())
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
column_specs=specs,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify data
database = self.spanner_helper.instance.database(self.database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(
f'SELECT id, word_count FROM {self.table_name} ORDER BY id')
rows = list(results)
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0][1], 3) # "Hello world test" = 3 words
self.assertEqual(rows[1][1], 6) # 6 words
def test_write_with_timestamp(self):
"""Test writing with timestamp columns."""
from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
# Create database with timestamp column
self.spanner_helper.drop_database(self.database_id)
database = self.spanner_helper.instance.database(
self.database_id,
ddl_statements=[
f'''
CREATE TABLE {self.table_name} (
id STRING(1024) NOT NULL,
embedding ARRAY<FLOAT32>(vector_length=>3),
content STRING(MAX),
created_at TIMESTAMP,
metadata JSON
) PRIMARY KEY (id)'''
])
database.create().result(120)
# Create chunks with timestamp
timestamp_str = "2025-10-28T09:45:00.123456Z"
chunks = [
Chunk(
id='doc1',
embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
content=Content(text='Document with timestamp'),
metadata={'created_at': timestamp_str}),
]
# Create config with timestamp field
specs = (
SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
with_content_spec().add_metadata_field(
'created_at', str,
column_name='created_at').with_metadata_spec().build())
config = SpannerVectorWriterConfig(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
table_name=self.table_name,
column_specs=specs,
emulator_host=self.spanner_helper.get_emulator_host(),
)
# Write chunks
with TestPipeline() as p:
p.not_use_test_runner_api = True
_ = (p | beam.Create(chunks) | config.create_write_transform())
# Verify timestamp was written
database = self.spanner_helper.instance.database(self.database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(
f'SELECT id, created_at FROM {self.table_name}')
rows = list(results)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], 'doc1')
# Timestamp is returned as datetime object by Spanner client
self.assertIsNotNone(rows[0][1])
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()