IMPALA-13004: Fix heap-use-after-free error in ExprTest AiFunctionsTest
The issue is that the code previously used a std::string_view to
hold the data which is actually returned by rapidjson::Document.
However, the rapidjson::Document object gets destroyed after
creating the std::string_view. This meant the std::string_view
referenced memory that was no longer valid, leading to a
heap-use-after-free error.
This patch fixes this issue by modifying the function to
return a std::string instead of a std::string_view. When the
function returns a string, it creates a copy of the
data from rapidjson::Document. This ensures the returned
string has its own memory allocation and doesn't rely on
the destroyed rapidjson::Document.
Tests:
Reran the asan build and passed.
Change-Id: I3bb9dcf9d72cce7ad37d5bc25821cf6ee55a8ab5
Reviewed-on: http://gerrit.cloudera.org:8080/21315
Reviewed-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index e482cb6..6def1a0 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -85,7 +85,7 @@
gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) != nullptr);
}
-std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse(
+string AiFunctions::AiGenerateTextParseOpenAiResponse(
const std::string_view& response) {
rapidjson::Document document;
document.Parse(response.data(), response.size());
@@ -120,8 +120,7 @@
return AI_GENERATE_TXT_JSON_PARSE_ERROR;
}
- const rapidjson::Value& result = message[OPEN_AI_RESPONSE_FIELD_CONTENT];
- return std::string_view(result.GetString(), result.GetStringLength());
+ return message[OPEN_AI_RESPONSE_FIELD_CONTENT].GetString();
}
StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& endpoint,
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index c1d2e63..0e6396b 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -64,7 +64,7 @@
const StringVal& api_key_jceks_secret, const StringVal& params, const bool dry_run);
/// Internal helper function for parsing OPEN AI's API response. Input parameter is the
/// json representation of the OPEN AI's API response.
- static std::string_view AiGenerateTextParseOpenAiResponse(
+ static std::string AiGenerateTextParseOpenAiResponse(
const std::string_view& reponse);
friend class ExprTest_AiFunctionsTest_Test;
diff --git a/be/src/exprs/ai-functions.inline.h b/be/src/exprs/ai-functions.inline.h
index 7f7bcfd..9f143e2 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.inline.h
@@ -103,9 +103,12 @@
payload_allocator);
message_array.PushBack(message, payload_allocator);
payload.AddMember("messages", message_array, payload_allocator);
- // Override additional params
+ // Override additional params.
+ // Caution: 'payload' might reference data owned by 'overrides'.
+ // To ensure valid access, place 'overrides' outside the 'if'
+ // statement before using 'payload'.
+ Document overrides;
if (!fastpath && params.ptr != nullptr && params.len != 0) {
- Document overrides;
overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len);
if (overrides.HasParseError()) {
LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
@@ -172,7 +175,7 @@
ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
}
// Parse the JSON response string
- std::string_view response = AiGenerateTextParseOpenAiResponse(
+ std::string response = AiGenerateTextParseOpenAiResponse(
std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
VLOG(2) << "AI Generate Text: \nresponse: " << response;
StringVal result(ctx, response.length());
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index a056150..ec326ed 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -11336,12 +11336,12 @@
<< "\"total_tokens\": 73"
<< "},"
<< "\"system_fingerprint\": null}";
- std::string_view res = AiFunctions::AiGenerateTextParseOpenAiResponse(response.str());
+ std::string res = AiFunctions::AiGenerateTextParseOpenAiResponse(response.str());
string from_null("(\'\\\\0\')");
string to_null("(\'\\0\')");
size_t pos = content.find(from_null);
content.replace(pos, from_null.length(), to_null);
- EXPECT_EQ(string(res), content);
+ EXPECT_EQ(res, content);
// resource cleanup
pool.FreeAll();