blob: e482cb6880cca43804e2822046ad46af704b1641 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// The functions in this file are specifically not cross-compiled to IR because there
// is no signifcant performance benefit to be gained.
#include <gutil/strings/util.h>
#include "exprs/ai-functions.inline.h"
using namespace impala_udf;
DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions",
"The default API endpoint for an external AI engine.");
DEFINE_validator(ai_endpoint, [](const char* name, const string& endpoint) {
return (impala::AiFunctions::is_api_endpoint_valid(endpoint) &&
impala::AiFunctions::is_api_endpoint_supported(endpoint));
});
DEFINE_string(ai_model, "gpt-4", "The default AI model used by an external AI engine.");
DEFINE_string(ai_api_key_jceks_secret, "",
"The jceks secret key used for extracting the api key from configured keystores. "
"'hadoop.security.credential.provider.path' in core-site must be configured to "
"include the keystore storing the corresponding secret.");
DEFINE_int32(ai_connection_timeout_s, 10,
"(Advanced) The time in seconds for connection timed out when communicating with an "
"external AI engine");
TAG_FLAG(ai_api_key_jceks_secret, sensitive);
namespace impala {
// static class members
const string AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR = "Invalid Json";
const string AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR =
"Invalid Protocol, use https";
const string AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR =
"Unsupported Endpoint";
const string AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR =
"Invalid Prompt, cannot be null or empty";
const string AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR =
"Invalid override, 'messages' cannot be overriden";
const string AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR =
"Invalid override, 'n' must be of integer type and have value 1";
string AiFunctions::ai_api_key_;
const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER =
"Content-Type: application/json";
// other constants
static const StringVal NULL_STRINGVAL = StringVal::null();
static const char* AI_API_ENDPOINT_PREFIX = "https://";
static const char* OPEN_AI_AZURE_ENDPOINT = "openai.azure.com";
static const char* OPEN_AI_PUBLIC_ENDPOINT = "api.openai.com";
// OPEN AI specific constants
static const char* OPEN_AI_RESPONSE_FIELD_CHOICES = "choices";
static const char* OPEN_AI_RESPONSE_FIELD_MESSAGE = "message";
static const char* OPEN_AI_RESPONSE_FIELD_CONTENT = "content";
bool AiFunctions::is_api_endpoint_valid(const std::string_view& endpoint) {
// Simple validation for endpoint. It should start with https://
return (strncaseprefix(endpoint.data(), endpoint.size(), AI_API_ENDPOINT_PREFIX,
sizeof(AI_API_ENDPOINT_PREFIX))
!= nullptr);
}
bool AiFunctions::is_api_endpoint_supported(const std::string_view& endpoint) {
// Only OpenAI endpoints are supported.
return (
gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) != nullptr ||
gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) != nullptr);
}
std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse(
const std::string_view& response) {
rapidjson::Document document;
document.Parse(response.data(), response.size());
// Check for parse errors
if (document.HasParseError()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
return AI_GENERATE_TXT_JSON_PARSE_ERROR;
}
// Check if the "choices" array exists and is not empty
if (!document.HasMember(OPEN_AI_RESPONSE_FIELD_CHOICES)
|| !document[OPEN_AI_RESPONSE_FIELD_CHOICES].IsArray()
|| document[OPEN_AI_RESPONSE_FIELD_CHOICES].Empty()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
return AI_GENERATE_TXT_JSON_PARSE_ERROR;
}
// Access the first element of the "choices" array
const rapidjson::Value& firstChoice = document[OPEN_AI_RESPONSE_FIELD_CHOICES][0];
// Check if the "message" object exists
if (!firstChoice.HasMember(OPEN_AI_RESPONSE_FIELD_MESSAGE)
|| !firstChoice[OPEN_AI_RESPONSE_FIELD_MESSAGE].IsObject()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
return AI_GENERATE_TXT_JSON_PARSE_ERROR;
}
// Access the "content" field within "message"
const rapidjson::Value& message = firstChoice[OPEN_AI_RESPONSE_FIELD_MESSAGE];
if (!message.HasMember(OPEN_AI_RESPONSE_FIELD_CONTENT)
|| !message[OPEN_AI_RESPONSE_FIELD_CONTENT].IsString()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
return AI_GENERATE_TXT_JSON_PARSE_ERROR;
}
const rapidjson::Value& result = message[OPEN_AI_RESPONSE_FIELD_CONTENT];
return std::string_view(result.GetString(), result.GetStringLength());
}
StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& endpoint,
const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params) {
return AiGenerateTextInternal<false>(
ctx, endpoint, prompt, model, api_key_jceks_secret, params, false);
}
StringVal AiFunctions::AiGenerateTextDefault(
FunctionContext* ctx, const StringVal& prompt) {
return AiGenerateTextInternal<true>(
ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL, NULL_STRINGVAL, false);
}
} // namespace impala