blob: 4a5ada007d93b74806bab87d0e76dfb6ce9461a3 [file] [log] [blame]
# SPDX-License-Identifier: Apache-2.0
from nifiapi.properties import ExpressionLanguageScope, PropertyDependency, PropertyDescriptor, StandardValidators
# Connection Strategies
LOCAL_DISK = "Local Disk"
REMOTE_SERVER = "Remote Chroma Server"
# Authentication Strategies
TOKEN = "Token Authentication"
BASIC_AUTH = "Basic Authentication"
NONE = "None"
# Transport Protocols
HTTP = "http"
HTTPS = "https"
CONNECTION_STRATEGY = PropertyDescriptor(
name="Connection Strategy",
description="Specifies how to connect to the Chroma server",
allowable_values=[LOCAL_DISK, REMOTE_SERVER],
default_value=REMOTE_SERVER,
required=True,
)
DIRECTORY = PropertyDescriptor(
name="Directory",
description="The Directory that Chroma should use to persist data",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
default_value="./chroma",
dependencies=[PropertyDependency(CONNECTION_STRATEGY, LOCAL_DISK)],
)
HOSTNAME = PropertyDescriptor(
name="Hostname",
description="The hostname to connect to in order to communicate with Chroma",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="localhost",
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)],
)
PORT = PropertyDescriptor(
name="Port",
description="The port that the Chroma server is listening on",
validators=[StandardValidators.PORT_VALIDATOR],
default_value="8000",
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)],
)
TRANSPORT_PROTOCOL = PropertyDescriptor(
name="Transport Protocol",
description="Specifies whether connections should be made over http or https",
allowable_values=[HTTP, HTTPS],
default_value=HTTPS,
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)],
)
AUTH_STRATEGY = PropertyDescriptor(
name="Authentication Strategy",
description="Specifies how to authenticate to Chroma server",
allowable_values=[TOKEN, BASIC_AUTH, NONE],
default_value=TOKEN,
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)],
)
AUTH_TOKEN = PropertyDescriptor(
name="Authentication Token",
description="The token to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, TOKEN)],
)
USERNAME = PropertyDescriptor(
name="Username",
description="The username to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, BASIC_AUTH)],
)
PASSWORD = PropertyDescriptor(
name="Password",
description="The password to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, BASIC_AUTH)],
)
COLLECTION_NAME = PropertyDescriptor(
name="Collection Name",
description="The name of the Chroma Collection",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
default_value="nifi",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
)
PROPERTIES = [
CONNECTION_STRATEGY,
DIRECTORY,
HOSTNAME,
PORT,
TRANSPORT_PROTOCOL,
AUTH_STRATEGY,
AUTH_TOKEN,
USERNAME,
PASSWORD,
COLLECTION_NAME,
]
def create_client(context):
import chromadb
from chromadb import Settings
connection_strategy = context.getProperty(CONNECTION_STRATEGY).getValue()
if connection_strategy == LOCAL_DISK:
directory = context.getProperty(DIRECTORY).getValue()
return chromadb.PersistentClient(directory)
hostname = context.getProperty(HOSTNAME).getValue()
port = context.getProperty(PORT).asInteger()
headers = {}
ssl = context.getProperty(TRANSPORT_PROTOCOL).getValue() == HTTPS
auth_strategy = context.getProperty(AUTH_STRATEGY).getValue()
if auth_strategy == TOKEN:
auth_provider = "chromadb.auth.token.TokenAuthClientProvider"
credentials = context.getProperty(AUTH_TOKEN).getValue()
elif auth_strategy == BASIC_AUTH:
auth_provider = "chromadb.auth.basic.BasicAuthClientProvider"
username = context.getProperty(USERNAME).getValue()
password = context.getProperty(PASSWORD).getValue()
credentials = username + ":" + password
else:
auth_provider = None
credentials = None
settings = Settings(chroma_client_auth_provider=auth_provider, chroma_client_auth_credentials=credentials)
return chromadb.HttpClient(hostname, port, ssl, headers, settings)