blob: 3d28ccf7238d3863b42d66ebd2e84092672f236f [file] [log] [blame]
# SPDX-License-Identifier: Apache-2.0
import json
import ChromaUtils
import EmbeddingUtils
import QueryUtils
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import ExpressionLanguageScope, PropertyDescriptor, StandardValidators
class QueryChroma(FlowFileTransform):
class Java:
implements = ["org.apache.nifi.python.processor.FlowFileTransform"]
class ProcessorDetails:
version = "2.0.0.dev0"
description = "Queries a Chroma Vector Database in order to gather a specified number of documents that are most closely related to the given query."
tags = [
"chroma",
"vector",
"vectordb",
"embeddings",
"enrich",
"enrichment",
"ai",
"artificial intelligence",
"ml",
"machine learning",
"text",
"LLM",
]
QUERY = PropertyDescriptor(
name="Query",
description="""The query to issue to the Chroma VectorDB. The query is always converted into embeddings using the configured embedding function, and the embedding is
then sent to Chroma. The text itself is not sent to Chroma.""",
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
)
NUMBER_OF_RESULTS = PropertyDescriptor(
name="Number of Results",
description="The number of results to return from Chroma",
required=True,
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR],
default_value="10",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
)
METADATA_FILTER = PropertyDescriptor(
name="Metadata Filter",
description="""A JSON representation of a Metadata Filter that can be applied against the Chroma documents in order to narrow down the documents that can be returned.
For example: { "metadata_field": "some_value" }""",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=False,
)
DOCUMENT_FILTER = PropertyDescriptor(
name="Document Filter",
description="""A JSON representation of a Document Filter that can be applied against the Chroma documents' text in order to narrow down the documents that can be returned.
For example: { "$contains": "search_string" }""",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=False,
)
client = None
embedding_function = None
include_ids = None
include_metadatas = None
include_documents = None
include_distances = None
include_embeddings = None
results_field = None
property_descriptors = (
list(ChromaUtils.PROPERTIES)
+ [prop for prop in EmbeddingUtils.PROPERTIES if prop != EmbeddingUtils.EMBEDDING_MODEL]
+ [
QUERY,
NUMBER_OF_RESULTS,
QueryUtils.OUTPUT_STRATEGY,
QueryUtils.RESULTS_FIELD,
METADATA_FILTER,
DOCUMENT_FILTER,
QueryUtils.INCLUDE_IDS,
QueryUtils.INCLUDE_METADATAS,
QueryUtils.INCLUDE_DOCUMENTS,
QueryUtils.INCLUDE_DISTANCES,
QueryUtils.INCLUDE_EMBEDDINGS,
]
)
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def onScheduled(self, context):
self.client = ChromaUtils.create_client(context)
self.embedding_function = EmbeddingUtils.create_embedding_function(context)
self.include_ids = context.getProperty(QueryUtils.INCLUDE_IDS).asBoolean()
self.include_metadatas = context.getProperty(QueryUtils.INCLUDE_METADATAS).asBoolean()
self.include_documents = context.getProperty(QueryUtils.INCLUDE_DOCUMENTS).asBoolean()
self.include_distances = context.getProperty(QueryUtils.INCLUDE_DISTANCES).asBoolean()
self.include_embeddings = context.getProperty(QueryUtils.INCLUDE_EMBEDDINGS).asBoolean()
self.results_field = context.getProperty(QueryUtils.RESULTS_FIELD).getValue()
self.query_utils = QueryUtils.QueryUtils(context)
def transform(self, context, flowfile):
client = self.client
embedding_function = self.embedding_function
collection_name = (
context.getProperty(ChromaUtils.COLLECTION_NAME).evaluateAttributeExpressions(flowfile).getValue()
)
collection = client.get_collection(name=collection_name, embedding_function=embedding_function)
query_text = context.getProperty(self.QUERY).evaluateAttributeExpressions(flowfile).getValue()
embeddings = embedding_function([query_text])
included_fields = []
if self.include_distances:
included_fields.append("distances")
if self.include_documents:
included_fields.append("documents")
if self.include_embeddings:
included_fields.append("embeddings")
if self.include_metadatas:
included_fields.append("metadatas")
where = None
where_clause = context.getProperty(self.METADATA_FILTER).evaluateAttributeExpressions(flowfile).getValue()
if where_clause is not None:
where = json.loads(where_clause)
where_document = None
where_document_clause = (
context.getProperty(self.DOCUMENT_FILTER).evaluateAttributeExpressions(flowfile).getValue()
)
if where_document_clause is not None:
where_document = json.loads(where_document_clause)
query_results = collection.query(
query_embeddings=embeddings,
n_results=context.getProperty(self.NUMBER_OF_RESULTS).evaluateAttributeExpressions(flowfile).asInteger(),
include=included_fields,
where_document=where_document,
where=where,
)
ids = query_results["ids"][0]
distances = (
None
if (not self.include_distances or query_results["distances"] is None)
else query_results["distances"][0]
)
metadatas = (
None
if (not self.include_metadatas or query_results["metadatas"] is None)
else query_results["metadatas"][0]
)
documents = (
None
if (not self.include_documents or query_results["documents"] is None)
else query_results["documents"][0]
)
embeddings = (
None
if (not self.include_embeddings or query_results["embeddings"] is None)
else query_results["embeddings"][0]
)
(output_contents, mime_type) = self.query_utils.create_json(
flowfile, documents, metadatas, embeddings, distances, ids
)
# Return the results
attributes = {"mime.type": mime_type}
return FlowFileTransformResult(relationship="success", contents=output_contents, attributes=attributes)