| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <gutil/strings/util.h> |
| #include <rapidjson/document.h> |
| #include <rapidjson/error/en.h> |
| #include <rapidjson/stringbuffer.h> |
| #include <rapidjson/writer.h> |
| |
| #include "common/compiler-util.h" |
| #include "exprs/ai-functions.h" |
| #include "kudu/util/curl_util.h" |
| #include "kudu/util/faststring.h" |
| #include "kudu/util/flag_tags.h" |
| #include "kudu/util/monotime.h" |
| #include "kudu/util/status.h" |
| #include "runtime/exec-env.h" |
| #include "service/frontend.h" |
| |
| using namespace rapidjson; |
| using namespace impala_udf; |
| |
| DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions", |
| "The default API endpoint for an external AI engine."); |
| DEFINE_validator(ai_endpoint, [](const char* name, const string& endpoint) { |
| return impala::AiFunctions::is_api_endpoint_valid(endpoint); |
| }); |
| |
| DEFINE_string(ai_model, "gpt-4", "The default AI model used by an external AI engine."); |
| |
| DEFINE_string(ai_api_key_jceks_secret, "", |
| "The jceks secret key used for extracting the api key from configured keystores. " |
| "'hadoop.security.credential.provider.path' in core-site must be configured to " |
| "include the keystore storing the corresponding secret."); |
| |
| DEFINE_string(ai_additional_platforms, "", |
| "A comma-separated list of additional platforms allowed for Impala to access via " |
| "the AI api, formatted as 'site1,site2'."); |
| |
| DEFINE_int32(ai_connection_timeout_s, 10, |
| "(Advanced) The time in seconds for connection timed out when communicating with an " |
| "external AI engine"); |
| TAG_FLAG(ai_api_key_jceks_secret, sensitive); |
| |
| namespace impala { |
| |
| // Set an error message in the context, causing the query to fail. |
| #define SET_ERROR(ctx, status_str, prefix) \ |
| do { \ |
| (ctx)->SetError((prefix + status_str).c_str()); \ |
| } while (false) |
| |
| // Check the status and return an error if it fails. |
| #define RETURN_STRINGVAL_IF_ERROR(ctx, stmt) \ |
| do { \ |
| const ::impala::Status& _status = (stmt); \ |
| if (UNLIKELY(!_status.ok())) { \ |
| SET_ERROR(ctx, _status.msg().msg(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX); \ |
| return StringVal::null(); \ |
| } \ |
| } while (false) |
| |
| // Impala Ai Functions Options Constants. |
| static const char* IMPALA_AI_API_STANDARD_FIELD = "api_standard"; |
| static const char* IMPALA_AI_CREDENTIAL_TYPE_FIELD = "credential_type"; |
| static const char* IMPALA_AI_PAYLOAD_FIELD = "payload"; |
| static const char* IMPALA_AI_API_STANDARD_OPENAI = "openai"; |
| static const char* IMPALA_AI_CREDENTIAL_TYPE_PLAIN = "plain"; |
| static const char* IMPALA_AI_CREDENTIAL_TYPE_JCEKS = "jceks"; |
| static const int MAX_CUSTOM_PAYLOAD_LENGTH = 5 * 1024 * 1024; // 5MB |
| |
| static const size_t IMPALA_AI_API_STANDARD_OPENAI_LEN = |
| std::strlen(IMPALA_AI_API_STANDARD_OPENAI); |
| static const size_t IMPALA_AI_CREDENTIAL_TYPE_PLAIN_LEN = |
| std::strlen(IMPALA_AI_CREDENTIAL_TYPE_PLAIN); |
| static const size_t IMPALA_AI_CREDENTIAL_TYPE_JCEKS_LEN = |
| std::strlen(IMPALA_AI_CREDENTIAL_TYPE_JCEKS); |
| |
| template <AiFunctions::AI_PLATFORM platform> |
| Status getAuthorizationHeader(string& authHeader, const std::string_view& api_key, |
| const AiFunctions::AiFunctionsOptions& ai_options) { |
| const char* header_prefix = nullptr; |
| switch (platform) { |
| case AiFunctions::AI_PLATFORM::OPEN_AI: |
| header_prefix = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER; |
| break; |
| case AiFunctions::AI_PLATFORM::AZURE_OPEN_AI: |
| header_prefix = AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER; |
| break; |
| case AiFunctions::AI_PLATFORM::GENERAL: |
| // For the general platform, only support OPEN_AI api standard for now. |
| if (ai_options.api_standard == AiFunctions::API_STANDARD::OPEN_AI) { |
| header_prefix = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER; |
| break; |
| } |
| default: |
| DCHECK(false) << "AiGenerateTextInternal should only be called for Supported " |
| "Platforms and Standard"; |
| return Status(AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR); |
| } |
| DCHECK(header_prefix != nullptr); |
| authHeader = header_prefix; |
| authHeader.append(api_key); |
| return Status::OK(); |
| } |
| |
| static void ParseImpalaOptions(const StringVal& options, Document& document, |
| AiFunctions::AiFunctionsOptions& result) { |
| // If options is NULL or empty, return with defaults. |
| if (options.is_null || options.len == 0) return; |
| |
| if (document.Parse(reinterpret_cast<const char*>(options.ptr), options.len) |
| .HasParseError()) { |
| std::stringstream ss; |
| ss << "Error parsing impala options: " |
| << string(reinterpret_cast<const char*>(options.ptr), options.len) |
| << ", error code: " << document.GetParseError() << ", offset input " |
| << document.GetErrorOffset(); |
| throw std::runtime_error(ss.str()); |
| } |
| // Check for "api_standard" field. |
| if (document.HasMember(IMPALA_AI_API_STANDARD_FIELD) |
| && document[IMPALA_AI_API_STANDARD_FIELD].IsString()) { |
| const char* api_standard_value = document[IMPALA_AI_API_STANDARD_FIELD].GetString(); |
| if (gstrncasestr(IMPALA_AI_API_STANDARD_OPENAI, api_standard_value, |
| IMPALA_AI_API_STANDARD_OPENAI_LEN) != nullptr) { |
| result.api_standard = AiFunctions::API_STANDARD::OPEN_AI; |
| } else { |
| result.api_standard = AiFunctions::API_STANDARD::UNSUPPORTED; |
| } |
| } |
| |
| // Check for "credential_type" field. |
| if (document.HasMember(IMPALA_AI_CREDENTIAL_TYPE_FIELD) |
| && document[IMPALA_AI_CREDENTIAL_TYPE_FIELD].IsString()) { |
| const char* credential_type_value = |
| document[IMPALA_AI_CREDENTIAL_TYPE_FIELD].GetString(); |
| if (gstrncasestr(IMPALA_AI_CREDENTIAL_TYPE_PLAIN, credential_type_value, |
| IMPALA_AI_CREDENTIAL_TYPE_PLAIN_LEN) != nullptr) { |
| result.credential_type = AiFunctions::CREDENTIAL_TYPE::PLAIN; |
| } else if (gstrncasestr(IMPALA_AI_CREDENTIAL_TYPE_JCEKS, credential_type_value, |
| IMPALA_AI_CREDENTIAL_TYPE_JCEKS_LEN) != nullptr) { |
| result.credential_type = AiFunctions::CREDENTIAL_TYPE::JCEKS; |
| } |
| } |
| |
| // Check for "payload" field. |
| if (document.HasMember(IMPALA_AI_PAYLOAD_FIELD) |
| && document[IMPALA_AI_PAYLOAD_FIELD].IsString()) { |
| const char* payload_value = document[IMPALA_AI_PAYLOAD_FIELD].GetString(); |
| result.ai_custom_payload = std::string_view(payload_value); |
| // Check if payload exceeds the maximum allowed length of custom payload. |
| if (result.ai_custom_payload.length() > MAX_CUSTOM_PAYLOAD_LENGTH) { |
| std::stringstream ss; |
| ss << "Error: custom payload can't be longer than " << MAX_CUSTOM_PAYLOAD_LENGTH |
| << " bytes. Current length: " << result.ai_custom_payload.length(); |
| result.ai_custom_payload = std::string_view(); |
| throw std::runtime_error(ss.str()); |
| } |
| } |
| } |
| |
| template <bool fastpath, AiFunctions::AI_PLATFORM platform> |
| StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, |
| const std::string_view& endpoint_sv, const StringVal& prompt, const StringVal& model, |
| const StringVal& auth_credential, const StringVal& platform_params, |
| const StringVal& impala_options, const bool dry_run) { |
| // Generate the header for the POST request |
| vector<string> headers; |
| headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER); |
| string authHeader; |
| AiFunctions::AiFunctionsOptions ai_options; |
| Document impala_options_document; |
| |
| if (!fastpath) { |
| try { |
| ParseImpalaOptions(impala_options, impala_options_document, ai_options); |
| } catch (const std::runtime_error& e) { |
| std::stringstream ss; |
| ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what(); |
| LOG(WARNING) << ss.str(); |
| const Status err_status(ss.str()); |
| RETURN_STRINGVAL_IF_ERROR(ctx, err_status); |
| } |
| } |
| |
| if (!fastpath && auth_credential.ptr != nullptr && auth_credential.len != 0) { |
| if (ai_options.credential_type == CREDENTIAL_TYPE::PLAIN) { |
| // Use the credential as a plain text token. |
| std::string_view token( |
| reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len); |
| RETURN_STRINGVAL_IF_ERROR( |
| ctx, getAuthorizationHeader<platform>(authHeader, token, ai_options)); |
| } else { |
| DCHECK(ai_options.credential_type == CREDENTIAL_TYPE::JCEKS); |
| // Use the credential as JCEKS secret and fetch API key. |
| string api_key; |
| string api_key_secret( |
| reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len); |
| RETURN_STRINGVAL_IF_ERROR(ctx, |
| ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore( |
| api_key_secret, &api_key)); |
| RETURN_STRINGVAL_IF_ERROR( |
| ctx, getAuthorizationHeader<platform>(authHeader, api_key, ai_options)); |
| } |
| } else { |
| RETURN_STRINGVAL_IF_ERROR( |
| ctx, getAuthorizationHeader<platform>(authHeader, ai_api_key_, ai_options)); |
| } |
| headers.emplace_back(authHeader); |
| |
| string payload_str; |
| if (!fastpath && !ai_options.ai_custom_payload.empty()) { |
| payload_str = |
| string(ai_options.ai_custom_payload.data(), ai_options.ai_custom_payload.size()); |
| } else { |
| // Generate the payload for the POST request |
| Document payload; |
| payload.SetObject(); |
| Document::AllocatorType& payload_allocator = payload.GetAllocator(); |
| // Azure Open AI endpoint doesn't expect model as a separate param since it's |
| // embedded in the endpoint. The 'deployment_name' below maps to a model. |
| // https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/.. |
| if (platform != AI_PLATFORM::AZURE_OPEN_AI) { |
| if (!fastpath && model.ptr != nullptr && model.len != 0) { |
| payload.AddMember("model", |
| rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len), |
| payload_allocator); |
| } else { |
| payload.AddMember("model", |
| rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()), |
| payload_allocator); |
| } |
| } |
| Value message_array(rapidjson::kArrayType); |
| Value message(rapidjson::kObjectType); |
| message.AddMember("role", "user", payload_allocator); |
| if (prompt.ptr == nullptr || prompt.len == 0) { |
| // Return a string with the invalid prompt error message instead of failing |
| // the query, as the issue may be with the row rather than the configuration |
| // or query. This behavior might be reconsidered later. |
| return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str()); |
| } |
| message.AddMember("content", |
| rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len), |
| payload_allocator); |
| message_array.PushBack(message, payload_allocator); |
| payload.AddMember("messages", message_array, payload_allocator); |
| // Override additional platform params. |
| // Caution: 'payload' might reference data owned by 'overrides'. |
| // To ensure valid access, place 'overrides' outside the 'if' |
| // statement before using 'payload'. |
| Document overrides; |
| if (!fastpath && platform_params.ptr != nullptr && platform_params.len != 0) { |
| overrides.Parse(reinterpret_cast<char*>(platform_params.ptr), platform_params.len); |
| if (overrides.HasParseError()) { |
| std::stringstream ss; |
| ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code " |
| << overrides.GetParseError() << ", offset input " |
| << overrides.GetErrorOffset(); |
| LOG(WARNING) << ss.str(); |
| const Status err_status(ss.str()); |
| RETURN_STRINGVAL_IF_ERROR(ctx, err_status); |
| } |
| for (auto& m : overrides.GetObject()) { |
| if (payload.HasMember(m.name.GetString())) { |
| if (m.name == "messages") { |
| const string error_msg = AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR |
| + ": 'messages' is constructed from 'prompt', cannot be overridden"; |
| LOG(WARNING) << error_msg; |
| const Status err_status(error_msg); |
| RETURN_STRINGVAL_IF_ERROR(ctx, err_status); |
| } else { |
| payload[m.name.GetString()] = m.value; |
| } |
| } else { |
| if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) { |
| const string error_msg = AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR |
| + ": 'n' must be of integer type and have value 1"; |
| LOG(WARNING) << error_msg; |
| const Status err_status(error_msg); |
| RETURN_STRINGVAL_IF_ERROR(ctx, err_status); |
| } |
| payload.AddMember(m.name, m.value, payload_allocator); |
| } |
| } |
| } |
| // Convert payload into string for POST request |
| StringBuffer buffer; |
| Writer<StringBuffer> writer(buffer); |
| payload.Accept(writer); |
| payload_str = string(buffer.GetString(), buffer.GetSize()); |
| } |
| DCHECK(!payload_str.empty()); |
| VLOG(2) << "AI Generate Text: \nendpoint: " << endpoint_sv |
| << " \npayload: " << payload_str; |
| if (UNLIKELY(dry_run)) { |
| std::stringstream post_request; |
| post_request << endpoint_sv; |
| for (auto& header : headers) { |
| post_request << "\n" << header; |
| } |
| post_request << "\n" << payload_str; |
| return StringVal::CopyFrom(ctx, |
| reinterpret_cast<const uint8_t*>(post_request.str().data()), |
| post_request.str().length()); |
| } |
| // Send request to external AI API endpoint |
| kudu::EasyCurl curl; |
| curl.set_timeout(kudu::MonoDelta::FromSeconds(FLAGS_ai_connection_timeout_s)); |
| curl.set_fail_on_http_error(true); |
| kudu::faststring resp; |
| kudu::Status status; |
| if (fastpath) { |
| DCHECK_EQ(std::string_view(FLAGS_ai_endpoint), endpoint_sv); |
| status = curl.PostToURL(FLAGS_ai_endpoint, payload_str, &resp, headers); |
| } else { |
| std::string endpoint_str{endpoint_sv}; |
| status = curl.PostToURL(endpoint_str, payload_str, &resp, headers); |
| } |
| VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString(); |
| if (UNLIKELY(!status.ok())) { |
| SET_ERROR(ctx, status.ToString(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX); |
| return StringVal::null(); |
| } |
| // Parse the JSON response string |
| std::string response = AiGenerateTextParseOpenAiResponse( |
| std::string_view(reinterpret_cast<char*>(resp.data()), resp.size())); |
| VLOG(2) << "AI Generate Text: \nresponse: " << response; |
| StringVal result(ctx, response.length()); |
| if (UNLIKELY(result.is_null)) return StringVal::null(); |
| memcpy(result.ptr, response.data(), response.length()); |
| return result; |
| } |
| |
| // Template instantiations for getAuthorizationHeader function. |
| #define INSTANTIATE_AI_AUTH_HEADER(PLATFORM) \ |
| template Status getAuthorizationHeader<AiFunctions::AI_PLATFORM::PLATFORM>( \ |
| string&, const std::string_view&, const AiFunctions::AiFunctionsOptions&); |
| |
| INSTANTIATE_AI_AUTH_HEADER(UNSUPPORTED) |
| INSTANTIATE_AI_AUTH_HEADER(OPEN_AI) |
| INSTANTIATE_AI_AUTH_HEADER(AZURE_OPEN_AI) |
| INSTANTIATE_AI_AUTH_HEADER(GENERAL) |
| |
| #undef INSTANTIATE_AI_AUTH_HEADER |
| |
| // Template instantiations for AiGenerateTextInternal function. |
| #define INSTANTIATE_AI_GENERATE_TEXT(FASTPATH, PLATFORM) \ |
| template StringVal AiFunctions::AiGenerateTextInternal< \ |
| FASTPATH, AiFunctions::AI_PLATFORM::PLATFORM>( \ |
| FunctionContext*, const std::string_view&, const StringVal&, const StringVal&, \ |
| const StringVal&, const StringVal&, const StringVal&, const bool); |
| |
| #define INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(PLATFORM) \ |
| INSTANTIATE_AI_GENERATE_TEXT(true, PLATFORM) \ |
| INSTANTIATE_AI_GENERATE_TEXT(false, PLATFORM) |
| |
| INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(UNSUPPORTED) |
| INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(OPEN_AI) |
| INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(AZURE_OPEN_AI) |
| INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(GENERAL) |
| |
| #undef INSTANTIATE_AI_GENERATE_TEXT |
| #undef INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM |
| |
| } // namespace impala |