| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "vec/functions/ai/ai_adapter.h" |
| |
| #include <curl/curl.h> |
| #include <gen_cpp/PaloInternalService_types.h> |
| #include <gmock/gmock-matchers.h> |
| #include <gtest/gtest.h> |
| |
| #include <string> |
| #include <vector> |
| |
| #include "common/status.h" |
| #include "vec/functions/ai/ai_classify.h" |
| #include "vec/functions/ai/ai_extract.h" |
| #include "vec/functions/ai/ai_sentiment.h" |
| #include "vec/functions/ai/ai_summarize.h" |
| |
| namespace doris::vectorized { |
| class MockHttpClient : public HttpClient { |
| public: |
| curl_slist* get() { return this->_header_list; } |
| |
| private: |
| std::unordered_map<std::string, std::string> _headers; |
| std::string _content_type; |
| }; |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_chat_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "ollama"; |
| config.temperature = 0.7; |
| config.max_tokens = 128; |
| config.endpoint = "http://localhost:11434/api/chat"; |
| adapter.init(config); |
| |
| // header test |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| EXPECT_STREQ(mock_client.get()->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hello world"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body test |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // general flags |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "ollama"); |
| ASSERT_TRUE(doc.HasMember("stream")) << "Missing stream field"; |
| ASSERT_TRUE(doc["stream"].IsBool()) << "Stream field is not a bool"; |
| ASSERT_FALSE(doc["stream"].GetBool()); |
| ASSERT_TRUE(doc.HasMember("think")) << "Missing think field"; |
| ASSERT_TRUE(doc["think"].IsBool()) << "Think field is not a bool"; |
| ASSERT_FALSE(doc["think"].GetBool()); |
| |
| // options (temperature + max_token) |
| ASSERT_FALSE(doc.HasMember("temperature")) << "Temperature should be nested in options"; |
| ASSERT_FALSE(doc.HasMember("max_tokens")) << "Max tokens should be nested in options"; |
| ASSERT_TRUE(doc.HasMember("options")) << "Missing options field"; |
| ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object"; |
| const auto& options = doc["options"]; |
| ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field"; |
| ASSERT_TRUE(options["temperature"].IsNumber()) << "options.temperature is not a number"; |
| ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.7); |
| ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field"; |
| ASSERT_TRUE(options["max_token"].IsInt()) << "options.max_token is not an integer"; |
| ASSERT_EQ(options["max_token"].GetInt(), 128); |
| |
| // content |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GE(doc["messages"].Size(), 2) << "Messages should contain system and user prompts"; |
| const auto& first_message = doc["messages"][0]; |
| ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(first_message["role"].GetString(), "system"); |
| ASSERT_TRUE(first_message.HasMember("content")) << "Message missing content field"; |
| ASSERT_TRUE(first_message["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(first_message["content"].GetString(), FunctionAISummarize::system_prompt); |
| |
| const auto& user_message = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(user_message.HasMember("role")) << "User message missing role field"; |
| ASSERT_TRUE(user_message["role"].IsString()) << "User role field is not a string"; |
| ASSERT_STREQ(user_message["role"].GetString(), "user"); |
| ASSERT_TRUE(user_message.HasMember("content")) << "User message missing content field"; |
| ASSERT_TRUE(user_message["content"].IsString()) << "User content field is not a string"; |
| ASSERT_STREQ(user_message["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_generate_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "ollama"; |
| config.temperature = 0.8; |
| config.max_tokens = 64; |
| config.endpoint = "http://localhost:11434/api/generate"; |
| adapter.init(config); |
| |
| std::vector<std::string> inputs = {"hello world"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_STREQ(doc["model"].GetString(), "ollama"); |
| ASSERT_TRUE(doc.HasMember("system")) << "Missing system field"; |
| ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string"; |
| ASSERT_STREQ(doc["system"].GetString(), FunctionAISummarize::system_prompt); |
| ASSERT_TRUE(doc.HasMember("prompt")) << "Missing prompt field"; |
| ASSERT_TRUE(doc["prompt"].IsString()) << "Prompt field is not a string"; |
| ASSERT_STREQ(doc["prompt"].GetString(), inputs[0].c_str()); |
| |
| ASSERT_FALSE(doc.HasMember("messages")) << "Generate endpoint should not include messages"; |
| |
| ASSERT_TRUE(doc.HasMember("options")) << "Missing options field"; |
| ASSERT_TRUE(doc["options"].IsObject()) << "Options is not an object"; |
| const auto& options = doc["options"]; |
| ASSERT_TRUE(options.HasMember("temperature")) << "Missing options.temperature field"; |
| ASSERT_DOUBLE_EQ(options["temperature"].GetDouble(), 0.8); |
| ASSERT_TRUE(options.HasMember("max_token")) << "Missing options.max_token field"; |
| ASSERT_EQ(options["max_token"].GetInt(), 64); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_parse_response) { |
| LocalAdapter adapter; |
| // OpenAI type |
| std::string resp = R"({"choices":[{"message":{"content":"hi"}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "hi"); |
| |
| // Simple text type |
| resp = R"({"text":"simple result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "simple result"); |
| |
| resp = R"({"content":"simple result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "simple result"); |
| |
| // Ollama response type |
| resp = R"({"response":"ollama result"})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "ollama result"); |
| |
| // Ollama chat message type |
| resp = R"({"message":{"content":"ollama chat"}})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "ollama chat"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_request_default_endpoint) { |
| LocalAdapter adapter; |
| TAIResource config; |
| config.model_name = "local-default"; |
| config.temperature = 0.3; |
| config.max_tokens = 42; |
| config.endpoint = "http://localhost:8000/v1/completions"; |
| adapter.init(config); |
| |
| std::vector<std::string> inputs = {"default prompt"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISummarize::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()) << st.to_string(); |
| |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("model")); |
| ASSERT_STREQ(doc["model"].GetString(), "local-default"); |
| |
| ASSERT_TRUE(doc.HasMember("temperature")); |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.3); |
| |
| ASSERT_TRUE(doc.HasMember("max_tokens")); |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 42); |
| |
| ASSERT_TRUE(doc.HasMember("messages")); |
| ASSERT_TRUE(doc["messages"].IsArray()); |
| ASSERT_GE(doc["messages"].Size(), 2); |
| |
| const auto& system_msg = doc["messages"][0]; |
| ASSERT_TRUE(system_msg.HasMember("role")); |
| ASSERT_STREQ(system_msg["role"].GetString(), "system"); |
| ASSERT_TRUE(system_msg.HasMember("content")); |
| ASSERT_STREQ(system_msg["content"].GetString(), FunctionAISummarize::system_prompt); |
| |
| const auto& user_msg = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(user_msg.HasMember("role")); |
| ASSERT_STREQ(user_msg["role"].GetString(), "user"); |
| ASSERT_TRUE(user_msg.HasMember("content")); |
| ASSERT_STREQ(user_msg["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_completions_request) { |
| OpenAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "gpt-3.5-turbo"; |
| config.temperature = 0.5; |
| config.max_tokens = 64; |
| config.api_key = "test_openai_key"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi openai"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "gpt-3.5-turbo"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5); |
| |
| // max token |
| ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field"; |
| ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 64); |
| // msg |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty"; |
| |
| // system_prompt |
| const auto& first_message = doc["messages"][0]; |
| ASSERT_TRUE(first_message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(first_message["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(first_message["role"].GetString(), "system"); |
| ASSERT_STREQ(first_message["content"].GetString(), FunctionAISentiment::system_prompt); |
| |
| // The content of the last message |
| const auto& last_message = doc["messages"][doc["messages"].Size() - 1]; |
| ASSERT_TRUE(last_message.HasMember("content")) << "Message missing content field"; |
| ASSERT_TRUE(last_message["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(last_message["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_completions_parse_response) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"choices":[{"message":{"content":"openai result"}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "openai result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adatper_responses_request) { |
| OpenAIAdapter adapter; |
| TAIResource config; |
| config.model_name = "gpt-5"; |
| config.temperature = 0.5; |
| config.max_tokens = 64; |
| config.api_key = "test_openai_key"; |
| config.endpoint = "https://api.openai.com/v1/responses"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi openai"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAISentiment::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "gpt-5"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 0.5); |
| |
| // max tokens |
| ASSERT_TRUE(doc.HasMember("max_output_tokens")) << "Missing max_output_tokens field"; |
| ASSERT_TRUE(doc["max_output_tokens"].IsInt()) << "max_output_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_output_tokens"].GetInt(), 64); |
| |
| // input |
| ASSERT_TRUE(doc.HasMember("input")) << "Missing input field"; |
| ASSERT_TRUE(doc["input"].IsArray()) << "Input is not an array"; |
| ASSERT_GT(doc["input"].Size(), 0) << "Input array is empty"; |
| |
| // system_prompt |
| const auto& input = doc["input"]; |
| ASSERT_TRUE(input[0].HasMember("role")) << request_body; |
| ASSERT_TRUE(input[0]["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(input[0]["role"].GetString(), "system"); |
| ASSERT_TRUE(input[0].HasMember("content")) << request_body; |
| ASSERT_TRUE(input[0]["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(input[0]["content"].GetString(), FunctionAISentiment::system_prompt); |
| |
| // input content |
| ASSERT_TRUE(input[1].HasMember("role")) << request_body; |
| ASSERT_TRUE(input[1]["role"].IsString()) << "Role field is not a string"; |
| ASSERT_STREQ(input[1]["role"].GetString(), "user"); |
| ASSERT_TRUE(input[1].HasMember("content")) << request_body; |
| ASSERT_TRUE(input[1]["content"].IsString()) << "Content field is not a string"; |
| ASSERT_STREQ(input[1]["content"].GetString(), inputs[0].c_str()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_responses_parse_response) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"output":[{"content":[{"text":"openai response result"}]}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "openai response result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_request) { |
| GeminiAdapter adapter; |
| TAIResource config; |
| config.temperature = 0.2; |
| config.max_tokens = 32; |
| config.api_key = "test_gemini_key"; |
| adapter.init(config); |
| |
| // header test |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "x-goog-api-key: test_gemini_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hello gemini"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAIExtract::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body test |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| ASSERT_TRUE(doc.HasMember("generationConfig")); |
| const auto& gen_cfg = doc["generationConfig"]; |
| |
| //temperature |
| ASSERT_TRUE(gen_cfg.HasMember("temperature")); |
| ASSERT_TRUE(gen_cfg["temperature"].IsNumber()); |
| ASSERT_DOUBLE_EQ(gen_cfg["temperature"].GetDouble(), 0.2); |
| |
| //max_token |
| ASSERT_TRUE(gen_cfg.HasMember("maxOutputTokens")) << "Missing maxOutputTokens field"; |
| ASSERT_TRUE(gen_cfg["maxOutputTokens"].IsInt()); |
| ASSERT_EQ(gen_cfg["maxOutputTokens"].GetInt(), 32); |
| |
| // system_prompt |
| ASSERT_TRUE(doc.HasMember("systemInstruction")) << "Missing system field"; |
| ASSERT_TRUE(doc["systemInstruction"].IsObject()) << request_body; |
| ASSERT_TRUE(doc["systemInstruction"].HasMember("parts")) << request_body; |
| ASSERT_TRUE(doc["systemInstruction"]["parts"].IsArray()) << request_body; |
| ASSERT_GT(doc["systemInstruction"]["parts"].Size(), 0) << request_body; |
| |
| const auto& content_sys = doc["systemInstruction"]["parts"][0]; |
| ASSERT_TRUE(content_sys.HasMember("text")) << "parts missing text field"; |
| ASSERT_TRUE(content_sys["text"].IsString()) << "Text field is not a string"; |
| ASSERT_STREQ(content_sys["text"].GetString(), FunctionAIExtract::system_prompt); |
| |
| // content structure |
| ASSERT_TRUE(doc.HasMember("contents")) << "Missing contents field"; |
| ASSERT_TRUE(doc["contents"].IsArray()) << "Contents is not an array"; |
| ASSERT_GT(doc["contents"].Size(), 0) << "Contents array is empty"; |
| |
| // content |
| const auto& content = doc["contents"][0]; |
| ASSERT_TRUE(content.HasMember("parts")) << "Content missing parts field"; |
| ASSERT_TRUE(content["parts"].IsArray()) << "Parts is not an array"; |
| ASSERT_GT(content["parts"].Size(), 0) << "Parts array is empty"; |
| |
| const auto& part = content["parts"][0]; |
| ASSERT_TRUE(part.HasMember("text")) << "Part missing text field"; |
| ASSERT_TRUE(part["text"].IsString()) << "Text field is not a string"; |
| ASSERT_STREQ(part["text"].GetString(), "hello gemini"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response) { |
| GeminiAdapter adapter; |
| std::string resp = R"({"candidates":[{"content":{"parts":[{"text":"gemini result"}]}}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "gemini result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_request) { |
| AnthropicAdapter adapter; |
| TAIResource config; |
| config.model_name = "claude-3"; |
| config.temperature = 1.0; |
| config.max_tokens = 256; |
| config.api_key = "test_anthropic_key"; |
| config.anthropic_version = "2023-06-01"; |
| adapter.init(config); |
| |
| // header |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| |
| EXPECT_STREQ(mock_client.get()->data, "x-api-key: test_anthropic_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "anthropic-version: 2023-06-01"); |
| EXPECT_STREQ(mock_client.get()->next->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"hi anthropic"}; |
| std::string request_body; |
| Status st = |
| adapter.build_request_payload(inputs, FunctionAIClassify::system_prompt, request_body); |
| ASSERT_TRUE(st.ok()); |
| |
| // body |
| rapidjson::Document doc; |
| doc.Parse(request_body.c_str()); |
| ASSERT_FALSE(doc.HasParseError()) << "JSON parse error"; |
| ASSERT_TRUE(doc.IsObject()) << "JSON is not an object"; |
| |
| // model |
| ASSERT_TRUE(doc.HasMember("model")) << "Missing model field"; |
| ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string"; |
| ASSERT_STREQ(doc["model"].GetString(), "claude-3"); |
| |
| // temperature |
| ASSERT_TRUE(doc.HasMember("temperature")) << "Missing temperature field"; |
| ASSERT_TRUE(doc["temperature"].IsNumber()) << "Temperature field is not a number"; |
| ASSERT_DOUBLE_EQ(doc["temperature"].GetDouble(), 1.0); |
| |
| // max token |
| ASSERT_TRUE(doc.HasMember("max_tokens")) << "Missing max_tokens field"; |
| ASSERT_TRUE(doc["max_tokens"].IsInt()) << "Max_tokens field is not an integer"; |
| ASSERT_EQ(doc["max_tokens"].GetInt(), 256); |
| |
| // system_prompt |
| ASSERT_TRUE(doc.HasMember("system")) << "Missing system field"; |
| ASSERT_TRUE(doc["system"].IsString()) << "System field is not a string"; |
| ASSERT_STREQ(doc["system"].GetString(), FunctionAIClassify::system_prompt); |
| |
| // message Format |
| ASSERT_TRUE(doc.HasMember("messages")) << "Missing messages field"; |
| ASSERT_TRUE(doc["messages"].IsArray()) << "Messages is not an array"; |
| ASSERT_GT(doc["messages"].Size(), 0) << "Messages array is empty"; |
| |
| const auto& message = doc["messages"][0]; |
| ASSERT_TRUE(message.HasMember("role")) << "Message missing role field"; |
| ASSERT_TRUE(message.HasMember("content")) << "Message missing content field"; |
| |
| // content of the last message |
| if (message["content"].IsArray()) { |
| ASSERT_GT(message["content"].Size(), 0) << "Content array is empty"; |
| const auto& content = message["content"][0]; |
| ASSERT_TRUE(content.HasMember("type")) << "Content missing type field"; |
| ASSERT_TRUE(content.HasMember("text")) << "Content missing text field"; |
| ASSERT_STREQ(content["text"].GetString(), "hi anthropic"); |
| } else if (message["content"].IsString()) { |
| ASSERT_STREQ(message["content"].GetString(), "hi anthropic"); |
| } |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response) { |
| AnthropicAdapter adapter; |
| std::string resp = R"({"content":[{"type":"text","text":"anthropic result"}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_TRUE(st.ok()); |
| ASSERT_EQ(results.size(), 1); |
| ASSERT_EQ(results[0], "anthropic result"); |
| } |
| |
| TEST(AI_ADAPTER_TEST, unsupported_provider_type) { |
| TAIResource config; |
| config.provider_type = "not_exist"; |
| auto adapter = doris::vectorized::AIAdapterFactory::create_adapter(config.provider_type); |
| ASSERT_EQ(adapter, nullptr); |
| } |
| |
| TEST(AI_ADAPTER_TEST, adapter_factory_all_types) { |
| std::vector<std::string> types = {"LOCAL", "OPENAI", "MOONSHOT", "DEEPSEEK", |
| "MINIMAX", "ZHIPU", "QWEN", "BAICHUAN", |
| "ANTHROPIC", "GEMINI", "VOYAGEAI", "MOCK"}; |
| for (const auto& type : types) { |
| auto adapter = doris::vectorized::AIAdapterFactory::create_adapter(type); |
| ASSERT_TRUE(adapter != nullptr) << "Adapter not found for type: " << type; |
| } |
| } |
| |
| TEST(AI_ADAPTER_TEST, local_adapter_parse_response_parse_error) { |
| LocalAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| } |
| |
| TEST(AI_ADAPTER_TEST, parse_response_wrong_type) { |
| LocalAdapter adapter; |
| // response field is not a string |
| std::string resp = R"({"response":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), |
| ::testing::HasSubstr("Unsupported response format from local AI.")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choice_format_error) { |
| OpenAIAdapter adapter; |
| // message field missing |
| std::string resp = R"({"choices":[{}]})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response")); |
| |
| // content field is not a string |
| resp = R"({"choices":[{"message":{"content":123}}]})"; |
| results.clear(); |
| st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid choice format in response")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_parse_error) { |
| OpenAIAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, openai_adapter_parse_response_choices_not_array) { |
| OpenAIAdapter adapter; |
| std::string resp = R"({"choices":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_adapter_parse_response_parse_error) { |
| GeminiAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, gemini_parse_response_missing_candidates) { |
| GeminiAdapter adapter; |
| std::string resp = R"({"foo":"bar"})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_parse_error) { |
| AnthropicAdapter adapter; |
| std::string resp = "not a json"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, anthropic_adapter_parse_response_content_not_array) { |
| AnthropicAdapter adapter; |
| std::string resp = R"({"content":123})"; |
| std::vector<std::string> results; |
| Status st = adapter.parse_response(resp, results); |
| ASSERT_FALSE(st.ok()); |
| EXPECT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format")); |
| } |
| |
| TEST(AI_ADAPTER_TEST, voyage_adapter_chat_test) { |
| VoyageAIAdapter adapter; |
| TAIResource config; |
| config.api_key = "test_voyage_key"; |
| config.provider_type = "VoyageAI"; |
| |
| adapter.init(config); |
| MockHttpClient mock_client; |
| Status auth_status = adapter.set_authentication(&mock_client); |
| ASSERT_TRUE(auth_status.ok()); |
| EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_voyage_key"); |
| EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json"); |
| |
| std::vector<std::string> inputs = {"test_inputs"}; |
| std::string request_body; |
| Status st = adapter.build_request_payload(inputs, "test_system_prompt", request_body); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_STREQ(st.to_string().c_str(), |
| "[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation"); |
| |
| std::string response_body = "test_response_body"; |
| std::vector<std::string> result; |
| st = adapter.parse_response(response_body, result); |
| ASSERT_FALSE(st.ok()); |
| ASSERT_STREQ(st.to_string().c_str(), |
| "[NOT_IMPLEMENTED_ERROR]VoyageAI don't support text generation"); |
| } |
| |
| } // namespace doris::vectorized |