| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import logging |
| from abc import ABC |
| from abc import abstractmethod |
| from typing import Callable |
| from typing import List |
| from typing import NamedTuple |
| from typing import Optional |
| |
| import apache_beam as beam |
| from apache_beam.coders import registry |
| from apache_beam.coders.row_coder import RowCoder |
| from apache_beam.io.jdbc import WriteToJdbc |
| from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig |
| from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig |
| from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig |
| from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec |
| from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder |
| from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution |
| from apache_beam.ml.rag.types import Chunk |
| |
| _LOGGER = logging.getLogger(__name__) |
| |
| |
| class _ConflictResolutionStrategy(ABC): |
| """Abstract base class for conflict resolution strategies.""" |
| @abstractmethod |
| def get_conflict_clause(self, all_columns: List[str]) -> str: |
| """Generate the MySQL conflict clause.""" |
| pass |
| |
| |
| class _NoConflictStrategy(_ConflictResolutionStrategy): |
| """Strategy for when no conflict resolution is needed.""" |
| def get_conflict_clause(self, all_columns: List[str]) -> str: |
| return "" |
| |
| |
| class _UpdateStrategy(_ConflictResolutionStrategy): |
| """Strategy for UPDATE action on conflict.""" |
| def __init__(self, update_fields: Optional[List[str]] = None): |
| self.update_fields = update_fields |
| |
| def get_conflict_clause(self, all_columns: List[str]) -> str: |
| # Use provided fields or default to all columns |
| fields_to_update = self.update_fields or all_columns |
| assert len(fields_to_update) > 0 |
| |
| updates = [f"{field} = VALUES({field})" for field in fields_to_update] |
| return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}" |
| |
| |
| class _IgnoreStrategy(_ConflictResolutionStrategy): |
| """Strategy for IGNORE action on conflict.""" |
| def __init__(self, primary_key_field: str): |
| self.primary_key_field = primary_key_field |
| |
| def get_conflict_clause(self, all_columns: List[str]) -> str: |
| return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\ |
| f" = {self.primary_key_field}" |
| |
| |
| def _create_conflict_strategy( |
| conflict_resolution: Optional[ConflictResolution] |
| ) -> _ConflictResolutionStrategy: |
| if conflict_resolution is None: |
| return _NoConflictStrategy() |
| if conflict_resolution.action == "UPDATE": |
| return _UpdateStrategy(conflict_resolution.update_fields) |
| if conflict_resolution.action == "IGNORE": |
| assert conflict_resolution.primary_key_field is not None |
| return _IgnoreStrategy(conflict_resolution.primary_key_field) |
| raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}") |
| |
| |
| class _MySQLQueryBuilder: |
| def __init__( |
| self, |
| table_name: str, |
| *, |
| column_specs: List[ColumnSpec], |
| conflict_resolution: Optional[ConflictResolution] = None): |
| """Builds SQL queries for writing Chunks with Embeddings to MySQL. |
| """ |
| self.table_name = table_name |
| |
| self.column_specs = column_specs |
| self.conflict_resolution_strategy = _create_conflict_strategy( |
| conflict_resolution) |
| |
| names = [col.column_name for col in self.column_specs] |
| duplicates = set(name for name in names if names.count(name) > 1) |
| if duplicates: |
| raise ValueError(f"Duplicate column names found: {duplicates}") |
| |
| fields = [(col.column_name, col.python_type) for col in self.column_specs] |
| type_name = f"VectorRecord_{table_name}" |
| self.record_type = NamedTuple(type_name, fields) # type: ignore |
| |
| registry.register_coder(self.record_type, RowCoder) |
| |
| def build_insert(self) -> str: |
| fields = [col.column_name for col in self.column_specs] |
| placeholders = [col.placeholder for col in self.column_specs] |
| |
| # Build base query |
| query = f""" |
| INSERT INTO {self.table_name} |
| ({', '.join(fields)}) |
| VALUES ({', '.join(placeholders)}) |
| """ |
| conflict_clause = self.conflict_resolution_strategy.get_conflict_clause( |
| all_columns=fields) |
| query += f" {conflict_clause}" |
| |
| _LOGGER.info("MySQL Query with placeholders %s", query) |
| return query |
| |
| def create_converter(self) -> Callable[[Chunk], NamedTuple]: |
| """Creates a function to convert Chunks to records.""" |
| def convert(chunk: Chunk) -> self.record_type: # type: ignore |
| return self.record_type( |
| **{col.column_name: col.value_fn(chunk) |
| for col in self.column_specs}) # type: ignore |
| |
| return convert |
| |
| |
| class MySQLVectorWriterConfig(VectorDatabaseWriteConfig): |
| def __init__( |
| self, |
| connection_config: ConnectionConfig, |
| table_name: str, |
| *, |
| # pylint: disable=dangerous-default-value |
| write_config: WriteConfig = WriteConfig(), |
| column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( |
| ), |
| conflict_resolution: Optional[ConflictResolution] = None): |
| """Configuration for writing vectors to MySQL using jdbc. |
| |
| Supports flexible schema configuration through column specifications and |
| conflict resolution strategies with MySQL-specific syntax. |
| |
| Args: |
| connection_config: |
| :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. |
| table_name: Target table name. |
| write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control |
| batch sizes, authosharding, etc. |
| column_specs: |
| Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how |
| embeddings and metadata are written to the database |
| schema. If None, uses default Chunk schema with MySQL vector |
| functions. |
| conflict_resolution: Optional |
| :class:`~.mysql_common.ConflictResolution` |
| strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE. |
| None by default, meaning errors are thrown when attempting to insert |
| duplicates. |
| |
| Examples: |
| Simple case with default schema: |
| |
| >>> config = MySQLVectorWriterConfig( |
| ... connection_config=ConnectionConfig(...), |
| ... table_name='embeddings' |
| ... ) |
| |
| Custom schema with metadata fields and MySQL functions: |
| |
| >>> specs = (ColumnSpecsBuilder() |
| ... .with_id_spec(column_name="my_id_column") |
| ... .with_embedding_spec( |
| ... column_name="embedding_vec", |
| ... placeholder="string_to_vector(?)" |
| ... ) |
| ... .add_metadata_field(field="source", column_name="src") |
| ... .add_metadata_field( |
| ... "timestamp", |
| ... column_name="created_at", |
| ... placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')" |
| ... ) |
| ... .build()) |
| |
| Minimal schema (only ID + embedding written): |
| |
| >>> column_specs = (ColumnSpecsBuilder() |
| ... .with_id_spec() |
| ... .with_embedding_spec() |
| ... .build()) |
| |
| >>> config = MySQLVectorWriterConfig( |
| ... connection_config=ConnectionConfig(...), |
| ... table_name='embeddings', |
| ... column_specs=specs, |
| ... conflict_resolution=ConflictResolution( |
| ... on_conflict_fields=["id"], |
| ... action="UPDATE", |
| ... update_fields=["embedding", "content"] |
| ... ) |
| ... ) |
| |
| Using MySQL JSON functions: |
| |
| >>> specs = (ColumnSpecsBuilder() |
| ... .with_id_spec() |
| ... .with_embedding_spec() |
| ... .with_metadata_spec( |
| ... column_name="metadata_json", |
| ... placeholder="CAST(? AS JSON)" |
| ... ) |
| ... .build()) |
| """ |
| self.connection_config = connection_config |
| self.write_config = write_config |
| # NamedTuple is created and registered here during pipeline construction |
| self.query_builder = _MySQLQueryBuilder( |
| table_name, |
| column_specs=column_specs, |
| conflict_resolution=conflict_resolution) |
| |
| def create_write_transform(self) -> beam.PTransform: |
| return _WriteToMySQLVectorDatabase(self) |
| |
| |
| class _WriteToMySQLVectorDatabase(beam.PTransform): |
| """Implementation of MySQL vector database write.""" |
| def __init__(self, config: MySQLVectorWriterConfig): |
| self.config = config |
| self.query_builder = config.query_builder |
| self.connection_config = config.connection_config |
| self.write_config = config.write_config |
| |
| def expand(self, pcoll: beam.PCollection[Chunk]): |
| return ( |
| pcoll |
| | |
| "Convert to Records" >> beam.Map(self.query_builder.create_converter()) |
| | "Write to MySQL" >> WriteToJdbc( |
| table_name=self.query_builder.table_name, |
| driver_class_name="com.mysql.cj.jdbc.Driver", |
| jdbc_url=self.connection_config.jdbc_url, |
| username=self.connection_config.username, |
| password=self.connection_config.password, |
| statement=self.query_builder.build_insert(), |
| connection_properties=self.connection_config.connection_properties, |
| connection_init_sqls=self.connection_config.connection_init_sqls, |
| autosharding=self.write_config.autosharding, |
| max_connections=self.write_config.max_connections, |
| write_batch_size=self.write_config.write_batch_size, |
| **self.connection_config.additional_jdbc_args)) |