blob: 3d41b0b11b114df25b8515de0894c16da9359b3b [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
from apache_beam.ml.rag.types import EmbeddableItem
_LOGGER = logging.getLogger(__name__)
MetadataSpec = Union[ColumnSpec, dict[str, ColumnSpec]]
class _PostgresQueryBuilder:
def __init__(
self,
table_name: str,
*,
column_specs: list[ColumnSpec],
conflict_resolution: Optional[ConflictResolution] = None):
"""Builds SQL queries for writing EmbeddableItems to Postgres.
"""
self.table_name = table_name
self.column_specs = column_specs
self.conflict_resolution = conflict_resolution
# Validate no duplicate column names
names = [col.column_name for col in self.column_specs]
duplicates = set(name for name in names if names.count(name) > 1)
if duplicates:
raise ValueError(f"Duplicate column names found: {duplicates}")
# Create NamedTuple type
fields = [(col.column_name, col.python_type) for col in self.column_specs]
type_name = f"VectorRecord_{table_name}"
self.record_type = NamedTuple(type_name, fields) # type: ignore
# Register coder
registry.register_coder(self.record_type, RowCoder)
# Set default update fields to all non-conflict fields if update fields are
# not specified
if self.conflict_resolution:
self.conflict_resolution.maybe_set_default_update_fields(
[col.column_name for col in self.column_specs if col.column_name])
def build_insert(self) -> str:
"""Build INSERT query with proper type casting."""
# Get column names and placeholders
fields = [col.column_name for col in self.column_specs]
placeholders = [col.placeholder for col in self.column_specs]
# Build base query
query = f"""
INSERT INTO {self.table_name}
({', '.join(fields)})
VALUES ({', '.join(placeholders)})
"""
# Add conflict handling if configured
if self.conflict_resolution:
query += f" {self.conflict_resolution.get_conflict_clause()}"
_LOGGER.info("Query with placeholders %s", query)
return query
def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]:
"""Creates a function to convert EmbeddableItems to records."""
def convert(chunk: EmbeddableItem) -> self.record_type: # type: ignore
return self.record_type(
**{col.column_name: col.value_fn(chunk)
for col in self.column_specs}) # type: ignore
return convert
class PostgresVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(
self,
connection_config: ConnectionConfig,
table_name: str,
*,
# pylint: disable=dangerous-default-value
write_config: WriteConfig = WriteConfig(),
column_specs: list[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build(
),
conflict_resolution: Optional[ConflictResolution] = ConflictResolution(
on_conflict_fields=[], action='IGNORE')):
"""Configuration for writing vectors to Postgres using jdbc.
Supports flexible schema configuration through column specifications and
conflict resolution strategies.
Args:
connection_config:
:class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
table_name: Target table name.
write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
batch sizes, authosharding, etc.
column_specs:
Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how
embeddings and metadata are written a database
schema. If None, uses default EmbeddableItem schema.
conflict_resolution: Optional
:class:`~.postgres_common.ConflictResolution`
strategy for handling insert conflicts. ON CONFLICT DO NOTHING by
default.
Examples:
Simple case with default schema:
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings'
... )
Custom schema with metadata fields:
>>> specs = (ColumnSpecsBuilder()
... .with_id_spec(column_name="my_id_column")
... .with_embedding_spec(column_name="embedding_vec")
... .add_metadata_field(field="source", column_name="src")
... .add_metadata_field(
... "timestamp",
... column_name="created_at",
... sql_typecast="::timestamp"
... )
... .build())
Minimal schema (only ID + embedding written)
>>> column_specs = (ColumnSpecsBuilder()
... .with_id_spec()
... .with_embedding_spec()
... .build())
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings',
... column_specs=specs
... )
"""
self.connection_config = connection_config
self.write_config = write_config
# NamedTuple is created and registered here during pipeline construction
self.query_builder = _PostgresQueryBuilder(
table_name,
column_specs=column_specs,
conflict_resolution=conflict_resolution)
def create_write_transform(self) -> beam.PTransform:
return _WriteToPostgresVectorDatabase(self)
class _WriteToPostgresVectorDatabase(beam.PTransform):
"""Implementation of Postgres vector database write. """
def __init__(self, config: PostgresVectorWriterConfig):
self.config = config
self.query_builder = config.query_builder
self.connection_config = config.connection_config
self.write_config = config.write_config
def expand(self, pcoll: beam.PCollection[EmbeddableItem]):
return (
pcoll
|
"Convert to Records" >> beam.Map(self.query_builder.create_converter())
| "Write to Postgres" >> WriteToJdbc(
table_name=self.query_builder.table_name,
driver_class_name="org.postgresql.Driver",
jdbc_url=self.connection_config.jdbc_url,
username=self.connection_config.username,
password=self.connection_config.password,
statement=self.query_builder.build_insert(),
connection_properties=self.connection_config.connection_properties,
connection_init_sqls=self.connection_config.connection_init_sqls,
autosharding=self.write_config.autosharding,
max_connections=self.write_config.max_connections,
write_batch_size=self.write_config.write_batch_size,
**self.connection_config.additional_jdbc_args))