blob: c73aba5f42e46b7f6692a1a5f55be0c2639a1479 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from pymilvus import MilvusClient
from pymilvus.exceptions import MilvusException
import apache_beam as beam
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE
from apache_beam.ml.rag.utils import MilvusConnectionParameters
from apache_beam.ml.rag.utils import MilvusHelpers
from apache_beam.ml.rag.utils import retry_with_backoff
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
from apache_beam.transforms import DoFn
_LOGGER = logging.getLogger(__name__)
@dataclass
class MilvusWriteConfig:
"""Configuration parameters for writing data to Milvus collections.
This class defines the parameters needed to write data to a Milvus collection,
including collection targeting, batching behavior, and operation timeouts.
Args:
collection_name: Name of the target Milvus collection to write data to.
Must be a non-empty string.
partition_name: Name of the specific partition within the collection to
write to. If empty, writes to the default partition.
timeout: Maximum time in seconds to wait for write operations to complete.
If None, uses the client's default timeout.
write_config: Configuration for write operations including batch size and
other write-specific settings.
kwargs: Additional keyword arguments for write operations. Enables forward
compatibility with future Milvus client parameters.
"""
collection_name: str
partition_name: str = ""
timeout: Optional[float] = None
write_config: WriteConfig = field(default_factory=WriteConfig)
kwargs: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
@property
def write_batch_size(self):
"""Returns the batch size for write operations.
Returns:
The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified.
"""
return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE
@dataclass
class MilvusVectorWriterConfig(VectorDatabaseWriteConfig):
"""Configuration for writing vector data to Milvus collections.
This class extends VectorDatabaseWriteConfig to provide Milvus-specific
configuration for ingesting vector embeddings and associated metadata.
It defines how Apache Beam chunks are converted to Milvus records and
handles the write operation parameters.
The configuration includes connection parameters, write settings, and
column specifications that determine how chunk data is mapped to Milvus
fields.
Args:
connection_params: Configuration for connecting to the Milvus server,
including URI, credentials, and connection options.
write_config: Configuration for write operations including collection name,
partition, batch size, and timeouts.
column_specs: List of column specifications defining how chunk fields are
mapped to Milvus collection fields. Defaults to standard RAG fields
(id, embedding, sparse_embedding, content, metadata).
Example:
config = MilvusVectorWriterConfig(
connection_params=MilvusConnectionParameters(
uri="http://localhost:19530"),
write_config=MilvusWriteConfig(collection_name="my_collection"),
column_specs=MilvusVectorWriterConfig.default_column_specs())
"""
connection_params: MilvusConnectionParameters
write_config: MilvusWriteConfig
column_specs: List[ColumnSpec] = field(
default_factory=lambda: MilvusVectorWriterConfig.default_column_specs())
def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]:
"""Creates a function to convert Apache Beam Chunks to Milvus records.
Returns:
A function that takes a Chunk and returns a dictionary representing
a Milvus record with fields mapped according to column_specs.
"""
def convert(chunk: Chunk) -> Dict[str, Any]:
result = {}
for col in self.column_specs:
result[col.column_name] = col.value_fn(chunk)
return result
return convert
def create_write_transform(self) -> beam.PTransform:
"""Creates the Apache Beam transform for writing to Milvus.
Returns:
A PTransform that can be applied to a PCollection of Chunks to write
them to the configured Milvus collection.
"""
return _WriteToMilvusVectorDatabase(self)
@staticmethod
def default_column_specs() -> List[ColumnSpec]:
"""Returns default column specifications for RAG use cases.
Creates column mappings for standard RAG fields: id, dense embedding,
sparse embedding, content text, and metadata. These specifications
define how Chunk fields are converted to Milvus-compatible formats.
Returns:
List of ColumnSpec objects defining the default field mappings.
"""
column_specs = ColumnSpecsBuilder()
return column_specs\
.with_id_spec()\
.with_embedding_spec(convert_fn=lambda values: list(values))\
.with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\
.with_content_spec()\
.with_metadata_spec(convert_fn=lambda values: dict(values))\
.build()
class _WriteToMilvusVectorDatabase(beam.PTransform):
"""Apache Beam PTransform for writing vector data to Milvus.
This transform handles the conversion of Apache Beam Chunks to Milvus records
and coordinates the write operations. It applies the configured converter
function and uses a DoFn for batched writes to optimize performance.
Args:
config: MilvusVectorWriterConfig containing all necessary parameters for
the write operation.
"""
def __init__(self, config: MilvusVectorWriterConfig):
self.config = config
def expand(self, pcoll: beam.PCollection[Chunk]):
"""Expands the PTransform to convert chunks and write to Milvus.
Args:
pcoll: PCollection of Chunk objects to write to Milvus.
Returns:
PCollection of dictionaries representing the records written to Milvus.
"""
return (
pcoll
| "Convert to Records" >> beam.Map(self.config.create_converter())
| beam.ParDo(
_WriteMilvusFn(
self.config.connection_params, self.config.write_config)))
class _WriteMilvusFn(DoFn):
"""DoFn that handles batched writes to Milvus.
This DoFn accumulates records in batches and flushes them to Milvus when
the batch size is reached or when the bundle finishes. This approach
optimizes performance by reducing the number of individual write operations.
Args:
connection_params: Configuration for connecting to the Milvus server.
write_config: Configuration for write operations including batch size
and collection details.
"""
def __init__(
self,
connection_params: MilvusConnectionParameters,
write_config: MilvusWriteConfig):
self._connection_params = connection_params
self._write_config = write_config
self.batch = []
def process(self, element, *args, **kwargs):
"""Processes individual records, batching them for efficient writes.
Args:
element: A dictionary representing a Milvus record to write.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Yields:
The original element after adding it to the batch.
"""
_ = args, kwargs # Unused parameters
self.batch.append(element)
if len(self.batch) >= self._write_config.write_batch_size:
self._flush()
yield element
def finish_bundle(self):
"""Called when a bundle finishes processing.
Flushes any remaining records in the batch to ensure all data is written.
"""
self._flush()
def _flush(self):
"""Flushes the current batch of records to Milvus.
Creates a MilvusSink connection and writes all batched records,
then clears the batch for the next set of records.
"""
if len(self.batch) == 0:
return
with _MilvusSink(self._connection_params, self._write_config) as sink:
sink.write(self.batch)
self.batch = []
def display_data(self):
"""Returns display data for monitoring and debugging.
Returns:
Dictionary containing database, collection, and batch size information
for display in the Apache Beam monitoring UI.
"""
res = super().display_data()
res["database"] = self._connection_params.db_name
res["collection"] = self._write_config.collection_name
res["batch_size"] = self._write_config.write_batch_size
return res
class _MilvusSink:
"""Low-level sink for writing data directly to Milvus.
This class handles the direct interaction with the Milvus client for
upsert operations. It manages the connection lifecycle and provides
context manager support for proper resource cleanup.
Args:
connection_params: Configuration for connecting to the Milvus server.
write_config: Configuration for write operations including collection
and partition targeting.
"""
def __init__(
self,
connection_params: MilvusConnectionParameters,
write_config: MilvusWriteConfig):
self._connection_params = connection_params
self._write_config = write_config
self._client = None
def write(self, documents):
"""Writes a batch of documents to the Milvus collection.
Performs an upsert operation to insert new documents or update existing
ones based on primary key. After the upsert, flushes the collection to
ensure data persistence.
Args:
documents: List of dictionaries representing Milvus records to write.
Each dictionary should contain fields matching the collection schema.
"""
self._client = MilvusClient(
**unpack_dataclass_with_kwargs(self._connection_params))
resp = self._client.upsert(
collection_name=self._write_config.collection_name,
partition_name=self._write_config.partition_name,
data=documents,
timeout=self._write_config.timeout,
**self._write_config.kwargs)
_LOGGER.debug(
"Upserted into Milvus: upsert_count=%d, cost=%d",
resp.get("upsert_count", 0),
resp.get("cost", 0))
def __enter__(self):
"""Enters the context manager and establishes Milvus connection.
Returns:
Self, enabling use in 'with' statements.
"""
if not self._client:
connection_params = unpack_dataclass_with_kwargs(self._connection_params)
# Extract retry parameters from connection_params.
max_retries = connection_params.pop('max_retries', 3)
retry_delay = connection_params.pop('retry_delay', 1.0)
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
def create_client():
return MilvusClient(**connection_params)
self._client = retry_with_backoff(
create_client,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff_factor=retry_backoff_factor,
operation_name="Milvus connection",
exception_types=(MilvusException, ))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exits the context manager and closes the Milvus connection.
Args:
exc_type: Exception type if an exception was raised.
exc_val: Exception value if an exception was raised.
exc_tb: Exception traceback if an exception was raised.
"""
_ = exc_type, exc_val, exc_tb # Unused parameters
if self._client:
self._client.close()