blob: a8470f4fac1d70b2596d614b26c05d269cfffd79 [file] [log] [blame]
# SPDX-License-Identifier: Apache-2.0
import json
import re
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import ExpressionLanguageScope, PropertyDescriptor, StandardValidators, TimeUnit
FLOWFILE_CONTENT = "flowfile_content"
FLOWFILE_CONTENT_REFERENCE = "{" + FLOWFILE_CONTENT + "}"
# Regex to match { followed by any number of characters other than { or }, followed by }. But do not match if it starts with {{
VAR_NAME_REGEX = r"(?<!{)\{([^{]*?)\}"
class PromptChatGPT(FlowFileTransform):
class Java:
implements = ["org.apache.nifi.python.processor.FlowFileTransform"]
class ProcessorDetails:
version = "2.0.0.dev0"
description = "Submits a prompt to ChatGPT, writing the results either to a FlowFile attribute or to the contents of the FlowFile"
tags = [
"text",
"chatgpt",
"gpt",
"machine learning",
"ML",
"artificial intelligence",
"ai",
"document",
"langchain",
]
dependencies = ["langchain==0.1.2", "openai==1.9.0", "jsonpath-ng"]
MODEL = PropertyDescriptor(
name="OpenAI Model Name",
description="The name of the OpenAI Model to use in order to answer the prompt",
default_value="gpt-3.5-turbo",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
)
PROMPT = PropertyDescriptor(
name="Prompt",
description="""The prompt to issue to ChatGPT. This may use FlowFile attributes via Expression Language and may also reference the FlowFile content by using the literal
{flowfile_content} (including braces) in the prompt. If the FlowFile's content is JSON formatted, a reference may also include JSONPath Expressions
to reference specific fields in the FlowFile content, such as {$.page_content}""",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=True,
)
TEMPERATURE = PropertyDescriptor(
name="Temperature",
description="""The Temperature parameter to submit to OpenAI. A lower value will result in more consistent answers while a higher value will result in a more creative answer.
"The value must be between 0 and 2, inclusive.""",
validators=[StandardValidators._standard_validators.createNonNegativeFloatingPointValidator(2.0)],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=True,
default_value="1.0",
)
RESULT_ATTRIBUTE = PropertyDescriptor(
name="Result Attribute",
description="If specified, the result will be added to the attribute whose name is given. If not specified, the result will be written to the FlowFile's content",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
)
API_KEY = PropertyDescriptor(
name="API Key",
description="The OpenAI API Key to use",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
)
TIMEOUT = PropertyDescriptor(
name="Request Timeout",
description="The amount of time to wait before timing out the request",
validators=[StandardValidators.TIME_PERIOD_VALIDATOR],
default_value="60 secs",
required=True,
)
MAX_TOKENS = PropertyDescriptor(
name="Max Tokens to Generate",
description="The maximum number of tokens that ChatGPT should generate",
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR],
required=False,
)
ORGANIZATION = PropertyDescriptor(
name="OpenAI Organization ID",
description="The OpenAI Organization ID",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
)
API_BASE = PropertyDescriptor(
name="API Base URL Path",
description="The API Base URL to use for interacting with OpenAI. This should be populated only if using a proxy or an emulator.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
)
property_descriptors = [
MODEL,
PROMPT,
TEMPERATURE,
RESULT_ATTRIBUTE,
API_KEY,
TIMEOUT,
MAX_TOKENS,
ORGANIZATION,
API_BASE,
]
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def transform(self, context, flowFile):
from langchain import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
prompt = context.getProperty(self.PROMPT).evaluateAttributeExpressions(flowFile).getValue()
# We want to allow referencing FlowFile content using JSONPath Expressions.
# To do that, we allow the same {variable} syntax as Langchain. But Langchain does not allow '$' characters
# to exist in the variable names. So we need to replace those variables in the prompt with new variables, such as
# jsonpath_var_0, jsonpath_var_1, etc. To do this, we will use a Regex to detect any variables that are referenced
# and if it starts with a $ we will replace it with jsonpath_var_<index> and we will keep a mapping from that name to
# the substituted variable name so that we can later determine what the JSONPath expression was.
variable_references = list(set(re.findall(VAR_NAME_REGEX, prompt)))
input_variables = []
jsonpath_to_var_mapping = {}
index = 0
for ref in variable_references:
if ref.startswith("$"):
var_name = "jsonpath_var_" + str(index)
index += 1
input_variables.append(var_name)
jsonpath_to_var_mapping[ref] = var_name
prompt = prompt.replace("{" + ref + "}", "{" + var_name + "}")
elif ref == FLOWFILE_CONTENT:
input_variables.append(ref)
else:
raise ValueError(
"Prompt contained an invalid variable reference: {"
+ ref
+ "}. Valid references are flowfile_content or any JSONPath expression."
)
temperature = context.getProperty(self.TEMPERATURE).evaluateAttributeExpressions(flowFile).asFloat()
model_name = context.getProperty(self.MODEL).evaluateAttributeExpressions(flowFile).getValue()
api_key = context.getProperty(self.API_KEY).getValue()
timeout = context.getProperty(self.TIMEOUT).asTimePeriod(TimeUnit.SECONDS)
max_tokens = context.getProperty(self.MAX_TOKENS).asInteger()
organization = context.getProperty(self.ORGANIZATION).getValue()
api_base = context.getProperty(self.API_BASE).getValue()
# Build out our LLMChain
llm = ChatOpenAI(
model_name=model_name,
temperature=temperature,
openai_api_key=api_key,
request_timeout=timeout,
max_retries=0,
max_tokens=max_tokens,
openai_organization=organization,
openai_api_base=api_base,
)
prompt_template = PromptTemplate(template=prompt, input_variables=input_variables)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
# Substitute in any JSON Path Expressions or references to {flowfile_content}.
llm_args = {}
json_content = None
for var_name in variable_references:
# If variable references {flowfile_content} substitute the content
if var_name == FLOWFILE_CONTENT:
llm_args[FLOWFILE_CONTENT] = flowFile.getContentsAsBytes().decode()
if var_name.startswith("$"):
# Load the FlowFile's contents into the json_content variable only once
if json_content is None:
json_content = json.loads(flowFile.getContentsAsBytes().decode())
# Import jsonpath_ng so that we can evaluate JSONPath against the FlowFile content.
from jsonpath_ng import parse
try:
jsonpath_expression = parse(var_name)
matches = jsonpath_expression.find(json_content)
variable_value = "\n".join([match.value for match in matches])
except:
self.logger.exception(f"Invalid JSONPath reference in prompt: {var_name}")
raise
# Insert the resolved value into llm_args
resolved_var_name = jsonpath_to_var_mapping.get(var_name)
llm_args[resolved_var_name] = variable_value
self.logger.debug(f"Evaluating prompt\nPrompt: {prompt}\nArgs: #{llm_args}")
# Run the LLM Chain in order to prompt ChatGPT
results = llm_chain(llm_args)
# Create the output content or FLowFile attribute
text = results["text"]
attribute_name = context.getProperty(self.RESULT_ATTRIBUTE).getValue()
if attribute_name is None:
output_content = text
output_attributes = None
else:
output_content = None
output_attributes = {attribute_name: text}
# Return the results
return FlowFileTransformResult("success", contents=output_content, attributes=output_attributes)