| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "exprs/function/ai/ai_adapter.h" |
| |
| #include <curl/curl.h> |
| #include <gen_cpp/PaloInternalService_types.h> |
| #include <gmock/gmock-matchers.h> |
| #include <gtest/gtest.h> |
| |
| #include <string> |
| #include <vector> |
| |
| #include "common/status.h" |
| #include "exprs/function/ai/ai_classify.h" |
| #include "exprs/function/ai/ai_extract.h" |
| #include "exprs/function/ai/ai_sentiment.h" |
| #include "exprs/function/ai/ai_summarize.h" |
| |
| namespace doris { |
| class MockHttpClient : public HttpClient { |
| public: |
| curl_slist* get() { return this->_header_list; } |
| |
| private: |
| std::unordered_map<std::string, std::string> _headers; |
| std::string _content_type; |
| }; |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_chat_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "ollama"; |
| config.temperature = 0.7; |
| config.max_tokens = 128; |
| config.endpoint = "http://localhost:11434/api/chat"; |
| adapter.init(config); |
| |
| // header test |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| EXPECT_STREQ(mock_client.get()->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hello world"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body test |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // general flags |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "ollama"); |
| ASSERT_TRUE(doc.HasMember("stream")) << "Missing stream field"; |
| ASSERT_TRUE(doc["stream"].IsBool()) << "Stream field is not a bool"; |
| ASSERT_FALSE(doc["stream"].GetBool()); |
| ASSERT_TRUE(doc.HasMember("think")) << "Missing think field"; |
| ASSERT_TRUE(doc["think"].IsBool()) << "Think field is not a bool"; |
| ASSERT_FALSE(doc["think"].GetBool()); |
| |
| // options (temperature + max_token) |
| ASSERT_FALSE(doc.HasMember("temperature")) << "Temperature should be nested in options"; |
| ASSERT_FALSE(doc.HasMember("max_tokens")) << "Max tokens should be nested in options"; |
| ASSERT_TRUE(doc.HasMember("options")) << "Missing options field"; |
| ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object"; |
| const auto& options = doc["options"]; |
| ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field"; |
| ASSERT_TRUE(options["temperature"].IsNumber()) << "options.temperature is not a number"; |
| ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.7); |
| ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field"; |
| ASSERT_TRUE(options["max_token"].IsInt()) << "options.max_token is not an integer"; |
| ASSERT_EQ(options["max_token"].GetInt(), 128); |
| |
| // content |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GE(doc["messages"].Size(), 2) << "Messages should contain system and user prompts"; |
| const auto& first_message = doc["messages"][0]; |
| ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(first_message["role"].GetString(), "system"); |
| ASSERT_TRUE(first_message.HasMember("content")) << "Message missing content field"; |
| ASSERT_TRUE(first_message["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(first_message["content"].GetString(), FunctionAISummarize::system_prompt); |
| |
| const auto& user_message = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(user_message.HasMember("role")) << "User message missing role field"; |
| ASSERT_TRUE(user_message["role"].IsString()) << "User role field is not a string"; |
| ASSERT_STREQ(user_message["role"].GetString(), "user"); |
| ASSERT_TRUE(user_message.HasMember("content")) << "User message missing content field"; |
| ASSERT_TRUE(user_message["content"].IsString()) << "User content field is not a string"; |
| ASSERT_STREQ(user_message["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_generate_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "ollama"; |
| config.temperature = 0.8; |
| config.max_tokens = 64; |
| config.endpoint = "http://localhost:11434/api/generate"; |
| adapter.init(config); |
| |
| std::vector<std::string> inputs = {"hello world"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_STREQ(doc["model"].GetString(), "ollama"); |
| ASSERT_TRUE(doc.HasMember("system")) << "Missing system field"; |
| ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string"; |
| ASSERT_STREQ(doc["system"].GetString(), FunctionAISummarize::system_prompt); |
| ASSERT_TRUE(doc.HasMember("prompt")) << "Missing prompt field"; |
| ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a string"; |
| ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str()); |
| |
| ASSERT_FALSE(doc.HasMember("messages")) << "Generate endpoint should not include messages"; |
| |
| ASSERT_TRUE(doc.HasMember("options")) << "Missing options field"; |
| ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object"; |
| const auto& options = doc["options"]; |
| ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field"; |
| ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.8); |
| ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field"; |
| ASSERT_EQ(options["max_token"].GetInt(), 64); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_parse_response) { |
| LocalAdapter adapter; |
| // OpenAI type |
| std::string resp = R"({"choices":[{"message":{"content":"hi"}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "hi"); |
| |
| // Simple text type |
| resp = R"({"text":"simple result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "simple result"); |
| |
| resp = R"({"content":"simple result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "simple result"); |
| |
| // Ollama response type |
| resp = R"({"response":"ollama result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "ollama result"); |
| |
| // Ollama chat message type |
| resp = R"({"message":{"content":"ollama chat"}})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "ollama chat"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_default_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "local-default"; |
| config.temperature = 0.3; |
| config.max_tokens = 42; |
| config.endpoint = "http://localhost:8000/v1/completions"; |
| adapter.init(config); |
| |
| std::vector<std::string> inputs = {"default prompt"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("model")); |
| ASSERT_STREQ(doc["model"].GetString(), "local-default"); |
| |
| ASSERT_TRUE(doc.HasMember("temperature")); |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.3); |
| |
| ASSERT_TRUE(doc.HasMember("max_tokens")); |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 42); |
| |
| ASSERT_TRUE(doc.HasMember("messages")); |
| ASSERT_TRUE(doc["messages"].IsArray()); |
| ASSERT_GE(doc["messages"].Size(), 2); |
| |
| const auto& system_msg = doc["messages"][0]; |
| ASSERT_TRUE(system_msg.HasMember("role")); |
| ASSERT_STREQ(system_msg["role"].GetString(), "system"); |
| ASSERT_TRUE(system_msg.HasMember("content")); |
| ASSERT_STREQ(system_msg["content"].GetString(), FunctionAISummarize::system_prompt); |
| |
| const auto& user_msg = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(user_msg.HasMember("role")); |
| ASSERT_STREQ(user_msg["role"].GetString(), "user"); |
| ASSERT_TRUE(user_msg.HasMember("content")); |
| ASSERT_STREQ(user_msg["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_completions_request) { |
| OpenAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "gpt-3.5-turbo"; |
| config.temperature = 0.5; |
| config.max_tokens = 64; |
| config.api_key = "test_openai_key"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi openai"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "gpt-3.5-turbo"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5); |
| |
| // max token |
| ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field"; |
| ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 64); |
| // msg |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty"; |
| |
| // system_prompt |
| const auto& first_message = doc["messages"][0]; |
| ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(first_message["role"].GetString(), "system"); |
| ASSERT_STREQ(first_message["content"].GetString(), FunctionAISentiment::system_prompt); |
| |
| // The content of the last message |
| const auto& last_message = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(last_message.HasMember("content")) << "Message missing content field"; |
| ASSERT_TRUE(last_message["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(last_message["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_completions_parse_response) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"choices":[{"message":{"content":"openai result"}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "openai result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adatper_responses_request) { |
| OpenAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "gpt-5"; |
| config.temperature = 0.5; |
| config.max_tokens = 64; |
| config.api_key = "test_openai_key"; |
| config.endpoint = "https://api.openai.com/v1/responses"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi openai"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "gpt-5"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5); |
| |
| // max tokens |
| ASSERT_TRUE(doc.HasMember("max_output_tokens")) << "Missing max_output_tokens field"; |
| ASSERT_TRUE(doc["max_output_tokens"].IsInt()) << "max_output_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_output_tokens"].GetInt(), 64); |
| |
| // input |
| ASSERT_TRUE(doc.HasMember("input")) << "Missing input field"; |
| ASSERT_TRUE(doc["input"].IsArray()) << "Input is not an array"; |
| ASSERT_GT(doc["input"].Size(), 0) << "Input array is empty"; |
| |
| // system_prompt |
| const auto& input = doc["input"]; |
| ASSERT_TRUE(input[0].HasMember("role")) << request_body; |
| ASSERT_TRUE(input[0]["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(input[0]["role"].GetString(), "system"); |
| ASSERT_TRUE(input[0].HasMember("content")) << request_body; |
| ASSERT_TRUE(input[0]["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(input[0]["content"].GetString(), FunctionAISentiment::system_prompt); |
| |
| // input content |
| ASSERT_TRUE(input[1].HasMember("role")) << request_body; |
| ASSERT_TRUE(input[1]["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(input[1]["role"].GetString(), "user"); |
| ASSERT_TRUE(input[1].HasMember("content")) << request_body; |
| ASSERT_TRUE(input[1]["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(input[1]["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_responses_parse_response) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"output":[{"content":[{"text":"openai response result"}]}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "openai response result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_keeps_mask_literals) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"choices":[{"message":{"content":"[MSKED]"}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "[MSKED]"); |
| |
| resp = R"({"choices":[{"message":{"content":"[MASK]"}}]})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "[MASK]"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_request) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.temperature = 0.2; |
| config.max_tokens = 32; |
| config.api_key = "test_gemini_key"; |
| adapter.init(config); |
| |
| // header test |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "x-goog-api-key: test_gemini_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hello gemini"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAIExtract::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body test |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("generationConfig")); |
| const auto& gen_cfg = doc["generationConfig"]; |
| |
| //temperature |
| ASSERT_TRUE(gen_cfg.HasMember("temperature")); |
| ASSERT_TRUE(gen_cfg["temperature"].IsNumber()); |
| ASSERT_DOUBLE_EQ(gen_cfg["temperature"].GetDouble(), 0.2); |
| |
| //max_token |
| ASSERT_TRUE(gen_cfg.HasMember("maxOutputTokens")) << "Missing maxOutputTokens field"; |
| ASSERT_TRUE(gen_cfg["maxOutputTokens"].IsInt()); |
| ASSERT_EQ(gen_cfg["maxOutputTokens"].GetInt(), 32); |
| |
| // system_prompt |
| ASSERT_TRUE(doc.HasMember("systemInstruction")) << "Missing system field"; |
| ASSERT_TRUE(doc["systemInstruction"].IsObject()) << request_body; |
| ASSERT_TRUE(doc["systemInstruction"].HasMember("parts")) << request_body; |
| ASSERT_TRUE(doc["systemInstruction"]["parts"].IsArray()) << request_body; |
| ASSERT_GT(doc["systemInstruction"]["parts"].Size(), 0) << request_body; |
| |
| const auto& content_sys = doc["systemInstruction"]["parts"][0]; |
| ASSERT_TRUE(content_sys.HasMember("text")) << "parts missing text field"; |
| ASSERT_TRUE(content_sys["text"].IsString()) << "Text field is not a string"; |
| ASSERT_STREQ(content_sys["text"].GetString(), FunctionAIExtract::system_prompt); |
| |
| // content structure |
| ASSERT_TRUE(doc.HasMember("contents")) << "Missing contents field"; |
| ASSERT_TRUE(doc["contents"].IsArray()) << "Contents is not an array"; |
| ASSERT_GT(doc["contents"].Size(), 0) << "Contents array is empty"; |
| |
| // content |
| const auto& content = doc["contents"][0]; |
| ASSERT_TRUE(content.HasMember("parts")) << "Content missing parts field"; |
| ASSERT_TRUE(content["parts"].IsArray()) << "Parts is not an array"; |
| ASSERT_GT(content["parts"].Size(), 0) << "Parts array is empty"; |
| |
| const auto& part = content["parts"][0]; |
| ASSERT_TRUE(part.HasMember("text")) << "Part missing text field"; |
| ASSERT_TRUE(part["text"].IsString()) << "Text field is not a string"; |
| ASSERT_STREQ(part["text"].GetString(), "hello gemini"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response) { |
| GeminiAdapter adapter; |
| std::string resp = R"({"candidates":[{"content":{"parts":[{"text":"gemini result"}]}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "gemini result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_request) { |
| AnthropicAdapter adapter; |
| TAIResource config; |
| config.model_name = "claude-3"; |
| config.temperature = 1.0; |
| config.max_tokens = 256; |
| config.api_key = "test_anthropic_key"; |
| config.anthropic_version = "2023-06-01"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "x-api-key: test_anthropic_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "anthropic-version: 2023-06-01"); |
| EXPECT_STREQ(mock_client.get()->next->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi anthropic"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAIClassify::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "claude-3"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 1.0); |
| |
| // max token |
| ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field"; |
| ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 256); |
| |
| // system_prompt |
| ASSERT_TRUE(doc.HasMember("system")) << "Missing system field"; |
| ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string"; |
| ASSERT_STREQ(doc["system"].GetString(), FunctionAIClassify::system_prompt); |
| |
| // message Format |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty"; |
| |
| const auto& message = doc["messages"][0]; |
| ASSERT_TRUE(message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(message.HasMember("content")) << "Message missing content field"; |
| |
| // content of the last message |
| if (message["content"].IsArray()) { |
| ASSERT_GT(message["content"].Size(), 0) << "Content array is empty"; |
| const auto& content = message["content"][0]; |
| ASSERT_TRUE(content.HasMember("type")) << "Content missing type field"; |
| ASSERT_TRUE(content.HasMember("text")) << "Content missing text field"; |
| ASSERT_STREQ(content["text"].GetString(), "hi anthropic"); |
| } else if (message["content"].IsString()) { |
| ASSERT_STREQ(message["content"].GetString(), "hi anthropic"); |
| } |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response) { |
| AnthropicAdapter adapter; |
| std::string resp = R"({"content":[{"type":"text","text":"anthropic result"}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "anthropic result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, unsupported_provider_type) { |
| TAIResource config; |
| config.provider_type = "not_exist"; |
| auto adapter = doris::AIAdapterFactory::create_adapter(config.provider_type); |
| ASSERT_EQ(adapter, nullptr); |
| } |
| |
| TEST(AI_ADAPTER_TEST, adapter_factory_all_types) { |
| std::vector<std::string> types = {"LOCAL", "OPENAI", "MOONSHOT", "DEEPSEEK", "MINIMAX", |
| "ZHIPU", "QWEN", "JINA", "BAICHUAN", "ANTHROPIC", |
| "GEMINI", "VOYAGEAI", "MOCK"}; |
| for (const auto& type : types) { |
| auto adapter = doris::AIAdapterFactory::create_adapter(type); |
| ASSERT_TRUE(adapter != nullptr) << "Adapter not found for type: " << type; |
| } |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_parse_response_parse_error) { |
| LocalAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, parse_response_wrong_type) { |
| LocalAdapter adapter; |
| // response field is not a string |
| std::string resp = R"({"response":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), |
| ::testing::HasSubstr("Unsupported response format from local AI.")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choice_format_error) { |
| OpenAIAdapter adapter; |
| // message field missing |
| std::string resp = R"({"choices":[{}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response")); |
| |
| // content field is not a string |
| resp = R"({"choices":[{"message":{"content":123}}]})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_parse_error) { |
| OpenAIAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choices_not_array) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"choices":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response_parse_error) { |
| GeminiAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_parse_response_missing_candidates) { |
| GeminiAdapter adapter; |
| std::string resp = R"({"foo":"bar"})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_parse_error) { |
| AnthropicAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_content_not_array) { |
| AnthropicAdapter adapter; |
| std::string resp = R"({"content":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, voyage_adapter_chat_test) { |
| VoyageAIAdapter adapter; |
| TAIResource config; |
| config.api_key = "test_voyage_key"; |
| config.provider_type = "VoyageAI"; |
| |
| adapter.init(config); |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_voyage_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"test_inputs"}; |
| std::string request_body; |
| Status st = adapter.build_request_payload(inputs, "test_system_prompt", request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_STREQ(st.to_string().c_str(), |
| "[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation"); |
| |
| std::string response_body = "test_response_body"; |
| std::vector<std::string> result; |
| st = adapter.parse_response(response_body, result); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_STREQ(st.to_string().c_str(), |
| "[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_image) { |
| QwenAdapter adapter; |
| TAIResource config; |
| config.model_name = "tongyi-embedding-vision-plus"; |
| config.dimensions = 1024; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE}, |
| {"https://a/b/c.png"}, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("model")); |
| ASSERT_STREQ(doc["model"].GetString(), "tongyi-embedding-vision-plus"); |
| ASSERT_TRUE(doc.HasMember("input")); |
| ASSERT_TRUE(doc["input"].HasMember("contents")); |
| const auto& contents = doc["input"]["contents"]; |
| ASSERT_TRUE(contents.IsArray()); |
| ASSERT_EQ(contents.Size(), 1); |
| ASSERT_TRUE(contents[0].HasMember("image")); |
| ASSERT_STREQ(contents[0]["image"].GetString(), "https://a/b/c.png"); |
| ASSERT_TRUE(doc.HasMember("parameters")); |
| ASSERT_TRUE(doc["parameters"].IsObject()); |
| ASSERT_TRUE(doc["parameters"].HasMember("dimension")); |
| ASSERT_EQ(doc["parameters"]["dimension"].GetInt(), 1024); |
| } |
| |
| TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_video) { |
| QwenAdapter adapter; |
| TAIResource config; |
| config.model_name = "tongyi-embedding-vision-plus"; |
| config.dimensions = 1024; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::VIDEO}, |
| {"https://a/b/c.mp4"}, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("model")); |
| ASSERT_STREQ(doc["model"].GetString(), "tongyi-embedding-vision-plus"); |
| ASSERT_TRUE(doc.HasMember("input")); |
| ASSERT_TRUE(doc["input"].HasMember("contents")); |
| const auto& contents = doc["input"]["contents"]; |
| ASSERT_TRUE(contents.IsArray()); |
| ASSERT_EQ(contents.Size(), 1); |
| ASSERT_TRUE(contents[0].HasMember("video")); |
| ASSERT_STREQ(contents[0]["video"].GetString(), "https://a/b/c.mp4"); |
| ASSERT_TRUE(doc.HasMember("parameters")); |
| ASSERT_TRUE(doc["parameters"].IsObject()); |
| ASSERT_TRUE(doc["parameters"].HasMember("dimension")); |
| ASSERT_EQ(doc["parameters"]["dimension"].GetInt(), 1024); |
| } |
| |
| TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_batch_request) { |
| QwenAdapter adapter; |
| TAIResource config; |
| config.model_name = "tongyi-embedding-vision-plus"; |
| config.dimensions = 1024; |
| adapter.init(config); |
| |
| std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO}; |
| std::vector<std::string> media_urls = {"https://a/b/c.png", "https://a/b/c.mp4"}; |
| std::string request_body; |
| Status st = |
| adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("input")); |
| ASSERT_TRUE(doc["input"].HasMember("contents")); |
| const auto& contents = doc["input"]["contents"]; |
| ASSERT_TRUE(contents.IsArray()); |
| ASSERT_EQ(contents.Size(), 2); |
| ASSERT_TRUE(contents[0].HasMember("image")); |
| ASSERT_STREQ(contents[0]["image"].GetString(), "https://a/b/c.png"); |
| ASSERT_TRUE(contents[1].HasMember("video")); |
| ASSERT_STREQ(contents[1]["video"].GetString(), "https://a/b/c.mp4"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, qwen_multimodal_embedding_request_audio_not_supported) { |
| QwenAdapter adapter; |
| TAIResource config; |
| config.model_name = "tongyi-embedding-vision-plus"; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::AUDIO}, |
| {"https://a/b/c.mp3"}, {}, request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_THAT(st.to_string(), |
| ::testing::HasSubstr("QWEN only supports image/video multimodal embed")); |
| ASSERT_THAT(st.to_string(), ::testing::HasSubstr("audio")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, voyage_multimodal_embedding_request) { |
| VoyageAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "voyage-multimodal-3.5"; |
| config.dimensions = 2048; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::VIDEO}, |
| {"https://a/b/c.mp4"}, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("inputs")); |
| const auto& inputs = doc["inputs"]; |
| ASSERT_TRUE(inputs.IsArray()); |
| ASSERT_EQ(inputs.Size(), 1); |
| ASSERT_TRUE(inputs[0].HasMember("content")); |
| const auto& content = inputs[0]["content"]; |
| ASSERT_TRUE(content.IsArray()); |
| ASSERT_EQ(content.Size(), 1); |
| ASSERT_TRUE(content[0].HasMember("type")); |
| ASSERT_STREQ(content[0]["type"].GetString(), "video_url"); |
| ASSERT_TRUE(content[0].HasMember("video_url")); |
| ASSERT_STREQ(content[0]["video_url"].GetString(), "https://a/b/c.mp4"); |
| ASSERT_FALSE(doc.HasMember("dimensions")); |
| ASSERT_FALSE(doc.HasMember("output_dimension")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, voyage_multimodal_embedding_batch_request) { |
| VoyageAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "voyage-multimodal-3.5"; |
| adapter.init(config); |
| |
| std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO}; |
| std::vector<std::string> media_urls = {"https://a/b/c.png", "https://a/b/c.mp4"}; |
| std::string request_body; |
| Status st = |
| adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("inputs")); |
| const auto& request_inputs = doc["inputs"]; |
| ASSERT_TRUE(request_inputs.IsArray()); |
| ASSERT_EQ(request_inputs.Size(), 2); |
| ASSERT_TRUE(request_inputs[0]["content"].IsArray()); |
| ASSERT_STREQ(request_inputs[0]["content"][0]["type"].GetString(), "image_url"); |
| ASSERT_STREQ(request_inputs[0]["content"][0]["image_url"].GetString(), "https://a/b/c.png"); |
| ASSERT_TRUE(request_inputs[1]["content"].IsArray()); |
| ASSERT_STREQ(request_inputs[1]["content"][0]["type"].GetString(), "video_url"); |
| ASSERT_STREQ(request_inputs[1]["content"][0]["video_url"].GetString(), "https://a/b/c.mp4"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, jina_multimodal_embedding_request) { |
| JinaAdapter adapter; |
| TAIResource config; |
| config.model_name = "jina-embeddings-v4"; |
| config.dimensions = 512; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE}, |
| {"https://a/b/c.jpg"}, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("task")); |
| ASSERT_STREQ(doc["task"].GetString(), "text-matching"); |
| ASSERT_TRUE(doc.HasMember("input")); |
| const auto& input = doc["input"]; |
| ASSERT_TRUE(input.IsArray()); |
| ASSERT_EQ(input.Size(), 1); |
| ASSERT_TRUE(input[0].HasMember("image")); |
| ASSERT_STREQ(input[0]["image"].GetString(), "https://a/b/c.jpg"); |
| ASSERT_TRUE(doc.HasMember("dimensions")); |
| ASSERT_EQ(doc["dimensions"].GetInt(), 512); |
| } |
| |
| TEST(AI_ADAPTER_TEST, jina_multimodal_embedding_batch_request) { |
| JinaAdapter adapter; |
| TAIResource config; |
| config.model_name = "jina-embeddings-v4"; |
| config.dimensions = 512; |
| adapter.init(config); |
| |
| std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::VIDEO}; |
| std::vector<std::string> media_urls = {"https://a/b/c.jpg", "https://a/b/c.mp4"}; |
| std::string request_body; |
| Status st = |
| adapter.build_multimodal_embedding_request(media_types, media_urls, {}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("input")); |
| const auto& input = doc["input"]; |
| ASSERT_TRUE(input.IsArray()); |
| ASSERT_EQ(input.Size(), 2); |
| ASSERT_TRUE(input[0].HasMember("image")); |
| ASSERT_STREQ(input[0]["image"].GetString(), "https://a/b/c.jpg"); |
| ASSERT_TRUE(input[1].HasMember("video")); |
| ASSERT_STREQ(input[1]["video"].GetString(), "https://a/b/c.mp4"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, multimodal_provider_support) { |
| OpenAIAdapter openai_adapter; |
| TAIResource openai_config; |
| openai_config.provider_type = "OPENAI"; |
| openai_adapter.init(openai_config); |
| |
| std::string request_body; |
| Status st = openai_adapter.build_multimodal_embedding_request( |
| {MultimodalType::IMAGE}, {"https://a/b/c.png"}, {}, request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_THAT(st.to_string(), ::testing::HasSubstr("does not support multimodal Embed")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request) { |
| GeminiAdapter gemini_adapter; |
| TAIResource gemini_config; |
| gemini_config.provider_type = "GEMINI"; |
| gemini_config.model_name = "gemini-embedding-2-preview"; |
| gemini_config.dimensions = 768; |
| gemini_adapter.init(gemini_config); |
| |
| struct GeminiMultimodalCase { |
| MultimodalType media_type; |
| const char* media_url; |
| const char* mime_type; |
| }; |
| const std::vector<GeminiMultimodalCase> test_cases = { |
| {MultimodalType::IMAGE, "https://a/b/c.jpg", "image/jpeg"}, |
| {MultimodalType::IMAGE, "https://a/b/c.webp", "image/webp"}, |
| {MultimodalType::AUDIO, "https://a/b/c.wav", "audio/wav"}, |
| {MultimodalType::VIDEO, "https://a/b/c.webm", "video/webm"}, |
| }; |
| |
| for (const auto& test_case : test_cases) { |
| std::string request_body; |
| Status st = gemini_adapter.build_multimodal_embedding_request( |
| {test_case.media_type}, {test_case.media_url}, {test_case.mime_type}, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("requests")); |
| ASSERT_TRUE(doc["requests"].IsArray()); |
| ASSERT_EQ(doc["requests"].Size(), 1); |
| const auto& request = doc["requests"][0]; |
| ASSERT_TRUE(request.HasMember("model")); |
| ASSERT_STREQ(request["model"].GetString(), "models/gemini-embedding-2-preview"); |
| ASSERT_TRUE(request.HasMember("outputDimensionality")); |
| ASSERT_EQ(request["outputDimensionality"].GetInt(), 768); |
| ASSERT_TRUE(request.HasMember("content")); |
| ASSERT_TRUE(request["content"].HasMember("parts")); |
| ASSERT_TRUE(request["content"]["parts"].IsArray()); |
| ASSERT_EQ(request["content"]["parts"].Size(), 1); |
| ASSERT_TRUE(request["content"]["parts"][0].HasMember("file_data")); |
| ASSERT_TRUE(request["content"]["parts"][0]["file_data"].IsObject()); |
| ASSERT_STREQ(request["content"]["parts"][0]["file_data"]["mime_type"].GetString(), |
| test_case.mime_type); |
| ASSERT_STREQ(request["content"]["parts"][0]["file_data"]["file_uri"].GetString(), |
| test_case.media_url); |
| } |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_batch_request) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.provider_type = "GEMINI"; |
| config.model_name = "gemini-embedding-2-preview"; |
| config.dimensions = 768; |
| adapter.init(config); |
| |
| std::vector<MultimodalType> media_types = {MultimodalType::IMAGE, MultimodalType::AUDIO, |
| MultimodalType::VIDEO}; |
| std::vector<std::string> media_urls = {"https://a/b/c.jpg", "https://a/b/c.wav", |
| "https://a/b/c.webm"}; |
| std::vector<std::string> media_content_types = {"image/jpeg", "audio/wav", "video/webm"}; |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request(media_types, media_urls, |
| media_content_types, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << request_body; |
| ASSERT_TRUE(doc.IsObject()); |
| ASSERT_TRUE(doc.HasMember("requests")); |
| const auto& requests = doc["requests"]; |
| ASSERT_TRUE(requests.IsArray()); |
| ASSERT_EQ(requests.Size(), 3); |
| |
| ASSERT_STREQ(requests[0]["model"].GetString(), "models/gemini-embedding-2-preview"); |
| ASSERT_EQ(requests[0]["outputDimensionality"].GetInt(), 768); |
| ASSERT_STREQ(requests[0]["content"]["parts"][0]["file_data"]["mime_type"].GetString(), |
| "image/jpeg"); |
| ASSERT_STREQ(requests[0]["content"]["parts"][0]["file_data"]["file_uri"].GetString(), |
| "https://a/b/c.jpg"); |
| |
| ASSERT_STREQ(requests[1]["content"]["parts"][0]["file_data"]["mime_type"].GetString(), |
| "audio/wav"); |
| ASSERT_STREQ(requests[1]["content"]["parts"][0]["file_data"]["file_uri"].GetString(), |
| "https://a/b/c.wav"); |
| |
| ASSERT_STREQ(requests[2]["content"]["parts"][0]["file_data"]["mime_type"].GetString(), |
| "video/webm"); |
| ASSERT_STREQ(requests[2]["content"]["parts"][0]["file_data"]["file_uri"].GetString(), |
| "https://a/b/c.webm"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request_empty_inputs) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.provider_type = "GEMINI"; |
| config.model_name = "gemini-embedding-2-preview"; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({}, {}, {}, request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_THAT(st.to_string(), |
| ::testing::HasSubstr("Gemini multimodal embed inputs can not be empty")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_request_size_mismatch) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.provider_type = "GEMINI"; |
| config.model_name = "gemini-embedding-2-preview"; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request( |
| {MultimodalType::IMAGE, MultimodalType::VIDEO}, {"https://a/b/c.png"}, {}, |
| request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_THAT( |
| st.to_string(), |
| ::testing::HasSubstr( |
| "Gemini multimodal embed input size mismatch, media_types=2, media_urls=1")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_multimodal_embedding_content_type_size_mismatch) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.provider_type = "GEMINI"; |
| config.model_name = "gemini-embedding-2-preview"; |
| adapter.init(config); |
| |
| std::string request_body; |
| Status st = adapter.build_multimodal_embedding_request({MultimodalType::IMAGE}, |
| {"https://a/b/c.jpg"}, {}, request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_THAT(st.to_string(), ::testing::HasSubstr("Gemini multimodal embed input size mismatch, " |
| "media_content_types=0, media_urls=1")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_parse_batch_embedding_response) { |
| GeminiAdapter adapter; |
| std::string resp = R"({ |
| "embeddings": [ |
| {"values": [0.1, 0.2, 0.3]}, |
| {"values": [0.4, 0.5]} |
| ] |
| })"; |
| |
| std::vector<std::vector<float>> results; |
| Status st = adapter.parse_embedding_response(resp, results); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| ASSERT_EQ(results.size(), 2); |
| ASSERT_EQ(results[0].size(), 3); |
| ASSERT_EQ(results[1].size(), 2); |
| ASSERT_FLOAT_EQ(results[0][0], 0.1F); |
| ASSERT_FLOAT_EQ(results[0][2], 0.3F); |
| ASSERT_FLOAT_EQ(results[1][0], 0.4F); |
| ASSERT_FLOAT_EQ(results[1][1], 0.5F); |
| } |
| |
| } // namespace doris |