blob: 2074697614252a6c5cd44a161817b07673db8a80 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "exprs/function/ai/embed.h"
#include <curl/curl.h>
#include <fmt/format.h>
#include <gen_cpp/PaloInternalService_types.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "core/data_type/data_type_jsonb.h"
#include "core/data_type/data_type_number.h"
#include "core/value/jsonb_value.h"
#include "exprs/function/ai/ai_adapter.h"
#include "io/fs/obj_storage_client.h"
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
namespace doris {
class MockHttpClient : public HttpClient {
public:
curl_slist* get() { return this->_header_list; }
private:
std::unordered_map<std::string, std::string> _headers;
std::string _content_type;
};
class MockEmbedObjStorageClient : public io::ObjStorageClient {
public:
io::ObjectStorageUploadResponse create_multipart_upload(
const io::ObjectStoragePathOptions& /*opts*/) override {
return {};
}
io::ObjectStorageResponse put_object(const io::ObjectStoragePathOptions& /*opts*/,
std::string_view /*stream*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageUploadResponse upload_part(const io::ObjectStoragePathOptions& /*opts*/,
std::string_view /*stream*/,
int /*part_num*/) override {
return {};
}
io::ObjectStorageResponse complete_multipart_upload(
const io::ObjectStoragePathOptions& /*opts*/,
const std::vector<io::ObjectCompleteMultiPart>& /*completed_parts*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageHeadResponse head_object(
const io::ObjectStoragePathOptions& /*opts*/) override {
return {};
}
io::ObjectStorageResponse get_object(const io::ObjectStoragePathOptions& /*opts*/,
void* /*buffer*/, size_t /*offset*/, size_t /*bytes_read*/,
size_t* /*size_return*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageResponse list_objects(const io::ObjectStoragePathOptions& /*opts*/,
std::vector<io::FileInfo>* /*files*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageResponse delete_objects(const io::ObjectStoragePathOptions& /*opts*/,
std::vector<std::string> /*objs*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageResponse delete_object(const io::ObjectStoragePathOptions& /*opts*/) override {
return io::ObjectStorageResponse::OK();
}
io::ObjectStorageResponse delete_objects_recursively(
const io::ObjectStoragePathOptions& /*opts*/) override {
return io::ObjectStorageResponse::OK();
}
std::string generate_presigned_url(const io::ObjectStoragePathOptions& opts,
int64_t expiration_secs, const S3ClientConf& conf) override {
last_opts = opts;
last_expiration_secs = expiration_secs;
last_conf = conf;
return fmt::format("mock-s3://{}/{}?ttl={}", opts.bucket, opts.key, expiration_secs);
}
io::ObjectStoragePathOptions last_opts;
int64_t last_expiration_secs = 0;
S3ClientConf last_conf;
};
class CountingMultimodalMockAdapter : public MockAdapter {
public:
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
const std::vector<std::string>& media_content_types,
std::string& request_body) const override {
EXPECT_EQ(media_types.size(), media_urls.size());
EXPECT_EQ(media_content_types.size(), media_urls.size());
batch_sizes.push_back(media_urls.size());
request_body = "{}";
return Status::OK();
}
mutable std::vector<size_t> batch_sizes;
};
class CountingTextMockAdapter : public MockAdapter {
public:
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
batch_sizes.push_back(inputs.size());
request_body = "{}";
return Status::OK();
}
mutable std::vector<size_t> batch_sizes;
};
static ColumnString::MutablePtr create_jsonb_column(const std::vector<std::string>& json_rows) {
auto column = ColumnString::create();
for (const auto& json_row : json_rows) {
JsonBinaryValue jsonb_value;
Status st = jsonb_value.from_json_string(json_row);
EXPECT_TRUE(st.ok()) << st.to_string();
column->insert_data(jsonb_value.value(), jsonb_value.size());
}
return column;
}
static void assert_mock_embedding_column(const ColumnArray& col_array, size_t row_count) {
const auto& offsets = col_array.get_offsets();
ASSERT_EQ(offsets.size(), row_count);
const auto& nested_nullable_col = assert_cast<const ColumnNullable&>(col_array.get_data());
const auto& nested_col =
assert_cast<const ColumnFloat32&>(*nested_nullable_col.get_nested_column_ptr());
ASSERT_EQ(nested_col.size(), row_count * 5);
for (size_t row = 0; row < row_count; ++row) {
ASSERT_EQ(offsets[row], (row + 1) * 5);
for (size_t i = 0; i < 5; ++i) {
ASSERT_FLOAT_EQ(nested_col.get_element(row * 5 + i), static_cast<float>(i));
}
}
}
TEST(EMBED_TEST, embed_function_build_test) {
FunctionEmbed function;
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"this is a test prompt"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
ColumnNumbers arguments = {0, 1};
std::string prompt;
Status status = function.build_prompt(block, arguments, 0, prompt);
ASSERT_TRUE(status.ok());
ASSERT_EQ(prompt, "this is a test prompt");
}
TEST(EMBED_TEST, embed_function_test) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"test input"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto sentiment_func = FunctionEmbed::create();
Status exec_status =
sentiment_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
const auto& offsets = col_array.get_offsets();
ASSERT_EQ(offsets.size(), 1U);
const auto& nested_nullable_col = assert_cast<const ColumnNullable&>(col_array.get_data());
const auto& nested_col =
assert_cast<const ColumnFloat32&>(*nested_nullable_col.get_nested_column_ptr());
ASSERT_EQ(nested_col.size(), 5U);
for (int i = 0; i < 5; ++i) {
ASSERT_FLOAT_EQ(nested_col.get_element(i), static_cast<float>(i));
}
}
TEST(EMBED_TEST, embed_function_text_multi_rows) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource"};
std::vector<std::string> texts = {"test input 1", "test input 2"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status =
embed_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, texts.size());
}
TEST(EMBED_TEST, embed_function_multimodal_direct_url) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_file_presigned_url_ttl_seconds(0);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource", "mock_resource"};
std::vector<std::string> file_json_rows = {
R"({"content_type":"image/png","uri":"https://example.com/a.png"})",
R"({"content_type":"video/mp4","uri":"https://example.com/b.mp4"})",
R"({"content_type":"audio/mpeg","uri":"https://example.com/c.mp3"})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, file_json_rows.size());
}
TEST(EMBED_TEST, embed_function_multimodal_batch_request) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource", "mock_resource"};
std::vector<std::string> file_json_rows = {
R"({"content_type":"image/png","uri":"https://example.com/a.png"})",
R"({"content_type":"video/mp4","uri":"https://example.com/b.mp4"})",
R"({"content_type":"audio/mpeg","uri":"https://example.com/c.mp3"})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
TAIResource config;
config.provider_type = "MOCK";
auto counting_adapter = std::make_shared<CountingMultimodalMockAdapter>();
std::shared_ptr<AIAdapter> adapter = counting_adapter;
adapter->init(config);
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
FunctionEmbed embed_func;
Status exec_status = embed_func.execute_with_adapter(ctx.get(), block, arguments, result_idx,
file_json_rows.size(), config, adapter);
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
EXPECT_THAT(counting_adapter->batch_sizes, ::testing::ElementsAre(3));
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, file_json_rows.size());
}
TEST(EMBED_TEST, embed_function_multimodal_batch_split_by_session_variable) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_embed_max_batch_size(2);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource", "mock_resource"};
std::vector<std::string> file_json_rows = {
R"({"content_type":"image/png","uri":"https://example.com/a.png"})",
R"({"content_type":"image/png","uri":"https://example.com/b.png"})",
R"({"content_type":"image/png","uri":"https://example.com/c.png"})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
TAIResource config;
config.provider_type = "MOCK";
auto counting_adapter = std::make_shared<CountingMultimodalMockAdapter>();
std::shared_ptr<AIAdapter> adapter = counting_adapter;
adapter->init(config);
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
FunctionEmbed embed_func;
Status exec_status = embed_func.execute_with_adapter(ctx.get(), block, arguments, result_idx,
file_json_rows.size(), config, adapter);
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
EXPECT_THAT(counting_adapter->batch_sizes, ::testing::ElementsAre(2, 1));
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, file_json_rows.size());
}
TEST(EMBED_TEST, embed_function_text_batch_split_by_session_variable) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_embed_max_batch_size(2);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
std::vector<std::string> resources = {"mock_resource", "mock_resource", "mock_resource"};
std::vector<std::string> texts = {"text-a", "text-b", "text-c"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
TAIResource config;
config.provider_type = "MOCK";
auto counting_adapter = std::make_shared<CountingTextMockAdapter>();
std::shared_ptr<AIAdapter> adapter = counting_adapter;
adapter->init(config);
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
FunctionEmbed embed_func;
Status exec_status = embed_func.execute_with_adapter(ctx.get(), block, arguments, result_idx,
texts.size(), config, adapter);
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
EXPECT_THAT(counting_adapter->batch_sizes, ::testing::ElementsAre(2, 1));
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, texts.size());
}
TEST(EMBED_TEST, embed_function_multimodal_s3_presigned_url) {
TQueryOptions query_options = create_fake_query_options();
query_options.__set_file_presigned_url_ttl_seconds(123);
auto query_ctx = MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), query_options);
query_ctx->set_mock_ai_resource();
TQueryGlobals query_globals;
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
query_ctx.get());
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
auto mock_client = std::make_shared<MockEmbedObjStorageClient>();
S3ClientFactory::instance().set_client_creator_for_test(
[mock_client](const S3ClientConf&) { return mock_client; });
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> file_json_rows = {R"({
"content_type":"image/png",
"uri":"s3://test-bucket/path/to/image.png",
"endpoint":"cos.ap-beijing.myqcloud.com",
"region":"ap-beijing",
"ak":"test-ak",
"sk":"test-sk",
"role_arn":"test-role",
"external_id":"test-external-id"
})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
S3ClientFactory::instance().clear_client_creator_for_test();
ASSERT_TRUE(exec_status.ok()) << exec_status.to_string();
const auto& col_array =
assert_cast<const ColumnArray&>(*block.get_by_position(result_idx).column);
assert_mock_embedding_column(col_array, file_json_rows.size());
ASSERT_EQ(mock_client->last_opts.bucket, "test-bucket");
ASSERT_EQ(mock_client->last_opts.key, "path/to/image.png");
ASSERT_EQ(mock_client->last_expiration_secs, 123);
ASSERT_EQ(mock_client->last_conf.endpoint, "cos.ap-beijing.myqcloud.com");
ASSERT_EQ(mock_client->last_conf.region, "ap-beijing");
ASSERT_EQ(mock_client->last_conf.ak, "test-ak");
ASSERT_EQ(mock_client->last_conf.sk, "test-sk");
ASSERT_EQ(mock_client->last_conf.role_arn, "test-role");
ASSERT_EQ(mock_client->last_conf.external_id, "test-external-id");
}
TEST(EMBED_TEST, embed_function_multimodal_s3_missing_endpoint) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> file_json_rows = {R"({
"content_type":"image/png",
"uri":"s3://test-bucket/path/to/image.png",
"region":"ap-beijing"
})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find("field 'endpoint' is required"), std::string::npos);
}
TEST(EMBED_TEST, embed_function_multimodal_s3_missing_region) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> file_json_rows = {R"({
"content_type":"image/png",
"uri":"s3://test-bucket/path/to/image.png",
"endpoint":"cos.ap-beijing.myqcloud.com"
})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find("field 'region' is required"), std::string::npos);
}
TEST(EMBED_TEST, embed_function_wrong_argument_count) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0};
size_t result_idx = 1;
auto embed_func = FunctionEmbed::create();
Status exec_status =
embed_func->execute_impl(ctx.get(), block, arguments, result_idx, resources.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find("Function EMBED expects 2 arguments"),
std::string::npos);
}
TEST(EMBED_TEST, embed_function_invalid_input_type) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<Int32> ids = {1};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_ids = ColumnHelper::create_column<DataTypeInt32>(ids);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_ids), std::make_shared<DataTypeInt32>(), "id"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status =
embed_func->execute_impl(ctx.get(), block, arguments, result_idx, resources.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find(
"Function EMBED expects the second argument to be STRING or JSON"),
std::string::npos);
}
TEST(EMBED_TEST, embed_function_missing_required_json_field) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> file_json_rows = {R"({"uri":"https://example.com/a.png"})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find("field 'content_type' is required"), std::string::npos);
}
TEST(EMBED_TEST, embed_function_unsupported_content_type) {
auto runtime_state = std::make_unique<MockRuntimeState>();
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> file_json_rows = {
R"({"content_type":"text/plain","uri":"https://example.com/a.txt"})"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_file = create_jsonb_column(file_json_rows);
Block block;
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
block.insert({std::move(col_file), std::make_shared<DataTypeJsonb>(), "file"});
block.insert(
{nullptr,
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())),
"result"});
ColumnNumbers arguments = {0, 1};
size_t result_idx = 2;
auto embed_func = FunctionEmbed::create();
Status exec_status = embed_func->execute_impl(ctx.get(), block, arguments, result_idx,
file_json_rows.size());
ASSERT_FALSE(exec_status.ok());
ASSERT_NE(exec_status.to_string().find("Unsupported content_type for EMBED"),
std::string::npos);
}
TEST(EMBED_TEST, local_adapter_embedding_request) {
LocalAdapter adapter;
TAIResource config;
config.model_name = "local-embedding-model";
config.dimensions = 1536;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"test sentence for embedding"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "local-embedding-model");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsString() || doc["input"].IsArray())
<< "Input field is not a string or array";
if (doc["input"].IsString()) {
ASSERT_STREQ(doc["input"].GetString(), "test sentence for embedding");
} else {
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "test sentence for embedding");
}
}
TEST(EMBED_TEST, local_adapter_parse_embedding_response) {
LocalAdapter adapter;
std::string resp1 = R"({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"object": "embedding",
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "mxbai-embed-large",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp1, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
std::string resp2 = R"({
"embedding": [0.1, 0.2, 0.3]
})";
results.clear();
st = adapter.parse_embedding_response(resp2, results);
ASSERT_TRUE(st.ok()) << "Format 2 failed: " << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
std::string resp3 = R"({
"embeddings": [[0.6, 0.7]]
})";
results.clear();
st = adapter.parse_embedding_response(resp3, results);
ASSERT_TRUE(st.ok()) << "Format 3 failed: " << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.6F);
ASSERT_FLOAT_EQ(results[0][1], 0.7F);
}
TEST(EMBED_TEST, openai_adapter_embedding_request) {
OpenAIAdapter adapter;
TAIResource config;
config.model_name = "text-embedding-ada-002";
config.api_key = "test_openai_key";
config.dimensions = 1536;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_openai_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "text-embedding-ada-002");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsArray()) << "Input field is not an array";
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "embed this text");
// should not support dimensions param
ASSERT_FALSE(doc.HasMember("dimensions"));
config.model_name = "text-embedding-3";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimensions")) << request_body;
ASSERT_TRUE(doc["dimensions"].IsInt()) << "Dimensions is not an integer";
ASSERT_EQ(doc["dimensions"].GetInt(), 1536);
}
TEST(EMBED_TEST, openai_adapter_parse_embedding_response) {
OpenAIAdapter adapter;
std::string resp = R"({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"object": "embedding",
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
}
TEST(EMBED_TEST, zhipu_embedding_request) {
ZhipuAdapter adapter;
TAIResource config;
config.model_name = "embedding-2";
config.api_key = "test_zhipu_key";
config.dimensions = 1024;
adapter.init(config);
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_zhipu_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_FALSE(doc.HasMember("dimensions")) << request_body;
config.model_name = "embedding-3";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimensions")) << request_body;
ASSERT_TRUE(doc["dimensions"].IsInt()) << "Dimensions is not an integer";
ASSERT_EQ(doc["dimensions"].GetInt(), config.dimensions);
}
TEST(EMBED_TEST, qwen_embedding_request) {
QwenAdapter adapter;
TAIResource config;
config.model_name = "text-embedding-v2";
config.api_key = "test_qwen_key";
config.dimensions = 1024;
adapter.init(config);
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_qwen_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed this text"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_FALSE(doc.HasMember("dimension")) << request_body;
config.model_name = "test-embedding-v4";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("dimension")) << request_body;
ASSERT_TRUE(doc["dimension"].IsInt()) << "Dimension is not an integer";
ASSERT_EQ(doc["dimension"].GetInt(), config.dimensions);
}
TEST(EMBED_TEST, gemini_adapter_embedding_request) {
GeminiAdapter adapter;
TAIResource config;
config.model_name = "embedding-001";
config.api_key = "test_gemini_key";
config.dimensions = 768;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "x-goog-api-key: test_gemini_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with gemini", "embed batch with gemini"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("requests")) << "Missing requests field";
ASSERT_TRUE(doc["requests"].IsArray()) << request_body;
ASSERT_EQ(doc["requests"].Size(), 2);
const auto& request0 = doc["requests"][0];
ASSERT_TRUE(request0.HasMember("model")) << request_body;
ASSERT_STREQ(request0["model"].GetString(), "models/embedding-001");
ASSERT_TRUE(request0.HasMember("content")) << request_body;
ASSERT_TRUE(request0["content"].IsObject()) << request_body;
ASSERT_TRUE(request0["content"].HasMember("parts")) << request_body;
ASSERT_TRUE(request0["content"]["parts"].IsArray()) << request_body;
ASSERT_EQ(request0["content"]["parts"].Size(), 1);
ASSERT_TRUE(request0["content"]["parts"][0].HasMember("text")) << request_body;
ASSERT_STREQ(request0["content"]["parts"][0]["text"].GetString(), "embed with gemini");
ASSERT_FALSE(request0.HasMember("outputDimensionality"));
const auto& request1 = doc["requests"][1];
ASSERT_TRUE(request1.HasMember("model")) << request_body;
ASSERT_STREQ(request1["model"].GetString(), "models/embedding-001");
ASSERT_TRUE(request1.HasMember("content")) << request_body;
ASSERT_TRUE(request1["content"].IsObject()) << request_body;
ASSERT_TRUE(request1["content"].HasMember("parts")) << request_body;
ASSERT_TRUE(request1["content"]["parts"].IsArray()) << request_body;
ASSERT_EQ(request1["content"]["parts"].Size(), 1);
ASSERT_TRUE(request1["content"]["parts"][0].HasMember("text")) << request_body;
ASSERT_STREQ(request1["content"]["parts"][0]["text"].GetString(), "embed batch with gemini");
config.model_name = "gemini-embedding-001";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("requests")) << request_body;
ASSERT_TRUE(doc["requests"].IsArray()) << request_body;
ASSERT_EQ(doc["requests"].Size(), 2);
ASSERT_TRUE(doc["requests"][0].HasMember("outputDimensionality")) << request_body;
ASSERT_EQ(doc["requests"][0]["outputDimensionality"].GetInt(), 768) << request_body;
ASSERT_TRUE(doc["requests"][1].HasMember("outputDimensionality")) << request_body;
ASSERT_EQ(doc["requests"][1]["outputDimensionality"].GetInt(), 768) << request_body;
}
TEST(EMBED_TEST, gemini_adapter_parse_embedding_response) {
GeminiAdapter adapter;
std::string resp = R"({
"embedding": {
"values":[
0.1,
0.2,
0.3
]
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
resp = R"({
"embeddings": [
{
"values":[
1.1,
1.2
]
},
{
"values":[
2.1,
2.2,
2.3
]
}
]
})";
results.clear();
st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok()) << st.to_string();
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 2);
ASSERT_EQ(results[1].size(), 3);
ASSERT_FLOAT_EQ(results[0][0], 1.1F);
ASSERT_FLOAT_EQ(results[0][1], 1.2F);
ASSERT_FLOAT_EQ(results[1][0], 2.1F);
ASSERT_FLOAT_EQ(results[1][1], 2.2F);
ASSERT_FLOAT_EQ(results[1][2], 2.3F);
}
TEST(EMBED_TEST, voyageai_adapter_embedding_request) {
VoyageAIAdapter adapter;
TAIResource config;
config.model_name = "voyage-multimodal-3";
config.api_key = "test_voyage_key";
config.dimensions = 1024;
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
EXPECT_STREQ(mock_client.get()->data, "Authorization: Bearer test_voyage_key");
EXPECT_STREQ(mock_client.get()->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with voyage"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "voyage-multimodal-3");
// input
ASSERT_TRUE(doc.HasMember("input")) << "Missing input field";
ASSERT_TRUE(doc["input"].IsArray()) << "Input field is not an array";
ASSERT_EQ(doc["input"].Size(), 1);
ASSERT_STREQ(doc["input"][0].GetString(), "embed with voyage");
// dimension parameter
ASSERT_FALSE(doc.HasMember("output_dimension"));
config.model_name = "voyage-3.5";
adapter.init(config);
st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
ASSERT_TRUE(doc.HasMember("model"));
ASSERT_STREQ(doc["model"].GetString(), "voyage-3.5") << request_body;
ASSERT_TRUE(doc.HasMember("output_dimension")) << request_body;
}
TEST(EMBED_TEST, voyageai_adapter_parse_embedding_response) {
VoyageAIAdapter adapter;
std::string resp = R"({
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_TRUE(st.ok());
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0].size(), 3);
ASSERT_EQ(results[1].size(), 2);
ASSERT_FLOAT_EQ(results[0][0], 0.1F);
ASSERT_FLOAT_EQ(results[0][1], 0.2F);
ASSERT_FLOAT_EQ(results[0][2], 0.3F);
ASSERT_FLOAT_EQ(results[1][0], 0.4F);
ASSERT_FLOAT_EQ(results[1][1], 0.5F);
}
TEST(EMBED_TEST, voyageai_adapter_parse_error_test) {
VoyageAIAdapter adapter;
// doc is not an object
std::string resp = R"(
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
{
"embedding": [0.4, 0.5],
"index": 1
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
)";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Failed to parse response"));
// `data` is not an array
resp = R"({
"object": "list",
"data": {
"embedding": [0.1, 0.2, 0.3],
"index": 0
},
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
// member `embedding` is missing
resp = R"({
"object": "list",
"data": [
{
"embeddings": [0.1, 0.2, 0.3],
"index": 0
}
],
"model": "voyage-3.5",
"usage": {
"total_tokens": 10
}
})";
st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string().c_str(), ::testing::HasSubstr("Invalid response format"));
}
TEST(EMBED_TEST, deepseek_adapter_embedding_request) {
DeepSeekAdapter adapter;
TAIResource config;
config.model_name = "deepseek-embedding";
config.api_key = "test_deepseek_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_deepseek_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with deepseek"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("Currently supported providers are OpenAI, Gemini, "
"Voyage, Jina, Qwen, and Minimax"));
}
TEST(EMBED_TEST, deepseek_adapter_parse_embedding_response) {
DeepSeekAdapter adapter;
std::string resp = R"({
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0,
"object": "embedding"
}
],
"model": "deepseek-embedding",
"object": "list",
"usage": {
"prompt_tokens": 4,
"total_tokens": 4
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("Currently supported providers are OpenAI, Gemini, "
"Voyage, Jina, Qwen, and Minimax"));
}
TEST(EMBED_TEST, moonshot_adapter_embedding_request) {
MoonShotAdapter adapter;
TAIResource config;
config.model_name = "moonshot-embedding";
config.api_key = "test_moonshot_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_moonshot_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with moonshot"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("Currently supported providers are OpenAI, Gemini, "
"Voyage, Jina, Qwen, and Minimax"));
}
TEST(EMBED_TEST, moonshot_adapter_parse_embedding_response) {
MoonShotAdapter adapter;
std::string resp = R"({
"id": "embedding-123",
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0,
"object": "embedding"
}
],
"model": "moonshot-embedding",
"usage": {
"prompt_tokens": 3,
"total_tokens": 3
}
})";
std::vector<std::vector<float>> results;
Status st = adapter.parse_embedding_response(resp, results);
ASSERT_FALSE(st.ok());
ASSERT_THAT(st.to_string(),
::testing::HasSubstr("Currently supported providers are OpenAI, Gemini, "
"Voyage, Jina, Qwen, and Minimax"));
}
TEST(EMBED_TEST, minimax_adapter_embedding_request) {
MinimaxAdapter adapter;
TAIResource config;
config.model_name = "minimax-embedding";
config.api_key = "test_minimax_key";
adapter.init(config);
// header test
MockHttpClient mock_client;
Status auth_status = adapter.set_authentication(&mock_client);
ASSERT_TRUE(auth_status.ok());
curl_slist* headers = mock_client.get();
ASSERT_TRUE(headers != nullptr);
EXPECT_STREQ(headers->data, "Authorization: Bearer test_minimax_key");
EXPECT_STREQ(headers->next->data, "Content-Type: application/json");
std::vector<std::string> inputs = {"embed with minimax"};
std::string request_body;
Status st = adapter.build_embedding_request(inputs, request_body);
ASSERT_TRUE(st.ok());
// body test
rapidjson::Document doc;
doc.Parse(request_body.c_str());
ASSERT_FALSE(doc.HasParseError()) << "JSON parse error";
ASSERT_TRUE(doc.IsObject()) << "JSON is not an object";
// model name
ASSERT_TRUE(doc.HasMember("model")) << "Missing model field";
ASSERT_TRUE(doc["model"].IsString()) << "Model field is not a string";
ASSERT_STREQ(doc["model"].GetString(), "minimax-embedding");
// type
ASSERT_TRUE(doc.HasMember("type")) << "Missing type field";
ASSERT_TRUE(doc["type"].IsString());
ASSERT_STREQ(doc["type"].GetString(), "db");
// input
ASSERT_TRUE(doc.HasMember("texts")) << "Missing texts field";
ASSERT_TRUE(doc["texts"].IsArray()) << "Texts field is not an array";
ASSERT_EQ(doc["texts"].Size(), 1);
ASSERT_STREQ(doc["texts"][0].GetString(), "embed with minimax");
}
} // namespace doris