blob: d5b70292a9c0c5b688af0f8e9d09889c3ce704cf [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/functions/ai/embed.h"
#include <curl/curl.h>
#include <gen_cpp/PaloInternalService_types.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
#include "vec/functions/ai/ai_adapter.h"
namespace doris::vectorized {
class MockHttpClient : public HttpClient {
public:
curl_slist* get() { return this->_header_list; }
private:
std::unordered_map<std::string, std::string> _headers;
std::string _content_type;
};
TEST(EMBED_TEST, embed_function_build_test) {
FunctionEmbed function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"this is a test prompt"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "this is a test prompt");
}
TEST(EMBED_TEST, embed_function_test) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"test input"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto sentiment_func = FunctionEmbed::create();
Status exec_status =
sentiment_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
const auto& offsets = col_array.get_offsets();
ASSERT_EQ(offsets.size(), 1U);
const auto& nested_nullable_col = assert_cast<const ColumnNullable&>(col_array.get_data());
const auto& nested_col =
assert_cast<const ColumnFloat32&>(*nested_nullable_col.get_nested_column_ptr());
ASSERT_EQ(nested_col.size(), 5U);
for (int i = 0; i < 5; ++i) {
ASSERT_FLOAT_EQ(nested_col.get_element(i), static_cast<float>(i));
}
}
TEST(EMBED_TEST, local_adapter_embedding_request) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "local-embedding-model";
config.dimensions = 1536;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"test sentence for embedding"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "local-embedding-model");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsString() || doc["input"].IsArray())
<< "Input field is not a string or array";
if (doc["input"].IsString()) {
ASSERT_STREQ(doc["input"].GetString(), "test sentence for embedding");
} else {
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "test sentence for embedding");
}
}
TEST(EMBED_TEST, local_adapter_parse_embedding_response) {
LocalAdapter adapter;
std::string resp1 = R"({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"object": "embedding",
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "mxbai-embed-large",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp1, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
std::string resp2 = R"({
"embedding": [0.1, 0.2, 0.3]
})";
results.clear();
st = adapter.parse_embedding_response(resp2, results);
ASSERT_TRUE(st.ok()) << "Format 2 failed: " << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
std::string resp3 = R"({
"embeddings": [[0.6, 0.7]]
})";
results.clear();
st = adapter.parse_embedding_response(resp3, results);
ASSERT_TRUE(st.ok()) << "Format 3 failed: " << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.6F);
ASSERT_FLOAT_EQ(results[0][1], 0.7F);
}
TEST(EMBED_TEST, openai_adapter_embedding_request) {
OpenAIAdapter adapter;
TAIResource config;
config.model_name = "text-embedding-ada-002";
config.api_key = "test_openai_key";
config.dimensions = 1536;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "text-embedding-ada-002");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsArray()) << "Input field is not an array";
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "embed this text");
// should not support dimensions param
ASSERT_FALSE(doc.HasMember("dimensions"));
config.model_name = "text-embedding-3";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimensions")) << request_body;
ASSERT_TRUE(doc["dimensions"].IsInt()) << "Dimensions is not an integer";
ASSERT_EQ(doc["dimensions"].GetInt(), 1536);
}
TEST(EMBED_TEST, openai_adapter_parse_embedding_response) {
OpenAIAdapter adapter;
std::string resp = R"({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"object": "embedding",
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
}
TEST(EMBED_TEST, zhipu_embedding_request) {
ZhipuAdapter adapter;
TAIResource config;
config.model_name = "embedding-2";
config.api_key = "test_zhipu_key";
config.dimensions = 1024;
adapter.init(config);
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_zhipu_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_FALSE(doc.HasMember("dimensions")) << request_body;
config.model_name = "embedding-3";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimensions")) << request_body;
ASSERT_TRUE(doc["dimensions"].IsInt()) << "Dimensions is not an integer";
ASSERT_EQ(doc["dimensions"].GetInt(), config.dimensions);
}
TEST(EMBED_TEST, qwen_embedding_request) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "text-embedding-v2";
config.api_key = "test_qwen_key";
config.dimensions = 1024;
adapter.init(config);
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_qwen_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_FALSE(doc.HasMember("dimension")) << request_body;
config.model_name = "test-embedding-v4";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimension")) << request_body;
ASSERT_TRUE(doc["dimension"].IsInt()) << "Dimension is not an integer";
ASSERT_EQ(doc["dimension"].GetInt(), config.dimensions);
}
TEST(EMBED_TEST, gemini_adapter_embedding_request) {
GeminiAdapter adapter;
TAIResource config;
config.model_name = "embedding-001";
config.api_key = "test_gemini_key";
config.dimensions = 768;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "x-goog-api-key: test_gemini_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with gemini"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc.HasMember("content")) << "Missing content field";
ASSERT_TRUE(doc["content"].IsObject()) << request_body;
auto& content = doc["content"];
ASSERT_TRUE(content.HasMember("parts")) << request_body;
ASSERT_TRUE(content["parts"].IsArray());
ASSERT_TRUE(content["parts"][0].HasMember("text")) << request_body;
ASSERT_STREQ(content["parts"][0]["text"].GetString(), "embed with gemini");
// should not have dimension param;
ASSERT_FALSE(doc.HasMember("outputDimensionality"));
config.model_name = "gemini-embedding-001";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("outputDimensionality")) << request_body;
ASSERT_EQ(doc["outputDimensionality"].GetInt(), 768) << request_body;
}
TEST(EMBED_TEST, gemini_adapter_parse_embedding_response) {
GeminiAdapter adapter;
std::string resp = R"({
"embedding": {
"values":[
0.1,
0.2,
0.3
]
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
}
TEST(EMBED_TEST, voyageai_adapter_embedding_request) {
VoyageAIAdapter adapter;
TAIResource config;
config.model_name = "voyage-multimodal-3";
config.api_key = "test_voyage_key";
config.dimensions = 1024;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_voyage_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with voyage"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "voyage-multimodal-3");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsArray()) << "Input field is not an array";
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "embed with voyage");
// dimension parameter
ASSERT_FALSE(doc.HasMember("output_dimension"));
config.model_name = "voyage-3.5";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("model"));
ASSERT_STREQ(doc["model"].GetString(), "voyage-3.5") << request_body;
ASSERT_TRUE(doc.HasMember("output_dimension")) << request_body;
}
TEST(EMBED_TEST, voyageai_adapter_parse_embedding_response) {
VoyageAIAdapter adapter;
std::string resp = R"({
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
}
TEST(EMBED_TEST, voyageai_adapter_parse_error_test) {
VoyageAIAdapter adapter;
// doc is not an object
std::string resp = R"(
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
)";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse response"));
// `data` is not an array
resp = R"({
"object": "list",
"data": {
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
// member `embedding` is missing
resp = R"({
"object": "list",
"data": [
{
"embeddings": [0.1, 0.2, 0.3],
"index": 0
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
}
TEST(EMBED_TEST, deepseek_adapter_embedding_request) {
DeepSeekAdapter adapter;
TAIResource config;
config.model_name = "deepseek-embedding";
config.api_key = "test_deepseek_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_deepseek_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with deepseek"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_FALSE(st.ok());
}
TEST(EMBED_TEST, deepseek_adapter_parse_embedding_response) {
DeepSeekAdapter adapter;
std::string resp = R"({
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0,
"object": "embedding"
}
],
"model": "deepseek-embedding",
"object": "list",
"usage": {
"prompt_tokens": 4,
"total_tokens": 4
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
}
TEST(EMBED_TEST, moonshot_adapter_embedding_request) {
MoonShotAdapter adapter;
TAIResource config;
config.model_name = "moonshot-embedding";
config.api_key = "test_moonshot_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_moonshot_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with moonshot"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_FALSE(st.ok());
}
TEST(EMBED_TEST, moonshot_adapter_parse_embedding_response) {
MoonShotAdapter adapter;
std::string resp = R"({
"id": "embedding-123",
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0,
"object": "embedding"
}
],
"model": "moonshot-embedding",
"usage": {
"prompt_tokens": 3,
"total_tokens": 3
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
}
TEST(EMBED_TEST, minimax_adapter_embedding_request) {
MinimaxAdapter adapter;
TAIResource config;
config.model_name = "minimax-embedding";
config.api_key = "test_minimax_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_minimax_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with minimax"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "minimax-embedding");
// type
ASSERT_TRUE(doc.HasMember("type")) << "Missing type field";
ASSERT_TRUE(doc["type"].IsString());
ASSERT_STREQ(doc["type"].GetString(), "db");
// input
ASSERT_TRUE(doc.HasMember("texts")) << "Missing texts field";
ASSERT_TRUE(doc["texts"].IsArray()) << "Texts field is not an array";
ASSERT_EQ(doc["texts"].Size(), 1);
ASSERT_STREQ(doc["texts"][0].GetString(), "embed with minimax");
}
} // namespace doris::vectorized