blob: 7a1ecdc0c9a7feb4e7a94c78c4ddc89b1a033388 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <arpa/inet.h>
#include <gen_cpp/PaloInternalService_types.h>
#include <gtest/gtest.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#include <string>
#include <thread>
#include "core/block/block.h"
#include "core/column/column_array.h"
#include "core/column/column_string.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_array.h"
#include "exprs/function/ai/ai_adapter.h"
#include "exprs/function/ai/ai_classify.h"
#include "exprs/function/ai/ai_extract.h"
#include "exprs/function/ai/ai_filter.h"
#include "exprs/function/ai/ai_fix_grammar.h"
#include "exprs/function/ai/ai_generate.h"
#include "exprs/function/ai/ai_mask.h"
#include "exprs/function/ai/ai_sentiment.h"
#include "exprs/function/ai/ai_similarity.h"
#include "exprs/function/ai/ai_summarize.h"
#include "exprs/function/ai/ai_translate.h"
#include "exprs/function/ai/embed.h"
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
namespace doris {
class FunctionAITransportTestHelper : public FunctionAISentiment {
public:
using FunctionAISentiment::do_send_request;
};
class FunctionAIFilterBatchTestHelper : public AIFunction<FunctionAIFilterBatchTestHelper> {
public:
friend class AIFunction<FunctionAIFilterBatchTestHelper>;
static constexpr auto name = "ai_filter";
static constexpr auto system_prompt = FunctionAIFilter::system_prompt;
static constexpr size_t number_of_arguments = FunctionAIFilter::number_of_arguments;
using AIFunction<FunctionAIFilterBatchTestHelper>::execute_batch_request;
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeBool>();
}
MutableColumnPtr create_result_column() const { return ColumnUInt8::create(); }
Status append_batch_results(const std::vector<std::string>& batch_results,
IColumn& col_result) const {
auto& bool_col = assert_cast<ColumnUInt8&>(col_result);
for (const auto& batch_result : batch_results) {
std::string_view trimmed = doris::trim(batch_result);
if (trimmed != "1" && trimmed != "0") {
return Status::RuntimeError("Failed to parse boolean value: " +
std::string(trimmed));
}
bool_col.insert_value(static_cast<UInt8>(trimmed == "1"));
}
return Status::OK();
}
};
class OneShotHttpServer {
public:
OneShotHttpServer(int status_code, std::string response_body)
: _status_code(status_code), _response_body(std::move(response_body)) {
_listen_fd = socket(AF_INET, SOCK_STREAM, 0);
DCHECK_GE(_listen_fd, 0);
int reuse_addr = 1;
int ret = setsockopt(_listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse_addr, sizeof(reuse_addr));
DCHECK_EQ(ret, 0);
sockaddr_in addr {};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
addr.sin_port = 0;
ret = bind(_listen_fd, reinterpret_cast<sockaddr*>(&addr), sizeof(addr));
DCHECK_EQ(ret, 0);
ret = listen(_listen_fd, 1);
DCHECK_EQ(ret, 0);
socklen_t addr_len = sizeof(addr);
ret = getsockname(_listen_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len);
DCHECK_EQ(ret, 0);
_port = ntohs(addr.sin_port);
_thread = std::thread([this] { _serve_once(); });
}
~OneShotHttpServer() {
if (_thread.joinable()) {
_thread.join();
}
if (_listen_fd >= 0) {
close(_listen_fd);
}
}
std::string endpoint() const { return "http://127.0.0.1:" + std::to_string(_port); }
std::string join_and_get_request() {
if (_thread.joinable()) {
_thread.join();
}
return _request;
}
private:
void _serve_once() {
int client_fd = accept(_listen_fd, nullptr, nullptr);
DCHECK_GE(client_fd, 0);
char buffer[4096];
while (true) {
ssize_t read_bytes = recv(client_fd, buffer, sizeof(buffer), 0);
if (read_bytes <= 0) {
break;
}
_request.append(buffer, read_bytes);
if (_request.find("\r\n\r\n") != std::string::npos) {
auto header_end = _request.find("\r\n\r\n");
size_t body_length = 0;
size_t length_pos = _request.find("Content-Length:");
if (length_pos != std::string::npos) {
size_t value_begin = length_pos + sizeof("Content-Length:") - 1;
size_t value_end = _request.find("\r\n", value_begin);
std::string length_str = _request.substr(value_begin, value_end - value_begin);
body_length = std::stoul(length_str);
}
if (_request.size() >= header_end + 4 + body_length) {
break;
}
}
}
std::string status_text = _status_code == 200 ? "OK" : "Internal Server Error";
std::string response = "HTTP/1.1 " + std::to_string(_status_code) + " " + status_text +
"\r\nContent-Type: application/json\r\nContent-Length: " +
std::to_string(_response_body.size()) +
"\r\nConnection: close\r\n\r\n" + _response_body;
size_t sent_bytes = 0;
while (sent_bytes < response.size()) {
ssize_t sent =
send(client_fd, response.data() + sent_bytes, response.size() - sent_bytes, 0);
DCHECK_GT(sent, 0);
sent_bytes += sent;
}
shutdown(client_fd, SHUT_RDWR);
close(client_fd);
}
int _listen_fd = -1;
uint16_t _port = 0;
int _status_code;
std::string _response_body;
std::string _request;
std::thread _thread;
};
TEST(AIFunctionTest, AISummarizeTest) {
FunctionAISummarize function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"This is a test document that needs to be summarized."};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "This is a test document that needs to be summarized.");
}
TEST(AIFunctionTest, AISentimentTest) {
FunctionAISentiment function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"I really enjoyed the doris community!"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "I really enjoyed the doris community!");
}
TEST(AIFunctionTest, AIMaskTest) {
FunctionAIMask function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {
"My email is rarity@example.com and my phone is 123-456-7890"};
std::vector<std::vector<std::string>> labels = {{"email", "phone"}};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto nested_column = ColumnString::create();
auto offsets_column = ColumnOffset64::create();
IColumn::Offset offset = 0;
for (const auto& row : labels) {
for (const auto& value : row) {
nested_column->insert_data(value.data(), value.size());
}
offset += row.size();
offsets_column->insert_value(offset);
}
auto array_column = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({std::move(array_column),
std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>()), "labels"});
ColumnNumbers arguments = {0, 1, 2};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt,
"Labels: [\"email\", \"phone\"]\n"
"Text: My email is rarity@example.com and my phone is 123-456-7890");
}
TEST(AIFunctionTest, AIGenerateTest) {
FunctionAIGenerate function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"Write a poem about spring"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "Write a poem about spring");
}
TEST(AIFunctionTest, AIFixGrammarTest) {
FunctionAIFixGrammar function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"She don't like apples"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "She don't like apples");
}
TEST(AIFunctionTest, AIExtractTest) {
FunctionAIExtract function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"John Smith was born on January 15, 1980"};
std::vector<std::vector<std::string>> labels = {{"name", "birthdate"}};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto nested_column = ColumnString::create();
auto offsets_column = ColumnOffset64::create();
IColumn::Offset offset = 0;
for (const auto& row : labels) {
for (const auto& value : row) {
nested_column->insert_data(value.data(), value.size());
}
offset += row.size();
offsets_column->insert_value(offset);
}
auto array_column = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({std::move(array_column),
std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>()), "labels"});
ColumnNumbers arguments = {0, 1, 2};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt,
"Labels: [\"name\", \"birthdate\"]\n"
"Text: John Smith was born on January 15, 1980");
}
TEST(AIFunctionTest, AIClassifyTest) {
FunctionAIClassify function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"This product exceeded my expectations"};
std::vector<std::vector<std::string>> labels = {{"positive", "negative", "neutral"}};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto nested_column = ColumnString::create();
auto offsets_column = ColumnOffset64::create();
IColumn::Offset offset = 0;
for (const auto& row : labels) {
for (const auto& value : row) {
nested_column->insert_data(value.data(), value.size());
}
offset += row.size();
offsets_column->insert_value(offset);
}
auto array_column = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({std::move(array_column),
std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>()), "labels"});
ColumnNumbers arguments = {0, 1, 2};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt,
"Labels: [\"positive\", \"negative\", \"neutral\"]\n"
"Text: This product exceeded my expectations");
}
TEST(AIFunctionTest, AITranslateTest) {
FunctionAITranslate function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"Hello world"};
std::vector<std::string> target_langs = {"Spanish"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto col_target_lang = ColumnHelper::create_column<DataTypeString>(target_langs);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({std::move(col_target_lang), std::make_shared<DataTypeString>(), "target_lang"});
ColumnNumbers arguments = {0, 1, 2};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt,
"Translate the following text to Spanish.\n"
"Text: Hello world");
}
TEST(AIFunctionTest, AISimilarityTest) {
FunctionAISimilarity function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> text1 = {"I like this dish"};
std::vector<std::string> text2 = {"This dish is very good"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
ColumnNumbers arguments = {0, 1, 2};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "Text 1: I like this dish\nText 2: This dish is very good");
}
TEST(AIFunctionTest, AISimilarityExecuteTest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["0.5"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> text1 = {"I like this dish"};
std::vector<std::string> text2 = {"This dish is very good"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
ColumnNumbers arguments = {0, 1, 2};
size_t result_idx = 3;
auto similarity_func = FunctionAISimilarity::create();
Status exec_status =
similarity_func->execute_impl(ctx.get(), block, arguments, result_idx, text1.size());
ASSERT_TRUE(exec_status.ok());
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AISimilarityTrimWhitespace) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::pair<std::string, float>> test_cases = {
{R"(["0.5"])", 0.5f},
{R"(["1.0"])", 1.0f},
{R"(["0.0"])", 0.0f},
{" " + std::string(R"(["0.5"])"), 0.5f},
{std::string(R"(["0.5"])") + " ", 0.5f},
{" " + std::string(R"(["0.5"])") + " ", 0.5f},
{"\n" + std::string(R"(["0.8"])"), 0.8f},
{std::string(R"(["0.3"])") + "\n", 0.3f},
{"\n" + std::string(R"(["0.7"])") + "\n", 0.7f},
{"\t" + std::string(R"(["0.2"])") + "\t", 0.2f},
{" \n\t" + std::string(R"(["0.9"])") + " \n\t", 0.9f},
{" " + std::string(R"(["0.1"])") + " ", 0.1f},
{"\r\n" + std::string(R"(["0.6"])") + "\r\n", 0.6f}};
for (const auto& test_case : test_cases) {
setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> text1 = {"Test text 1"};
std::vector<std::string> text2 = {"Test text 2"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
ColumnNumbers arguments = {0, 1, 2};
size_t result_idx = 3;
auto similarity_func = FunctionAISimilarity::create();
Status exec_status = similarity_func->execute_impl(ctx.get(), block, arguments, result_idx,
text1.size());
ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" << test_case.first << "'";
const auto& res_col =
assert_cast<const ColumnFloat32&>(*block.get_by_position(result_idx).column);
float val = res_col.get_data()[0];
ASSERT_FLOAT_EQ(val, test_case.second)
<< "Failed for test case: '" << test_case.first
<< "', expected: " << test_case.second << ", got: " << val;
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AISimilarityInvalidValue) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> invalid_cases = {"abc", "1.2x", "", " ", "\n\t", "1,2"};
for (const auto& invalid_value : invalid_cases) {
setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> text1 = {"Test text 1"};
std::vector<std::string> text2 = {"Test text 2"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
ColumnNumbers arguments = {0, 1, 2};
size_t result_idx = 3;
auto similarity_func = FunctionAISimilarity::create();
Status exec_status = similarity_func->execute_impl(ctx.get(), block, arguments, result_idx,
text1.size());
ASSERT_FALSE(exec_status.ok())
<< "Should have failed for invalid value: '" << invalid_value << "'";
ASSERT_NE(exec_status.to_string().find("Failed to parse float value"), std::string::npos);
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AISimilarityBatchExecuteTest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["0.5","1.0","0.0"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> text1 = {"a", "b", "c"};
std::vector<std::string> text2 = {"d", "e", "f"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
ColumnNumbers arguments = {0, 1, 2};
size_t result_idx = 3;
auto similarity_func = FunctionAISimilarity::create();
Status exec_status =
similarity_func->execute_impl(ctx.get(), block, arguments, result_idx, text1.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnFloat32&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 3);
EXPECT_FLOAT_EQ(res_col.get_data()[0], 0.5f);
EXPECT_FLOAT_EQ(res_col.get_data()[1], 1.0f);
EXPECT_FLOAT_EQ(res_col.get_data()[2], 0.0f);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterTest) {
FunctionAIFilter function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"This is a valid sentence."};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "This is a valid sentence.");
}
TEST(AIFunctionTest, AIFilterExecuteTest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["0"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"This is a valid sentence."};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
UInt8 val = res_col.get_data()[0];
ASSERT_TRUE(val == 0);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterExecuteMultipleRows) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["1","1"])", 1);
std::vector<std::string> resources = {"mock_resource", "mock_resource"};
std::vector<std::string> texts = {"This is valid.", "This is also valid."};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
unsetenv("AI_TEST_RESULT");
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 2);
ASSERT_EQ(res_col.get_data()[0], 1);
ASSERT_EQ(res_col.get_data()[1], 1);
}
TEST(AIFunctionTest, AIFilterTrimWhitespace) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::pair<std::string, UInt8>> test_cases = {
{R"(["0"])", 0},
{R"(["1"])", 1},
{" " + std::string(R"(["0"])"), 0},
{std::string(R"(["0"])") + " ", 0},
{" " + std::string(R"(["0"])") + " ", 0},
{"\n" + std::string(R"(["0"])"), 0},
{std::string(R"(["0"])") + "\n", 0},
{"\n" + std::string(R"(["0"])") + "\n", 0},
{"\t" + std::string(R"(["1"])") + "\t", 1},
{" \n\t" + std::string(R"(["1"])") + " \n\t", 1},
{" " + std::string(R"(["1"])") + " ", 1},
{"\r\n" + std::string(R"(["0"])") + "\r\n", 0}};
for (const auto& test_case : test_cases) {
setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"Test input"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" << test_case.first << "'";
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
UInt8 val = res_col.get_data()[0];
ASSERT_EQ(val, test_case.second)
<< "Failed for test case: '" << test_case.first
<< "', expected: " << (int)test_case.second << ", got: " << (int)val;
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterInvalidValue) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> invalid_cases = {
R"(["2"])", R"(["maybe"])", R"(["ok"])", R"([""])", R"(["01"])",
R"(["0.5"])", R"(["sure"])", R"(["truee"])", R"(["falsee"])", R"(["yess"])",
R"(["noo"])", R"(["true"])", R"(["false"])", R"(["yes"])", R"(["no"])",
R"(["TRUE"])", R"(["FALSE"])", R"(["YES"])", "[\"NO\"]"};
for (const auto& invalid_value : invalid_cases) {
setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"Test input"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_FALSE(exec_status.ok())
<< "Should have failed for invalid value: '" << invalid_value << "'";
ASSERT_TRUE(exec_status.to_string().find("Failed to parse boolean value") !=
std::string::npos)
<< "Error message should mention boolean parsing for value: '" << invalid_value
<< "'";
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterBatchExecuteTest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["1","0","1"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"valid text", "invalid text", "valid again"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 3);
EXPECT_EQ(res_col.get_data()[0], 1);
EXPECT_EQ(res_col.get_data()[1], 0);
EXPECT_EQ(res_col.get_data()[2], 1);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterBatchLengthMismatch) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["1","0"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"row1", "row2", "row3"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_TRUE(exec_status.to_string().find("expected 3 items but got 2") != std::string::npos);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterBatchInvalidJson) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> invalid_cases = {"1,0", "{}", "", " "};
for (const auto& invalid_value : invalid_cases) {
setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"row1", "row2"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_FALSE(exec_status.ok())
<< "Should have failed for invalid batch json: '" << invalid_value << "'";
ASSERT_TRUE(
exec_status.to_string().find("Invalid batch result format") != std::string::npos ||
exec_status.to_string().find("expected 2 items but got 1") != std::string::npos);
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterBatchInvalidElement) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> invalid_cases = {R"(["1","2"])", R"(["1",0])", R"(["yes","no"])"};
for (const auto& invalid_value : invalid_cases) {
setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"row1", "row2"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_FALSE(exec_status.ok())
<< "Should have failed for invalid batch element: '" << invalid_value << "'";
ASSERT_TRUE(exec_status.to_string().find("Failed to parse boolean value") !=
std::string::npos ||
exec_status.to_string().find("Invalid batch result format") !=
std::string::npos);
}
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterBatchSplitByWindow) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_ai_context_window_size(128 * 1024);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
auto runtime_state = std::make_unique<MockRuntimeState>(
TUniqueId(), 0, query_options, query_globals, nullptr, query_ctx.get());
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["1"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {std::string(70 * 1024, 'a'), std::string(70 * 1024, 'b'),
std::string(70 * 1024, 'c')};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 3);
EXPECT_EQ(res_col.get_data()[0], 1);
EXPECT_EQ(res_col.get_data()[1], 1);
EXPECT_EQ(res_col.get_data()[2], 1);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterSingleRowExceedsBatchWindow) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_ai_context_window_size(128 * 1024);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
auto runtime_state = std::make_unique<MockRuntimeState>(
TUniqueId(), 0, query_options, query_globals, nullptr, query_ctx.get());
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {std::string(130 * 1024, 'x')};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
// Even if a single row exceeds the batch window, it should be sent as a standalone request.
setenv("AI_TEST_RESULT", "[\"1\"]", 1);
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 1);
EXPECT_EQ(res_col.get_data()[0], 1);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterOversizedRowFlushesHistoryBatchFirst) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_ai_context_window_size(128 * 1024);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
auto runtime_state = std::make_unique<MockRuntimeState>(
TUniqueId(), 0, query_options, query_globals, nullptr, query_ctx.get());
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource"};
std::vector<std::string> texts = {"small row", std::string(130 * 1024, 'x')};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
setenv("AI_TEST_RESULT", R"(["1"])", 1);
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 2);
EXPECT_EQ(res_col.get_data()[0], 1);
EXPECT_EQ(res_col.get_data()[1], 1);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, AIFilterUsesAiContextWindowSizeSessionVariable) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_ai_context_window_size(16);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
auto runtime_state = std::make_unique<MockRuntimeState>(
TUniqueId(), 0, query_options, query_globals, nullptr, query_ctx.get());
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["1"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"12345678901234567890", "abcdefghijabcdefghij"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto filter_func = FunctionAIFilter::create();
Status exec_status =
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& res_col =
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 2);
EXPECT_EQ(res_col.get_data()[0], 1);
EXPECT_EQ(res_col.get_data()[1], 1);
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, ResourceNotFound) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"not_exist_resource"};
std::vector<std::string> texts = {"test"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
ASSERT_DEATH(
{
auto sentiment_func = FunctionAISentiment::create();
Status exec_status = sentiment_func->execute_impl(ctx.get(), block, arguments,
result_idx, texts.size());
static_cast<void>(exec_status);
},
"it != ai_resources->end");
}
TEST(AIFunctionTest, MockResourceSendRequest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"test input"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto sentiment_func = FunctionAISentiment::create();
Status exec_status =
sentiment_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& res_col =
assert_cast<const ColumnString&>(*block.get_by_position(result_idx).column);
StringRef ref = res_col.get_data_at(0);
std::string val(ref.data, ref.size);
ASSERT_EQ(val, "this is a mock response. test input");
}
TEST(AIFunctionTest, AIStringFunctionBatchExecuteTest) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
setenv("AI_TEST_RESULT", R"(["positive","negative","neutral"])", 1);
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"great", "bad", "okay"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto sentiment_func = FunctionAISentiment::create();
Status exec_status =
sentiment_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok());
const auto& res_col =
assert_cast<const ColumnString&>(*block.get_by_position(result_idx).column);
ASSERT_EQ(res_col.size(), 3);
EXPECT_EQ(res_col.get_data_at(0).to_string(), "positive");
EXPECT_EQ(res_col.get_data_at(1).to_string(), "negative");
EXPECT_EQ(res_col.get_data_at(2).to_string(), "neutral");
unsetenv("AI_TEST_RESULT");
}
TEST(AIFunctionTest, MissingAIResourcesMetadataTest) {
auto query_ctx = MockQueryContext::create();
TQueryOptions query_options;
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"test"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
ASSERT_DEATH(
{
auto sentiment_func = FunctionAISentiment::create();
Status exec_status = sentiment_func->execute_impl(ctx.get(), block, arguments,
result_idx, texts.size());
static_cast<void>(exec_status);
},
"ai_resources");
}
TEST(AIFunctionTest, ReturnTypeTest) {
FunctionAIClassify func_classify;
DataTypes args;
DataTypePtr ret_type = func_classify.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAIExtract func_extract;
ret_type = func_extract.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAIFilter func_filter;
ret_type = func_filter.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "BOOL");
FunctionAIFixGrammar func_fix_grammar;
ret_type = func_fix_grammar.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAIGenerate func_generate;
ret_type = func_generate.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAIMask func_mask;
ret_type = func_mask.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAISentiment func_sentiment;
ret_type = func_sentiment.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAISimilarity func_similarity;
ret_type = func_similarity.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "FLOAT");
FunctionAISummarize func_summarize;
ret_type = func_summarize.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionAITranslate func_translate;
ret_type = func_translate.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "String");
FunctionEmbed func_embed;
ret_type = func_embed.get_return_type_impl(args);
ASSERT_TRUE(ret_type != nullptr);
ASSERT_EQ(ret_type->get_family_name(), "Array");
}
class FunctionAISentimentTestHelper : public FunctionAISentiment {
public:
using FunctionAISentiment::normalize_endpoint;
};
class FunctionEmbedTestHelper : public FunctionEmbed {
public:
using FunctionEmbed::normalize_endpoint;
};
TEST(AIFunctionTest, NormalizeLegacyCompletionsEndpoint) {
TAIResource resource;
resource.endpoint = "https://api.openai.com/v1/completions";
FunctionAISentimentTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions");
}
TEST(AIFunctionTest, NormalizeEndpointNoopForOtherPaths) {
TAIResource resource;
resource.endpoint = "https://api.openai.com/v1/chat/completions";
FunctionAISentimentTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint, "https://api.openai.com/v1/chat/completions");
resource.endpoint = "https://localhost/v1/responses";
FunctionAISentimentTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint, "https://localhost/v1/responses");
}
TEST(AIFunctionTest, NormalizeGeminiGenerateEndpointFromBaseVersion) {
TAIResource resource;
resource.provider_type = "gemini";
resource.model_name = "gemini-pro";
resource.endpoint = "https://generativelanguage.googleapis.com/v1beta";
FunctionAISentimentTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint,
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent");
}
TEST(AIFunctionTest, NormalizeGeminiEmbedEndpointFromBaseVersion) {
TAIResource resource;
resource.provider_type = "GEMINI";
resource.model_name = "gemini-embedding-2-preview";
resource.endpoint = "https://generativelanguage.googleapis.com/v1beta";
FunctionEmbedTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint,
"https://generativelanguage.googleapis.com/v1beta/models/"
"gemini-embedding-2-preview:batchEmbedContents");
}
TEST(AIFunctionTest, NormalizeGeminiEndpointNoopForNonBasePath) {
TAIResource resource;
resource.provider_type = "gemini";
resource.model_name = "gemini-pro";
resource.endpoint =
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent";
FunctionAISentimentTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint,
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent");
}
TEST(AIFunctionTest, NormalizeGeminiEmbedLegacySingleEndpointToBatchEndpoint) {
TAIResource resource;
resource.provider_type = "gemini";
resource.model_name = "gemini-embedding-2-preview";
resource.endpoint =
"https://generativelanguage.googleapis.com/v1beta/models/"
"gemini-embedding-2-preview:embedContent";
FunctionEmbedTestHelper::normalize_endpoint(resource);
ASSERT_EQ(resource.endpoint,
"https://generativelanguage.googleapis.com/v1beta/models/"
"gemini-embedding-2-preview:batchEmbedContents");
}
TEST(AIFunctionTest, ExecuteBatchRequestSuccess) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_query_timeout(5);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
OneShotHttpServer server(200, R"({"choices":[{"message":{"content":"[\"1\",\"0\"]"}}]})");
TAIResource config;
config.endpoint = server.endpoint();
config.provider_type = "OPENAI";
config.model_name = "test-model";
config.api_key = "secret";
config.max_retries = 1;
std::shared_ptr<AIAdapter> adapter = std::make_shared<OpenAIAdapter>();
adapter->init(config);
FunctionAIFilterBatchTestHelper helper;
std::vector<std::string> results;
Status st = helper.execute_batch_request({"first row", "second row"}, results, config, adapter,
ctx.get());
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 2);
EXPECT_EQ(results[0], "1");
EXPECT_EQ(results[1], "0");
std::string request = server.join_and_get_request();
ASSERT_NE(request.find("Authorization: Bearer secret"), std::string::npos);
ASSERT_NE(
request.find(
R"([{\"idx\":0,\"input\":\"first row\"},{\"idx\":1,\"input\":\"second row\"}])"),
std::string::npos);
}
TEST(AIFunctionTest, ExecuteBatchRequestResultSizeMismatch) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_query_timeout(5);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
OneShotHttpServer server(200, R"({"choices":[{"message":{"content":"[\"1\"]"}}]})");
TAIResource config;
config.endpoint = server.endpoint();
config.provider_type = "OPENAI";
config.model_name = "test-model";
config.api_key = "secret";
config.max_retries = 1;
std::shared_ptr<AIAdapter> adapter = std::make_shared<OpenAIAdapter>();
adapter->init(config);
FunctionAIFilterBatchTestHelper helper;
std::vector<std::string> results;
Status st = helper.execute_batch_request({"first row", "second row"}, results, config, adapter,
ctx.get());
ASSERT_FALSE(st.ok());
ASSERT_NE(st.to_string().find(
"Failed to parse ai_filter batch result, expected 2 items but got 1"),
std::string::npos);
}
TEST(AIFunctionTest, DoSendRequestTransportError) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_query_timeout(5);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
TAIResource config;
config.endpoint = "http://127.0.0.1:1";
config.provider_type = "OPENAI";
config.model_name = "test-model";
config.api_key = "secret";
std::shared_ptr<AIAdapter> adapter = std::make_shared<OpenAIAdapter>();
adapter->init(config);
HttpClient client;
std::string response;
FunctionAITransportTestHelper helper;
Status st = helper.do_send_request(&client, "{}", response, config, adapter, ctx.get());
ASSERT_FALSE(st.ok());
ASSERT_EQ(response, "");
}
TEST(AIFunctionTest, DoSendRequestNon200) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_query_timeout(5);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
OneShotHttpServer server(500, R"({"error":"bad request"})");
TAIResource config;
config.endpoint = server.endpoint();
config.provider_type = "OPENAI";
config.model_name = "test-model";
config.api_key = "secret";
std::shared_ptr<AIAdapter> adapter = std::make_shared<OpenAIAdapter>();
adapter->init(config);
HttpClient client;
std::string response;
FunctionAITransportTestHelper helper;
Status st = helper.do_send_request(&client, R"({"message":"hello"})", response, config, adapter,
ctx.get());
ASSERT_FALSE(st.ok());
ASSERT_NE(st.to_string().find("http status code is not 200"), std::string::npos);
ASSERT_EQ(response, R"({"error":"bad request"})");
std::string request = server.join_and_get_request();
ASSERT_NE(request.find("Authorization: Bearer secret"), std::string::npos);
ASSERT_NE(request.find("Content-Type: application/json"), std::string::npos);
}
TEST(AIFunctionTest, DoSendRequestSuccess) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_query_timeout(5);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
OneShotHttpServer server(200, R"({"ok":true})");
TAIResource config;
config.endpoint = server.endpoint();
config.provider_type = "OPENAI";
config.model_name = "test-model";
config.api_key = "secret";
std::shared_ptr<AIAdapter> adapter = std::make_shared<OpenAIAdapter>();
adapter->init(config);
HttpClient client;
std::string response;
FunctionAITransportTestHelper helper;
Status st = helper.do_send_request(&client, R"({"message":"hello"})", response, config, adapter,
ctx.get());
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(response, R"({"ok":true})");
std::string request = server.join_and_get_request();
ASSERT_NE(request.find("POST "), std::string::npos);
ASSERT_NE(request.find(R"({"message":"hello"})"), std::string::npos);
}
} // namespace doris