blob: b21da0443467f82b8b54ff062f2bbc28985ed33b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import secrets
import time
import unittest
import hamcrest as hc
import pytest
import apache_beam as beam
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.ml.rag.ingestion.bigquery import BigQueryVectorWriterConfig
from apache_beam.ml.rag.ingestion.bigquery import SchemaConfig
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.transforms.periodicsequence import PeriodicImpulse
@pytest.mark.uses_gcp_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_JARS'),
"EXPANSION_JARS environment var is not provided, "
"indicating that jars have not been built")
class BigQueryVectorWriterConfigTest(unittest.TestCase):
BIG_QUERY_DATASET_ID = 'python_rag_bigquery_'
def setUp(self):
self.test_pipeline = TestPipeline(is_integration_test=True)
self._runner = type(self.test_pipeline.runner).__name__
self.project = self.test_pipeline.get_option('project')
self.bigquery_client = BigQueryWrapper()
self.dataset_id = '%s%d%s' % (
self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3))
self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Created dataset %s in project %s", self.dataset_id, self.project)
def tearDown(self):
request = bigquery.BigqueryDatasetsDeleteRequest(
projectId=self.project, datasetId=self.dataset_id, deleteContents=True)
try:
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Deleting dataset %s in project %s", self.dataset_id, self.project)
self.bigquery_client.client.datasets.Delete(request)
# Failing to delete a dataset should not cause a test failure.
except Exception:
_LOGGER = logging.getLogger(__name__)
_LOGGER.debug(
'Failed to clean up dataset %s in project %s',
self.dataset_id,
self.project)
def test_default_schema(self):
table_name = 'python_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)
config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"})
]
pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("1", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}])])
]
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))
with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())
def test_default_schema_missing_embedding(self):
table_name = 'python_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)
config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(id="1", content=Content(text="foo"), metadata={"a": "b"}),
Chunk(id="2", content=Content(text="bar"), metadata={"c": "d"})
]
with self.assertRaisesRegex(Exception, "must contain dense embedding"):
with beam.Pipeline() as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())
def test_custom_schema(self):
table_name = 'python_custom_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)
schema_config = SchemaConfig(
schema={
'fields': [{
'name': 'id', 'type': 'STRING'
},
{
'name': 'embedding',
'type': 'FLOAT64',
'mode': 'REPEATED'
}, {
'name': 'source', 'type': 'STRING'
}]
},
chunk_to_dict_fn=lambda chunk: {
'id': chunk.id, 'embedding': chunk.embedding.dense_embedding,
'source': chunk.metadata.get('source')
})
config = BigQueryVectorWriterConfig(
write_config={'table': table_id}, schema_config=schema_config)
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo content"),
metadata={"source": "foo"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar content"),
metadata={"source": "bar"})
]
pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, embedding, source FROM %s" % table_id,
data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")])
]
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))
with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())
def test_streaming_default_schema(self):
self.skip_if_not_dataflow_runner()
table_name = 'python_streaming_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)
config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"}),
Chunk(
id="3",
embedding=Embedding(dense_embedding=[0.5, 0.6]),
content=Content(text="foo"),
metadata={"e": "f"}),
Chunk(
id="4",
embedding=Embedding(dense_embedding=[0.7, 0.8]),
content=Content(text="bar"),
metadata={"g": "h"})
]
pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("1", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}]), ("3", "foo", [0.5, 0.6], [{
"key": "e", "value": "f"
}]), ("4", "bar", [0.7, 0.8], [{
"key": "g", "value": "h"
}])])
]
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers),
streaming=True,
allow_unsafe_triggers=True)
with beam.Pipeline(argv=args) as p:
_ = (
p
| PeriodicImpulse(0, 4, 1)
| beam.Map(lambda t: chunks[t])
| config.create_write_transform())
def skip_if_not_dataflow_runner(self):
# skip if dataflow runner is not specified
if not self._runner or "dataflowrunner" not in self._runner.lower():
self.skipTest(
"Streaming with exactly-once route has the requirement "
"`beam:requirement:pardo:on_window_expiration:v1`, "
"which is currently only supported by the Dataflow runner")
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()