blob: 70397550241f09cbb0e0efa276835fcb0739064f [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements enrichment classes to implement semantic search on opensearch Vector DB.
opensearch :Enrichment Handler
-----------------
:class:`opensearchEnrichmentHandler` is a ``EnrichmentSourceHandler`` that performs enrichment/search
by fetching the similar text to the user query/prompt from the knowledge base (opensearch vector DB) and returns
the similar text along with its embeddings as Beam.Row Object.
Example usage::
opensearch_handler = opensearchEnrichmentHandler(opensearch_host='127.0.0.1', opensearch_port=6379)
pipeline | Enrichment(opensearch_handler)
No backward compatibility guarantees. Everything in this module is experimental.
"""
import logging
from opensearchpy import OpenSearch
from typing import Optional
import os
import apache_beam as beam
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel
__all__ = [
'OpenSearchEnrichmentHandler',
]
# Set the logging level to reduce verbose information
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger(__name__)
class OpenSearchEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
"""A handler for :class:`apache_beam.transforms.enrichment.Enrichment`
transform to interact with opensearch vector DB.
"""
def __init__(
self,
opensearch_host: str,
opensearch_port: int,
username: Optional[str],
password: Optional[str],
index_name: str = "embeddings-index",
vector_field: str = "text_vector",
k: int = 1,
size: int = 5,
):
"""Args:
opensearch_host (str): opensearch Host to connect to opensearch DB
opensearch_port (int): opensearch Port to connect to opensearch DB
index_name (str): Index Name created for searching in opensearch DB
vector_field (str): vector field to compute similarity score in vector DB
k (int): Value of K in KNN algorithm for searching in opensearch
"""
self.opensearch_host = opensearch_host
self.opensearch_port = opensearch_port
self.username = username | os.getenv("OPENSEARCH_USERNAME")
self.password = password | os.getenv("OPENSEARCH_PASSWORD")
self.index_name = index_name
self.vector_field = vector_field
self.k = k
self.size = size
self.client = None
if not self.username or not self.password:
raise ValueError("Username and password are needed for connecting to Opensearch cluster.")
def __enter__(self):
"""connect to the opensearch DB using opensearch client.
"""
if self.client is None:
http_auth = [self.username, self.password]
self.client = OpenSearch(hosts=[f'{self.opensearch_host}:{self.opensearch_port}'],
http_auth=http_auth,
verify_certs=False)
def __call__(self, request: beam.Row, *args, **kwargs):
"""
Reads a row from the opensearch Vector DB and returns
a `Tuple` of request and response.
Args:
request: the input `beam.Row` to enrich.
"""
# read embedding vector for user query
embedded_query = request['text']
# Prepare the Query
query = {
'size': self.size,
'query': {
'knn': {
self.vector_field: {
"vector": embedded_query,
"k": self.k
}
}
}
}
# perform vector search
results = self.client.search(
body=query,
index=self.index_name
)
logger.info("Enrichment_results", results)
return beam.Row(text=embedded_query), beam.Row(docs=results)