blob: abe9efc56cbfded0cd3865387633dcd4cedca0be [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Optional
import grpc
from objsize import get_deep_size
try:
from qdrant_client import QdrantClient
from qdrant_client import models
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.exceptions import UnexpectedResponse
except ImportError:
logging.warning("Qdrant client library is not installed.")
import apache_beam as beam
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.types import EmbeddableItem
DEFAULT_WRITE_BATCH_SIZE = 1000
DEFAULT_MAX_BATCH_BYTE_SIZE = 4 << 20
@dataclass
class QdrantConnectionParameters:
"""Configuration parameters for connecting to Qdrant service.
Either `location`, `url`, `host`, or `path` must be provided to establish
a connection.
Args:
location:
If `str` - use it as a `url` parameter.
If `None` - use default values for `host` and `port`.
url: either host or str of "<scheme>//<host>:<port>/<prefix>".
Default: `None`
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc: If `true` - use gPRC interface whenever possible.
https: If `true` - use HTTPS(SSL) protocol. Default: `None`
api_key: API key for authentication in Qdrant Cloud. Default: `None`
prefix:
If not `None` - add `prefix` to the REST URL path.
Example: `service/v1` will result in
`http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.
Default: `None`
timeout:
Timeout for REST and gRPC API requests.
Default: 5 seconds for REST and unlimited for gRPC
host:
Host name of Qdrant service.
If url and host are None, set to 'localhost'.
Default: `None`
path: Persistence path for QdrantLocal. Default: `None`
**kwargs: Additional arguments passed directly into client initialization
"""
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
grpc_port: int = 6334
prefer_grpc: bool = False
https: Optional[bool] = None
api_key: Optional[str] = None
prefix: Optional[str] = None
timeout: Optional[int] = None
host: Optional[str] = None
path: Optional[str] = None
kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not (self.location or self.url or self.host or self.path):
raise ValueError(
"One of location, url, host, or path must be provided for Qdrant")
@classmethod
def for_cloud(
cls,
url: str,
api_key: str,
*,
prefer_grpc: bool = False,
timeout: Optional[int] = None,
**kwargs: Any,
) -> "QdrantConnectionParameters":
"""Connect to Qdrant Cloud. Requires the cluster URL and an API key."""
return cls(
url=url,
api_key=api_key,
https=True,
prefer_grpc=prefer_grpc,
timeout=timeout,
kwargs=kwargs,
)
@classmethod
def for_host(
cls,
host: str,
port: int = 6333,
*,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: bool = False,
api_key: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> "QdrantConnectionParameters":
"""Connect to a self-hosted Qdrant instance by host and port."""
return cls(
host=host,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
timeout=timeout,
kwargs=kwargs,
)
@classmethod
def for_url(
cls,
url: str,
*,
api_key: Optional[str] = None,
prefer_grpc: bool = False,
timeout: Optional[int] = None,
**kwargs: Any,
) -> "QdrantConnectionParameters":
"""Connect using a full URL like 'https://my-qdrant.example.com:6333'."""
return cls(
url=url,
api_key=api_key,
prefer_grpc=prefer_grpc,
timeout=timeout,
kwargs=kwargs)
@classmethod
def local(cls, path: str) -> "QdrantConnectionParameters":
"""Use an embedded Qdrant instance persisted to the given path."""
return cls(path=path)
@classmethod
def in_memory(cls) -> "QdrantConnectionParameters":
"""Use an embedded in-memory Qdrant instance. Useful for tests."""
return cls(location=":memory:")
@dataclass
class QdrantWriteConfig(VectorDatabaseWriteConfig):
"""Configuration for writing to Qdrant vector database.
This class defines the parameters needed to write data to a qdrant collection,
including collection targeting, batching behavior, and operation timeouts.
Args:
connection_params: QdrantConnectionParameters with connection settings.
collection_name: Name of the Qdrant collection to write to.
timeout: Optional timeout for write operations in seconds. Default is None.
batch_size: Number of points to write in each batch. Default is 1000.
kwargs: Additional keyword arguments to pass to the client's upsert method.
dense_embedding_key: name for the dense vector in the qdrant collection.
sparse_embedding_key: name for the sparse vector in the qdrant collection.
"""
connection_params: QdrantConnectionParameters
collection_name: str
timeout: Optional[int] = None
batch_size: int = DEFAULT_WRITE_BATCH_SIZE
max_batch_byte_size: int = DEFAULT_MAX_BATCH_BYTE_SIZE
kwargs: dict[str, Any] = field(default_factory=dict)
dense_embedding_key: str = "dense"
sparse_embedding_key: str = "sparse"
def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
if self.batch_size <= 0:
raise ValueError("Batch size must be a positive integer")
def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]:
return _QdrantWriteTransform(self)
def create_converter(
self,
) -> Callable[[EmbeddableItem], "models.PointStruct"]:
def convert(item: EmbeddableItem) -> "models.PointStruct":
if item.dense_embedding is None and item.sparse_embedding is None:
raise ValueError(
"EmbeddableItem must have at least one embedding (dense or sparse)")
vector = {}
if item.dense_embedding is not None:
vector[self.dense_embedding_key] = item.dense_embedding
if item.sparse_embedding is not None:
sparse_indices, sparse_values = item.sparse_embedding
vector[self.sparse_embedding_key] = models.SparseVector(
indices=sparse_indices,
values=sparse_values,
)
id = (
int(item.id)
if isinstance(item.id, str) and item.id.isdigit() else item.id)
return models.PointStruct(
id=id,
vector=vector,
payload=item.metadata if item.metadata else None,
)
return convert
class _QdrantWriteTransform(beam.PTransform):
def __init__(self, config: QdrantWriteConfig):
self.config = config
def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]):
return (
input_or_inputs
| "Convert to Records" >> beam.Map(self.config.create_converter())
| beam.ParDo(_QdrantWriteFn(self.config)))
class _QdrantWriteFn(beam.DoFn):
def __init__(self, config: QdrantWriteConfig):
self.config = config
self._client: "Optional[QdrantClient]" = None
def start_bundle(self):
self._batch = []
self._batch_byte_size = 0
def process(self, element, *args, **kwargs):
element_byte_size = get_deep_size(element)
new_batch_byte_size = self._batch_byte_size + element_byte_size
is_batch_full = len(self._batch) >= self.config.batch_size
is_batch_too_large = new_batch_byte_size > self.config.max_batch_byte_size
if (is_batch_full or is_batch_too_large):
self._flush()
self._batch.append(element)
self._batch_byte_size += element_byte_size
def setup(self):
params = self.config.connection_params
self._client = QdrantClient(
location=params.location,
url=params.url,
port=params.port,
grpc_port=params.grpc_port,
prefer_grpc=params.prefer_grpc,
https=params.https,
api_key=params.api_key,
prefix=params.prefix,
timeout=params.timeout,
host=params.host,
path=params.path,
check_compatibility=False,
**params.kwargs,
)
def teardown(self):
if self._client:
try:
self._client.close()
finally:
self._client = None
def finish_bundle(self):
self._flush()
def _flush(self):
if not self._batch:
return
if not self._client:
raise RuntimeError("Qdrant client is not initialized")
max_retries = 3
attempt = 1
while True:
try:
self._client.upsert(
collection_name=self.config.collection_name,
points=self._batch,
timeout=self.config.timeout,
**self.config.kwargs,
)
break
except ResourceExhaustedResponse as e:
time.sleep(e.retry_after_s)
# don't count rate-limit against max_retries
continue
except (UnexpectedResponse, ResponseHandlingException,
grpc.RpcError) as e:
if attempt > max_retries:
raise
time.sleep(2**attempt)
attempt += 1
self._batch = []
self._batch_byte_size = 0
def display_data(self):
res = super().display_data()
res["collection"] = self.config.collection_name
res["batch_size"] = self.config.batch_size
res["max_batch_byte_size"] = self.config.max_batch_byte_size
return res