blob: 2cfda9d715436c67eac4ee30c85aa5fd35e967bf [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "exprs/function/ai/ai_adapter.h"
#include <curl/curl.h>
#include <gen_cpp/PaloInternalService_types.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "common/status.h"
#include "exprs/function/ai/ai_classify.h"
#include "exprs/function/ai/ai_extract.h"
#include "exprs/function/ai/ai_sentiment.h"
#include "exprs/function/ai/ai_summarize.h"
namespace doris {
class MockHttpClient : public HttpClient {
public:
curl_slist* get() { return this->_header_list; }
private:
std::unordered_map<std::string, std::string> _headers;
std::string _content_type;
};
TEST(AI_ADAPTER_TEST, local_adapter_request_chat_endpoint) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "ollama";
config.temperature = 0.7;
config.max_tokens = 128;
config.endpoint = "http://localhost:11434/api/chat";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"hello world"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// general flags
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "ollama");
ASSERT_TRUE(doc.HasMember("stream")) << "Missing stream field";
ASSERT_TRUE(doc["stream"].IsBool()) << "Stream field is not a bool";
ASSERT_FALSE(doc["stream"].GetBool());
ASSERT_TRUE(doc.HasMember("think")) << "Missing think field";
ASSERT_TRUE(doc["think"].IsBool()) << "Think field is not a bool";
ASSERT_FALSE(doc["think"].GetBool());
// options (temperature + max_token)
ASSERT_FALSE(doc.HasMember("temperature")) << "Temperature should be nested in options";
ASSERT_FALSE(doc.HasMember("max_tokens")) << "Max tokens should be nested in options";
ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
const auto& options = doc["options"];
ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field";
ASSERT_TRUE(options["temperature"].IsNumber()) << "options.temperature is not a number";
ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.7);
ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field";
ASSERT_TRUE(options["max_token"].IsInt()) << "options.max_token is not an integer";
ASSERT_EQ(options["max_token"].GetInt(), 128);
// content
ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field";
ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
ASSERT_GE(doc["messages"].Size(), 2) << "Messages should contain system and user prompts";
const auto& first_message = doc["messages"][0];
ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field";
ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string";
ASSERT_STREQ(first_message["role"].GetString(), "system");
ASSERT_TRUE(first_message.HasMember("content")) << "Message missing content field";
ASSERT_TRUE(first_message["content"].IsString()) << "Content field is not a string";
ASSERT_STREQ(first_message["content"].GetString(), FunctionAISummarize::system_prompt);
const auto& user_message = doc["messages"][doc["messages"].Size() - 1];
ASSERT_TRUE(user_message.HasMember("role")) << "User message missing role field";
ASSERT_TRUE(user_message["role"].IsString()) << "User role field is not a string";
ASSERT_STREQ(user_message["role"].GetString(), "user");
ASSERT_TRUE(user_message.HasMember("content")) << "User message missing content field";
ASSERT_TRUE(user_message["content"].IsString()) << "User content field is not a string";
ASSERT_STREQ(user_message["content"].GetString(), inputs[0].c_str());
}
TEST(AI_ADAPTER_TEST, local_adapter_request_generate_endpoint) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "ollama";
config.temperature = 0.8;
config.max_tokens = 64;
config.endpoint = "http://localhost:11434/api/generate";
adapter.init(config);
std::vector<std::string> inputs = {"hello world"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body);
ASSERT_TRUE(st.ok());
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_STREQ(doc["model"].GetString(), "ollama");
ASSERT_TRUE(doc.HasMember("system")) << "Missing system field";
ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string";
ASSERT_STREQ(doc["system"].GetString(), FunctionAISummarize::system_prompt);
ASSERT_TRUE(doc.HasMember("prompt")) << "Missing prompt field";
ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a string";
ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str());
ASSERT_FALSE(doc.HasMember("messages")) << "Generate endpoint should not include messages";
ASSERT_TRUE(doc.HasMember("options")) << "Missing options field";
ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object";
const auto& options = doc["options"];
ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field";
ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.8);
ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field";
ASSERT_EQ(options["max_token"].GetInt(), 64);
}
TEST(AI_ADAPTER_TEST, local_adapter_parse_response) {
LocalAdapter adapter;
// OpenAI type
std::string resp = R"({"choices":[{"message":{"content":"hi"}}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "hi");
// Simple text type
resp = R"({"text":"simple result"})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "simple result");
resp = R"({"content":"simple result"})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "simple result");
// Ollama response type
resp = R"({"response":"ollama result"})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "ollama result");
// Ollama chat message type
resp = R"({"message":{"content":"ollama chat"}})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "ollama chat");
}
TEST(AI_ADAPTER_TEST, local_adapter_request_default_endpoint) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "local-default";
config.temperature = 0.3;
config.max_tokens = 42;
config.endpoint = "http://localhost:8000/v1/completions";
adapter.init(config);
std::vector<std::string> inputs = {"default prompt"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("model"));
ASSERT_STREQ(doc["model"].GetString(), "local-default");
ASSERT_TRUE(doc.HasMember("temperature"));
ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.3);
ASSERT_TRUE(doc.HasMember("max_tokens"));
ASSERT_EQ(doc["max_tokens"].GetInt(), 42);
ASSERT_TRUE(doc.HasMember("messages"));
ASSERT_TRUE(doc["messages"].IsArray());
ASSERT_GE(doc["messages"].Size(), 2);
const auto& system_msg = doc["messages"][0];
ASSERT_TRUE(system_msg.HasMember("role"));
ASSERT_STREQ(system_msg["role"].GetString(), "system");
ASSERT_TRUE(system_msg.HasMember("content"));
ASSERT_STREQ(system_msg["content"].GetString(), FunctionAISummarize::system_prompt);
const auto& user_msg = doc["messages"][doc["messages"].Size() - 1];
ASSERT_TRUE(user_msg.HasMember("role"));
ASSERT_STREQ(user_msg["role"].GetString(), "user");
ASSERT_TRUE(user_msg.HasMember("content"));
ASSERT_STREQ(user_msg["content"].GetString(), inputs[0].c_str());
}
TEST(AI_ADAPTER_TEST, openai_adapter_completions_request) {
OpenAIAdapter adapter;
TAIResource config;
config.model_name = "gpt-3.5-turbo";
config.temperature = 0.5;
config.max_tokens = 64;
config.api_key = "test_openai_key";
adapter.init(config);
// header
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"hi openai"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body);
ASSERT_TRUE(st.ok());
// body
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "gpt-3.5-turbo");
// temperature
ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field";
ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number";
ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5);
// max token
ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field";
ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer";
ASSERT_EQ(doc["max_tokens"].GetInt(), 64);
// msg
ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field";
ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty";
// system_prompt
const auto& first_message = doc["messages"][0];
ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field";
ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string";
ASSERT_STREQ(first_message["role"].GetString(), "system");
ASSERT_STREQ(first_message["content"].GetString(), FunctionAISentiment::system_prompt);
// The content of the last message
const auto& last_message = doc["messages"][doc["messages"].Size() - 1];
ASSERT_TRUE(last_message.HasMember("content")) << "Message missing content field";
ASSERT_TRUE(last_message["content"].IsString()) << "Content field is not a string";
ASSERT_STREQ(last_message["content"].GetString(), inputs[0].c_str());
}
TEST(AI_ADAPTER_TEST, openai_adapter_completions_parse_response) {
OpenAIAdapter adapter;
std::string resp = R"({"choices":[{"message":{"content":"openai result"}}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "openai result");
}
TEST(AI_ADAPTER_TEST, openai_adatper_responses_request) {
OpenAIAdapter adapter;
TAIResource config;
config.model_name = "gpt-5";
config.temperature = 0.5;
config.max_tokens = 64;
config.api_key = "test_openai_key";
config.endpoint = "https://api.openai.com/v1/responses";
adapter.init(config);
// header
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"hi openai"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body);
ASSERT_TRUE(st.ok());
// body
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "gpt-5");
// temperature
ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field";
ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number";
ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5);
// max tokens
ASSERT_TRUE(doc.HasMember("max_output_tokens")) << "Missing max_output_tokens field";
ASSERT_TRUE(doc["max_output_tokens"].IsInt()) << "max_output_tokens field is not an integer";
ASSERT_EQ(doc["max_output_tokens"].GetInt(), 64);
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsArray()) << "Input is not an array";
ASSERT_GT(doc["input"].Size(), 0) << "Input array is empty";
// system_prompt
const auto& input = doc["input"];
ASSERT_TRUE(input[0].HasMember("role")) << request_body;
ASSERT_TRUE(input[0]["role"].IsString()) << "Role field is not a string";
ASSERT_STREQ(input[0]["role"].GetString(), "system");
ASSERT_TRUE(input[0].HasMember("content")) << request_body;
ASSERT_TRUE(input[0]["content"].IsString()) << "Content field is not a string";
ASSERT_STREQ(input[0]["content"].GetString(), FunctionAISentiment::system_prompt);
// input content
ASSERT_TRUE(input[1].HasMember("role")) << request_body;
ASSERT_TRUE(input[1]["role"].IsString()) << "Role field is not a string";
ASSERT_STREQ(input[1]["role"].GetString(), "user");
ASSERT_TRUE(input[1].HasMember("content")) << request_body;
ASSERT_TRUE(input[1]["content"].IsString()) << "Content field is not a string";
ASSERT_STREQ(input[1]["content"].GetString(), inputs[0].c_str());
}
TEST(AI_ADAPTER_TEST, openai_adapter_responses_parse_response) {
OpenAIAdapter adapter;
std::string resp = R"({"output":[{"content":[{"text":"openai response result"}]}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "openai response result");
}
TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_keeps_mask_literals) {
OpenAIAdapter adapter;
std::string resp = R"({"choices":[{"message":{"content":"[MSKED]"}}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "[MSKED]");
resp = R"({"choices":[{"message":{"content":"[MASK]"}}]})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "[MASK]");
}
TEST(AI_ADAPTER_TEST, gemini_adapter_request) {
GeminiAdapter adapter;
TAIResource config;
config.temperature = 0.2;
config.max_tokens = 32;
config.api_key = "test_gemini_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "x-goog-api-key: test_gemini_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"hello gemini"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAIExtract::system_prompt, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("generationConfig"));
const auto& gen_cfg = doc["generationConfig"];
//temperature
ASSERT_TRUE(gen_cfg.HasMember("temperature"));
ASSERT_TRUE(gen_cfg["temperature"].IsNumber());
ASSERT_DOUBLE_EQ(gen_cfg["temperature"].GetDouble(), 0.2);
//max_token
ASSERT_TRUE(gen_cfg.HasMember("maxOutputTokens")) << "Missing maxOutputTokens field";
ASSERT_TRUE(gen_cfg["maxOutputTokens"].IsInt());
ASSERT_EQ(gen_cfg["maxOutputTokens"].GetInt(), 32);
// system_prompt
ASSERT_TRUE(doc.HasMember("systemInstruction")) << "Missing system field";
ASSERT_TRUE(doc["systemInstruction"].IsObject()) << request_body;
ASSERT_TRUE(doc["systemInstruction"].HasMember("parts")) << request_body;
ASSERT_TRUE(doc["systemInstruction"]["parts"].IsArray()) << request_body;
ASSERT_GT(doc["systemInstruction"]["parts"].Size(), 0) << request_body;
const auto& content_sys = doc["systemInstruction"]["parts"][0];
ASSERT_TRUE(content_sys.HasMember("text")) << "parts missing text field";
ASSERT_TRUE(content_sys["text"].IsString()) << "Text field is not a string";
ASSERT_STREQ(content_sys["text"].GetString(), FunctionAIExtract::system_prompt);
// content structure
ASSERT_TRUE(doc.HasMember("contents")) << "Missing contents field";
ASSERT_TRUE(doc["contents"].IsArray()) << "Contents is not an array";
ASSERT_GT(doc["contents"].Size(), 0) << "Contents array is empty";
// content
const auto& content = doc["contents"][0];
ASSERT_TRUE(content.HasMember("parts")) << "Content missing parts field";
ASSERT_TRUE(content["parts"].IsArray()) << "Parts is not an array";
ASSERT_GT(content["parts"].Size(), 0) << "Parts array is empty";
const auto& part = content["parts"][0];
ASSERT_TRUE(part.HasMember("text")) << "Part missing text field";
ASSERT_TRUE(part["text"].IsString()) << "Text field is not a string";
ASSERT_STREQ(part["text"].GetString(), "hello gemini");
}
TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response) {
GeminiAdapter adapter;
std::string resp = R"({"candidates":[{"content":{"parts":[{"text":"gemini result"}]}}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "gemini result");
}
TEST(AI_ADAPTER_TEST, anthropic_adapter_request) {
AnthropicAdapter adapter;
TAIResource config;
config.model_name = "claude-3";
config.temperature = 1.0;
config.max_tokens = 256;
config.api_key = "test_anthropic_key";
config.anthropic_version = "2023-06-01";
adapter.init(config);
// header
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "x-api-key: test_anthropic_key");
EXPECT_STREQ(mock_client.get()->next->data, "anthropic-version: 2023-06-01");
EXPECT_STREQ(mock_client.get()->next->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"hi anthropic"};
std::string request_body;
Status st =
adapter.build_request_payload(inputs, FunctionAIClassify::system_prompt, request_body);
ASSERT_TRUE(st.ok());
// body
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "claude-3");
// temperature
ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field";
ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number";
ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 1.0);
// max token
ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field";
ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer";
ASSERT_EQ(doc["max_tokens"].GetInt(), 256);
// system_prompt
ASSERT_TRUE(doc.HasMember("system")) << "Missing system field";
ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string";
ASSERT_STREQ(doc["system"].GetString(), FunctionAIClassify::system_prompt);
// message Format
ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field";
ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array";
ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty";
const auto& message = doc["messages"][0];
ASSERT_TRUE(message.HasMember("role")) << "Message missing role field";
ASSERT_TRUE(message.HasMember("content")) << "Message missing content field";
// content of the last message
if (message["content"].IsArray()) {
ASSERT_GT(message["content"].Size(), 0) << "Content array is empty";
const auto& content = message["content"][0];
ASSERT_TRUE(content.HasMember("type")) << "Content missing type field";
ASSERT_TRUE(content.HasMember("text")) << "Content missing text field";
ASSERT_STREQ(content["text"].GetString(), "hi anthropic");
} else if (message["content"].IsString()) {
ASSERT_STREQ(message["content"].GetString(), "hi anthropic");
}
}
TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response) {
AnthropicAdapter adapter;
std::string resp = R"({"content":[{"type":"text","text":"anthropic result"}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0], "anthropic result");
}
TEST(AI_ADAPTER_TEST, unsupported_provider_type) {
TAIResource config;
config.provider_type = "not_exist";
auto adapter = doris::AIAdapterFactory::create_adapter(config.provider_type);
ASSERT_EQ(adapter, nullptr);
}
TEST(AI_ADAPTER_TEST, adapter_factory_all_types) {
std::vector<std::string> types = {"LOCAL", "OPENAI", "MOONSHOT", "DEEPSEEK", "MINIMAX",
"ZHIPU", "QWEN", "JINA", "BAICHUAN", "ANTHROPIC",
"GEMINI", "VOYAGEAI", "MOCK"};
for (const auto& type : types) {
auto adapter = doris::AIAdapterFactory::create_adapter(type);
ASSERT_TRUE(adapter != nullptr) << "Adapter not found for type: " << type;
}
}
TEST(AI_ADAPTER_TEST, local_adapter_parse_response_parse_error) {
LocalAdapter adapter;
std::string resp = "not a json";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
}
TEST(AI_ADAPTER_TEST, parse_response_wrong_type) {
LocalAdapter adapter;
// response field is not a string
std::string resp = R"({"response":123})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(),
::testing::HasSubstr("Unsupported response format from local AI."));
}
TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choice_format_error) {
OpenAIAdapter adapter;
// message field missing
std::string resp = R"({"choices":[{}]})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response"));
// content field is not a string
resp = R"({"choices":[{"message":{"content":123}}]})";
results.clear();
st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response"));
}
TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_parse_error) {
OpenAIAdapter adapter;
std::string resp = "not a json";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse"));
}
TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choices_not_array) {
OpenAIAdapter adapter;
std::string resp = R"({"choices":123})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
}
TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response_parse_error) {
GeminiAdapter adapter;
std::string resp = "not a json";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse"));
}
TEST(AI_ADAPTER_TEST, gemini_parse_response_missing_candidates) {
GeminiAdapter adapter;
std::string resp = R"({"foo":"bar"})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
}
TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_parse_error) {
AnthropicAdapter adapter;
std::string resp = "not a json";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse"));
}
TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_content_not_array) {
AnthropicAdapter adapter;
std::string resp = R"({"content":123})";
std::vector<std::string> results;
Status st = adapter.parse_response(resp, results);
ASSERT_FALSE(st.ok());
EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
}
TEST(AI_ADAPTER_TEST, voyage_adapter_chat_test) {
VoyageAIAdapter adapter;
TAIResource config;
config.api_key = "test_voyage_key";
config.provider_type = "VoyageAI";
adapter.init(config);
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_voyage_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"test_inputs"};
std::string request_body;
Status st = adapter.build_request_payload(inputs, "test_system_prompt", request_body);
ASSERT_FALSE(st.ok());
ASSERT_STREQ(st.to_string().c_str(),
"[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation");
std::string response_body = "test_response_body";
std::vector<std::string> result;
st = adapter.parse_response(response_body, result);
ASSERT_FALSE(st.ok());
ASSERT_STREQ(st.to_string().c_str(),
"[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation");
}
TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_image) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "tongyi-embedding-vision-plus";
config.dimensions = 1024;
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE},
{"https://a/b/c.png"}, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("model"));
ASSERT_STREQ(doc["model"].GetString(), "tongyi-embedding-vision-plus");
ASSERT_TRUE(doc.HasMember("input"));
ASSERT_TRUE(doc["input"].HasMember("contents"));
const auto& contents = doc["input"]["contents"];
ASSERT_TRUE(contents.IsArray());
ASSERT_EQ(contents.Size(), 1);
ASSERT_TRUE(contents[0].HasMember("image"));
ASSERT_STREQ(contents[0]["image"].GetString(), "https://a/b/c.png");
ASSERT_TRUE(doc.HasMember("parameters"));
ASSERT_TRUE(doc["parameters"].IsObject());
ASSERT_TRUE(doc["parameters"].HasMember("dimension"));
ASSERT_EQ(doc["parameters"]["dimension"].GetInt(), 1024);
}
TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_video) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "tongyi-embedding-vision-plus";
config.dimensions = 1024;
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::VIDEO},
{"https://a/b/c.mp4"}, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("model"));
ASSERT_STREQ(doc["model"].GetString(), "tongyi-embedding-vision-plus");
ASSERT_TRUE(doc.HasMember("input"));
ASSERT_TRUE(doc["input"].HasMember("contents"));
const auto& contents = doc["input"]["contents"];
ASSERT_TRUE(contents.IsArray());
ASSERT_EQ(contents.Size(), 1);
ASSERT_TRUE(contents[0].HasMember("video"));
ASSERT_STREQ(contents[0]["video"].GetString(), "https://a/b/c.mp4");
ASSERT_TRUE(doc.HasMember("parameters"));
ASSERT_TRUE(doc["parameters"].IsObject());
ASSERT_TRUE(doc["parameters"].HasMember("dimension"));
ASSERT_EQ(doc["parameters"]["dimension"].GetInt(), 1024);
}
TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_batch_request) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "tongyi-embedding-vision-plus";
config.dimensions = 1024;
adapter.init(config);
std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO};
std::vector<std::string> media_urls = {"https://a/b/c.png", "https://a/b/c.mp4"};
std::string request_body;
Status st =
adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("input"));
ASSERT_TRUE(doc["input"].HasMember("contents"));
const auto& contents = doc["input"]["contents"];
ASSERT_TRUE(contents.IsArray());
ASSERT_EQ(contents.Size(), 2);
ASSERT_TRUE(contents[0].HasMember("image"));
ASSERT_STREQ(contents[0]["image"].GetString(), "https://a/b/c.png");
ASSERT_TRUE(contents[1].HasMember("video"));
ASSERT_STREQ(contents[1]["video"].GetString(), "https://a/b/c.mp4");
}
TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_audio_not_supported) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "tongyi-embedding-vision-plus";
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::AUDIO},
{"https://a/b/c.mp3"}, {}, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("QWEN only supports image/video multimodal embed"));
ASSERT_THAT(st.to_string(), ::testing::HasSubstr("audio"));
}
TEST(AI_ADAPTER_TEST, voyage_multimodal_embedding_request) {
VoyageAIAdapter adapter;
TAIResource config;
config.model_name = "voyage-multimodal-3.5";
config.dimensions = 2048;
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::VIDEO},
{"https://a/b/c.mp4"}, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("inputs"));
const auto& inputs = doc["inputs"];
ASSERT_TRUE(inputs.IsArray());
ASSERT_EQ(inputs.Size(), 1);
ASSERT_TRUE(inputs[0].HasMember("content"));
const auto& content = inputs[0]["content"];
ASSERT_TRUE(content.IsArray());
ASSERT_EQ(content.Size(), 1);
ASSERT_TRUE(content[0].HasMember("type"));
ASSERT_STREQ(content[0]["type"].GetString(), "video_url");
ASSERT_TRUE(content[0].HasMember("video_url"));
ASSERT_STREQ(content[0]["video_url"].GetString(), "https://a/b/c.mp4");
ASSERT_FALSE(doc.HasMember("dimensions"));
ASSERT_FALSE(doc.HasMember("output_dimension"));
}
TEST(AI_ADAPTER_TEST, voyage_multimodal_embedding_batch_request) {
VoyageAIAdapter adapter;
TAIResource config;
config.model_name = "voyage-multimodal-3.5";
adapter.init(config);
std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO};
std::vector<std::string> media_urls = {"https://a/b/c.png", "https://a/b/c.mp4"};
std::string request_body;
Status st =
adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("inputs"));
const auto& request_inputs = doc["inputs"];
ASSERT_TRUE(request_inputs.IsArray());
ASSERT_EQ(request_inputs.Size(), 2);
ASSERT_TRUE(request_inputs[0]["content"].IsArray());
ASSERT_STREQ(request_inputs[0]["content"][0]["type"].GetString(), "image_url");
ASSERT_STREQ(request_inputs[0]["content"][0]["image_url"].GetString(), "https://a/b/c.png");
ASSERT_TRUE(request_inputs[1]["content"].IsArray());
ASSERT_STREQ(request_inputs[1]["content"][0]["type"].GetString(), "video_url");
ASSERT_STREQ(request_inputs[1]["content"][0]["video_url"].GetString(), "https://a/b/c.mp4");
}
TEST(AI_ADAPTER_TEST, jina_multimodal_embedding_request) {
JinaAdapter adapter;
TAIResource config;
config.model_name = "jina-embeddings-v4";
config.dimensions = 512;
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE},
{"https://a/b/c.jpg"}, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("task"));
ASSERT_STREQ(doc["task"].GetString(), "text-matching");
ASSERT_TRUE(doc.HasMember("input"));
const auto& input = doc["input"];
ASSERT_TRUE(input.IsArray());
ASSERT_EQ(input.Size(), 1);
ASSERT_TRUE(input[0].HasMember("image"));
ASSERT_STREQ(input[0]["image"].GetString(), "https://a/b/c.jpg");
ASSERT_TRUE(doc.HasMember("dimensions"));
ASSERT_EQ(doc["dimensions"].GetInt(), 512);
}
TEST(AI_ADAPTER_TEST, jina_multimodal_embedding_batch_request) {
JinaAdapter adapter;
TAIResource config;
config.model_name = "jina-embeddings-v4";
config.dimensions = 512;
adapter.init(config);
std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO};
std::vector<std::string> media_urls = {"https://a/b/c.jpg", "https://a/b/c.mp4"};
std::string request_body;
Status st =
adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("input"));
const auto& input = doc["input"];
ASSERT_TRUE(input.IsArray());
ASSERT_EQ(input.Size(), 2);
ASSERT_TRUE(input[0].HasMember("image"));
ASSERT_STREQ(input[0]["image"].GetString(), "https://a/b/c.jpg");
ASSERT_TRUE(input[1].HasMember("video"));
ASSERT_STREQ(input[1]["video"].GetString(), "https://a/b/c.mp4");
}
TEST(AI_ADAPTER_TEST, multimodal_provider_support) {
OpenAIAdapter openai_adapter;
TAIResource openai_config;
openai_config.provider_type = "OPENAI";
openai_adapter.init(openai_config);
std::string request_body;
Status st = openai_adapter.build_multimodal_embedding_request(
{MultimodalType::IMAGE}, {"https://a/b/c.png"}, {}, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(), ::testing::HasSubstr("does not support multimodal Embed"));
}
TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request) {
GeminiAdapter gemini_adapter;
TAIResource gemini_config;
gemini_config.provider_type = "GEMINI";
gemini_config.model_name = "gemini-embedding-2-preview";
gemini_config.dimensions = 768;
gemini_adapter.init(gemini_config);
struct GeminiMultimodalCase {
MultimodalType media_type;
const char* media_url;
const char* mime_type;
};
const std::vector<GeminiMultimodalCase> test_cases = {
{MultimodalType::IMAGE, "https://a/b/c.jpg", "image/jpeg"},
{MultimodalType::IMAGE, "https://a/b/c.webp", "image/webp"},
{MultimodalType::AUDIO, "https://a/b/c.wav", "audio/wav"},
{MultimodalType::VIDEO, "https://a/b/c.webm", "video/webm"},
};
for (const auto& test_case : test_cases) {
std::string request_body;
Status st = gemini_adapter.build_multimodal_embedding_request(
{test_case.media_type}, {test_case.media_url}, {test_case.mime_type}, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("requests"));
ASSERT_TRUE(doc["requests"].IsArray());
ASSERT_EQ(doc["requests"].Size(), 1);
const auto& request = doc["requests"][0];
ASSERT_TRUE(request.HasMember("model"));
ASSERT_STREQ(request["model"].GetString(), "models/gemini-embedding-2-preview");
ASSERT_TRUE(request.HasMember("outputDimensionality"));
ASSERT_EQ(request["outputDimensionality"].GetInt(), 768);
ASSERT_TRUE(request.HasMember("content"));
ASSERT_TRUE(request["content"].HasMember("parts"));
ASSERT_TRUE(request["content"]["parts"].IsArray());
ASSERT_EQ(request["content"]["parts"].Size(), 1);
ASSERT_TRUE(request["content"]["parts"][0].HasMember("file_data"));
ASSERT_TRUE(request["content"]["parts"][0]["file_data"].IsObject());
ASSERT_STREQ(request["content"]["parts"][0]["file_data"]["mime_type"].GetString(),
test_case.mime_type);
ASSERT_STREQ(request["content"]["parts"][0]["file_data"]["file_uri"].GetString(),
test_case.media_url);
}
}
TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_batch_request) {
GeminiAdapter adapter;
TAIResource config;
config.provider_type = "GEMINI";
config.model_name = "gemini-embedding-2-preview";
config.dimensions = 768;
adapter.init(config);
std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::AUDIO,
MultimodalType::VIDEO};
std::vector<std::string> media_urls = {"https://a/b/c.jpg", "https://a/b/c.wav",
"https://a/b/c.webm"};
std::vector<std::string> media_content_types = {"image/jpeg", "audio/wav", "video/webm"};
std::string request_body;
Status st = adapter.build_multimodal_embedding_request(media_types, media_urls,
media_content_types, request_body);
ASSERT_TRUE(st.ok()) << st.to_string();
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << request_body;
ASSERT_TRUE(doc.IsObject());
ASSERT_TRUE(doc.HasMember("requests"));
const auto& requests = doc["requests"];
ASSERT_TRUE(requests.IsArray());
ASSERT_EQ(requests.Size(), 3);
ASSERT_STREQ(requests[0]["model"].GetString(), "models/gemini-embedding-2-preview");
ASSERT_EQ(requests[0]["outputDimensionality"].GetInt(), 768);
ASSERT_STREQ(requests[0]["content"]["parts"][0]["file_data"]["mime_type"].GetString(),
"image/jpeg");
ASSERT_STREQ(requests[0]["content"]["parts"][0]["file_data"]["file_uri"].GetString(),
"https://a/b/c.jpg");
ASSERT_STREQ(requests[1]["content"]["parts"][0]["file_data"]["mime_type"].GetString(),
"audio/wav");
ASSERT_STREQ(requests[1]["content"]["parts"][0]["file_data"]["file_uri"].GetString(),
"https://a/b/c.wav");
ASSERT_STREQ(requests[2]["content"]["parts"][0]["file_data"]["mime_type"].GetString(),
"video/webm");
ASSERT_STREQ(requests[2]["content"]["parts"][0]["file_data"]["file_uri"].GetString(),
"https://a/b/c.webm");
}
TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request_empty_inputs) {
GeminiAdapter adapter;
TAIResource config;
config.provider_type = "GEMINI";
config.model_name = "gemini-embedding-2-preview";
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({}, {}, {}, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("Gemini multimodal embed inputs can not be empty"));
}
TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request_size_mismatch) {
GeminiAdapter adapter;
TAIResource config;
config.provider_type = "GEMINI";
config.model_name = "gemini-embedding-2-preview";
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request(
{MultimodalType::IMAGE, MultimodalType::VIDEO}, {"https://a/b/c.png"}, {},
request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(
st.to_string(),
::testing::HasSubstr(
"Gemini multimodal embed input size mismatch, media_types=2, media_urls=1"));
}
TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_content_type_size_mismatch) {
GeminiAdapter adapter;
TAIResource config;
config.provider_type = "GEMINI";
config.model_name = "gemini-embedding-2-preview";
adapter.init(config);
std::string request_body;
Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE},
{"https://a/b/c.jpg"}, {}, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(), ::testing::HasSubstr("Gemini multimodal embed input size mismatch, "
"media_content_types=0, media_urls=1"));
}
TEST(AI_ADAPTER_TEST, gemini_parse_batch_embedding_response) {
GeminiAdapter adapter;
std::string resp = R"({
"embeddings": [
{"values": [0.1, 0.2, 0.3]},
{"values": [0.4, 0.5]}
]
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
}
} // namespace doris