blob: 7f7bcfd924e40223f0e4593cb84c6602a67f067f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <rapidjson/document.h>
#include <rapidjson/error/en.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include "common/compiler-util.h"
#include "exprs/ai-functions.h"
#include "kudu/util/curl_util.h"
#include "kudu/util/faststring.h"
#include "kudu/util/flag_tags.h"
#include "kudu/util/monotime.h"
#include "kudu/util/status.h"
#include "runtime/exec-env.h"
#include "service/frontend.h"
using namespace rapidjson;
using namespace impala_udf;
DECLARE_string(ai_endpoint);
DECLARE_string(ai_model);
DECLARE_string(ai_api_key_jceks_secret);
DECLARE_int32(ai_connection_timeout_s);
namespace impala {
template <bool fastpath>
StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params, const bool dry_run) {
std::string_view endpoint_sv(FLAGS_ai_endpoint);
// endpoint validation
if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr), endpoint.len);
// Simple validation for endpoint. It should start with https://
if (!is_api_endpoint_valid(endpoint_sv)) {
LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
}
// Only OpenAI endpoints are supported.
if (!is_api_endpoint_supported(endpoint_sv)) {
LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " << endpoint_sv;
return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
}
}
// Generate the header for the POST request
vector<string> headers;
headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
if (!fastpath && api_key_jceks_secret.ptr != nullptr && api_key_jceks_secret.len != 0) {
string api_key;
string api_key_secret(
reinterpret_cast<char*>(api_key_jceks_secret.ptr), api_key_jceks_secret.len);
Status status = ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
api_key_secret, &api_key);
if (!status.ok()) {
return StringVal::CopyFrom(ctx,
reinterpret_cast<const uint8_t*>(status.msg().msg().c_str()),
status.msg().msg().length());
}
headers.emplace_back("Authorization: Bearer " + api_key);
} else {
headers.emplace_back("Authorization: Bearer " + ai_api_key_);
}
// Generate the payload for the POST request
Document payload;
payload.SetObject();
Document::AllocatorType& payload_allocator = payload.GetAllocator();
if (!fastpath && model.ptr != nullptr && model.len != 0) {
payload.AddMember("model",
rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
payload_allocator);
} else {
payload.AddMember("model",
rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()),
payload_allocator);
}
Value message_array(rapidjson::kArrayType);
Value message(rapidjson::kObjectType);
message.AddMember("role", "user", payload_allocator);
if (prompt.ptr == nullptr || prompt.len == 0) {
return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
}
message.AddMember("content",
rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len),
payload_allocator);
message_array.PushBack(message, payload_allocator);
payload.AddMember("messages", message_array, payload_allocator);
// Override additional params
if (!fastpath && params.ptr != nullptr && params.len != 0) {
Document overrides;
overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len);
if (overrides.HasParseError()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
<< overrides.GetParseError() << ", offset input "
<< overrides.GetErrorOffset();
return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
}
for (auto& m : overrides.GetObject()) {
if (payload.HasMember(m.name.GetString())) {
if (m.name == "messages") {
LOG(WARNING)
<< AI_GENERATE_TXT_JSON_PARSE_ERROR
<< ": 'messages' is constructed from 'prompt', cannot be overridden";
return StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str());
} else {
payload[m.name.GetString()] = m.value;
}
} else {
if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
LOG(WARNING)
<< AI_GENERATE_TXT_JSON_PARSE_ERROR
<< ": 'n' must be of integer type and have value 1";
return StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str());
}
payload.AddMember(m.name, m.value, payload_allocator);
}
}
}
// Convert payload into string for POST request
StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
payload.Accept(writer);
string payload_str(buffer.GetString(), buffer.GetSize());
VLOG(2) << "AI Generate Text: \nendpoint: " << endpoint_sv
<< " \npayload: " << payload_str;
if (UNLIKELY(dry_run)) {
std::stringstream post_request;
post_request << endpoint_sv;
for (auto& header : headers) {
post_request << "\n" << header;
}
post_request << "\n" << payload_str;
return StringVal::CopyFrom(ctx,
reinterpret_cast<const uint8_t*>(post_request.str().data()),
post_request.str().length());
}
// Send request to external AI API endpoint
kudu::EasyCurl curl;
curl.set_timeout(kudu::MonoDelta::FromSeconds(FLAGS_ai_connection_timeout_s));
curl.set_fail_on_http_error(true);
kudu::faststring resp;
kudu::Status status;
if (fastpath) {
DCHECK_EQ(std::string_view(FLAGS_ai_endpoint), endpoint_sv);
status = curl.PostToURL(FLAGS_ai_endpoint, payload_str, &resp, headers);
} else {
std::string endpoint_str{endpoint_sv};
status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
}
VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
if (!status.ok()) {
string msg = status.ToString();
return StringVal::CopyFrom(
ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
}
// Parse the JSON response string
std::string_view response = AiGenerateTextParseOpenAiResponse(
std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
VLOG(2) << "AI Generate Text: \nresponse: " << response;
StringVal result(ctx, response.length());
if (UNLIKELY(result.is_null)) return StringVal::null();
memcpy(result.ptr, response.data(), response.length());
return result;
}
} // namespace impala