blob: b83aa26c51a85700cc755e7da8c4661b3e40ef33 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <gen_cpp/PaloInternalService_types.h>
#include <rapidjson/rapidjson.h>
#include <algorithm>
#include <cctype>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "common/status.h"
#include "core/string_buffer.hpp"
#include "rapidjson/document.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
#include "service/http/http_client.h"
#include "service/http/http_headers.h"
#include "util/security.h"
namespace doris {
struct AIResource {
AIResource() = default;
AIResource(const TAIResource& tai)
: endpoint(tai.endpoint),
provider_type(tai.provider_type),
model_name(tai.model_name),
api_key(tai.api_key),
temperature(tai.temperature),
max_tokens(tai.max_tokens),
max_retries(tai.max_retries),
retry_delay_second(tai.retry_delay_second),
anthropic_version(tai.anthropic_version),
dimensions(tai.dimensions) {}
std::string endpoint;
std::string provider_type;
std::string model_name;
std::string api_key;
double temperature;
int64_t max_tokens;
int32_t max_retries;
int32_t retry_delay_second;
std::string anthropic_version;
int32_t dimensions;
void serialize(BufferWritable& buf) const {
buf.write_binary(endpoint);
buf.write_binary(provider_type);
buf.write_binary(model_name);
buf.write_binary(api_key);
buf.write_binary(temperature);
buf.write_binary(max_tokens);
buf.write_binary(max_retries);
buf.write_binary(retry_delay_second);
buf.write_binary(anthropic_version);
buf.write_binary(dimensions);
}
void deserialize(BufferReadable& buf) {
buf.read_binary(endpoint);
buf.read_binary(provider_type);
buf.read_binary(model_name);
buf.read_binary(api_key);
buf.read_binary(temperature);
buf.read_binary(max_tokens);
buf.read_binary(max_retries);
buf.read_binary(retry_delay_second);
buf.read_binary(anthropic_version);
buf.read_binary(dimensions);
}
};
enum class MultimodalType { IMAGE, VIDEO, AUDIO };
inline const char* multimodal_type_to_string(MultimodalType type) {
switch (type) {
case MultimodalType::IMAGE:
return "image";
case MultimodalType::VIDEO:
return "video";
case MultimodalType::AUDIO:
return "audio";
}
return "unknown";
}
class AIAdapter {
public:
virtual ~AIAdapter() = default;
// Set authentication headers for the HTTP client
virtual Status set_authentication(HttpClient* client) const = 0;
virtual void init(const TAIResource& config) { _config = config; }
virtual void init(const AIResource& config) {
_config.endpoint = config.endpoint;
_config.provider_type = config.provider_type;
_config.model_name = config.model_name;
_config.api_key = config.api_key;
_config.temperature = config.temperature;
_config.max_tokens = config.max_tokens;
_config.max_retries = config.max_retries;
_config.retry_delay_second = config.retry_delay_second;
_config.anthropic_version = config.anthropic_version;
}
// Build request payload based on input text strings
virtual Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const {
return Status::NotSupported("{} don't support text generation", _config.provider_type);
}
// Parse response from AI service and extract generated text results
virtual Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const {
return Status::NotSupported("{} don't support text generation", _config.provider_type);
}
virtual Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const {
return embed_not_supported_status();
}
virtual Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& /*media_types*/,
const std::vector<std::string>& /*media_urls*/,
const std::vector<std::string>& /*media_content_types*/,
std::string& /*request_body*/) const {
return Status::NotSupported("{} does not support multimodal Embed feature.",
_config.provider_type);
}
virtual Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const {
return embed_not_supported_status();
}
protected:
TAIResource _config;
Status embed_not_supported_status() const {
return Status::NotSupported(
"{} does not support the Embed feature. Currently supported providers are "
"OpenAI, Gemini, Voyage, Jina, Qwen, and Minimax.",
_config.provider_type);
}
// Appends one provider-parsed text result to `results`.
// The adapter has already parsed the provider's outer response envelope before calling here.
// Example:
// provider response -> choices[0].message.content = "[\"1\",\"0\",\"1\"]"
// this helper -> appends "1", "0", "1" into `results`
static Status append_parsed_text_result(std::string_view text,
std::vector<std::string>& results) {
size_t begin = 0;
size_t end = text.size();
while (begin < end && std::isspace(static_cast<unsigned char>(text[begin]))) {
++begin;
}
while (begin < end && std::isspace(static_cast<unsigned char>(text[end - 1]))) {
--end;
}
if (begin < end && text[begin] == '[' && text[end - 1] == ']') {
rapidjson::Document doc;
doc.Parse(text.data() + begin, end - begin);
if (!doc.HasParseError() && doc.IsArray()) {
for (rapidjson::SizeType i = 0; i < doc.Size(); ++i) {
if (!doc[i].IsString()) {
return Status::InternalError(
"Invalid batch result format, array element {} is not a string", i);
}
results.emplace_back(doc[i].GetString(), doc[i].GetStringLength());
}
return Status::OK();
}
}
results.emplace_back(text.data(), text.size());
return Status::OK();
}
// return true if the model support dimension parameter
virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
// Different providers may have different dimension parameter names.
virtual std::string get_dimension_param_name() const { return "dimensions"; }
virtual void add_dimension_params(rapidjson::Value& doc,
rapidjson::Document::AllocatorType& allocator) const {
if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
std::string param_name = get_dimension_param_name();
rapidjson::Value name(param_name.c_str(), allocator);
doc.AddMember(name, _config.dimensions, allocator);
}
}
// Validates common multimodal embedding request invariants shared by providers.
Status validate_multimodal_embedding_inputs(
std::string_view provider_name, const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
std::initializer_list<MultimodalType> supported_types) const {
if (media_urls.empty()) {
return Status::InvalidArgument("{} multimodal embed inputs can not be empty",
provider_name);
}
if (media_types.size() != media_urls.size()) {
return Status::InvalidArgument(
"{} multimodal embed input size mismatch, media_types={}, media_urls={}",
provider_name, media_types.size(), media_urls.size());
}
for (MultimodalType media_type : media_types) {
bool supported = false;
for (MultimodalType supported_type : supported_types) {
if (media_type == supported_type) {
supported = true;
break;
}
}
if (!supported) [[unlikely]] {
return Status::InvalidArgument(
"{} only supports {} multimodal embed, got {}", provider_name,
supported_multimodal_types_to_string(supported_types),
multimodal_type_to_string(media_type));
}
}
return Status::OK();
}
static std::string supported_multimodal_types_to_string(
std::initializer_list<MultimodalType> supported_types) {
std::string result;
for (MultimodalType type : supported_types) {
if (!result.empty()) {
result += "/";
}
result += multimodal_type_to_string(type);
}
return result;
}
};
// Most LLM-providers' Embedding formats are based on VoyageAI.
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
class VoyageAIAdapter : public AIAdapter {
public:
Status set_authentication(HttpClient* client) const override {
client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
client->set_content_type("application/json");
return Status::OK();
}
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"model": "xxx",
"input": [
"xxx",
"xxx",
...
],
"output_dimensions": 512
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
add_dimension_params(doc, allocator);
rapidjson::Value input(rapidjson::kArrayType);
for (const auto& msg : inputs) {
input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
}
doc.AddMember("input", input, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
const std::vector<std::string>& /*media_content_types*/,
std::string& request_body) const override {
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
"VoyageAI", media_types, media_urls,
{MultimodalType::IMAGE, MultimodalType::VIDEO}));
if (_config.dimensions != -1) {
LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, "
<< "model=" << _config.model_name << ", dimensions=" << _config.dimensions;
}
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"inputs": [
{
"content": [
{"type": "image_url", "image_url": "<url>"}
]
},
{
"content": [
{"type": "video_url", "video_url": "<url>"}
]
}
],
"model": "voyage-multimodal-3.5"
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
rapidjson::Value request_inputs(rapidjson::kArrayType);
for (size_t i = 0; i < media_urls.size(); ++i) {
rapidjson::Value input(rapidjson::kObjectType);
rapidjson::Value content(rapidjson::kArrayType);
rapidjson::Value media_item(rapidjson::kObjectType);
if (media_types[i] == MultimodalType::IMAGE) {
media_item.AddMember("type", "image_url", allocator);
media_item.AddMember("image_url",
rapidjson::Value(media_urls[i].c_str(), allocator), allocator);
} else {
media_item.AddMember("type", "video_url", allocator);
media_item.AddMember("video_url",
rapidjson::Value(media_urls[i].c_str(), allocator), allocator);
}
content.PushBack(media_item, allocator);
input.AddMember("content", content, allocator);
request_inputs.PushBack(input, allocator);
}
doc.AddMember("inputs", request_inputs, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
if (!doc.HasMember("data") || !doc["data"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
/*{
"data":[
{
"object": "embedding",
"embedding": [...], <- only need this
"index": 0
},
{
"object": "embedding",
"embedding": [...],
"index": 1
}, ...
],
"model"....
}*/
const auto& data = doc["data"];
results.reserve(data.Size());
for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}",
_config.provider_type, response_body);
}
std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
return Status::OK();
}
protected:
bool supports_dimension_param(const std::string& model_name) const override {
static const std::unordered_set<std::string> no_dimension_models = {
"voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
"voyage-multimodal-3"};
return !no_dimension_models.contains(model_name);
}
std::string get_dimension_param_name() const override { return "output_dimension"; }
};
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
class LocalAdapter : public AIAdapter {
public:
// Local deployments typically don't need authentication
Status set_authentication(HttpClient* client) const override {
client->set_content_type("application/json");
return Status::OK();
}
Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
std::string end_point = _config.endpoint;
if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
RETURN_IF_ERROR(
build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
} else {
RETURN_IF_ERROR(
build_default_request(doc, allocator, inputs, system_prompt, request_body));
}
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
// Handle various response formats from local LLMs
// Format 1: OpenAI-compatible format with choices/message/content
if (doc.HasMember("choices") && doc["choices"].IsArray()) {
const auto& choices = doc["choices"];
results.reserve(choices.Size());
for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
choices[i]["message"]["content"].IsString()) {
RETURN_IF_ERROR(append_parsed_text_result(
choices[i]["message"]["content"].GetString(), results));
} else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
// Some local LLMs use a simpler format
RETURN_IF_ERROR(
append_parsed_text_result(choices[i]["text"].GetString(), results));
}
}
} else if (doc.HasMember("text") && doc["text"].IsString()) {
// Format 2: Simple response with just "text" or "content" field
RETURN_IF_ERROR(append_parsed_text_result(doc["text"].GetString(), results));
} else if (doc.HasMember("content") && doc["content"].IsString()) {
RETURN_IF_ERROR(append_parsed_text_result(doc["content"].GetString(), results));
} else if (doc.HasMember("response") && doc["response"].IsString()) {
// Format 3: Response field (Ollama `generate` format)
RETURN_IF_ERROR(append_parsed_text_result(doc["response"].GetString(), results));
} else if (doc.HasMember("message") && doc["message"].IsObject() &&
doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
// Format 4: message/content field (Ollama `chat` format)
RETURN_IF_ERROR(
append_parsed_text_result(doc["message"]["content"].GetString(), results));
} else {
return Status::NotSupported("Unsupported response format from local AI.");
}
return Status::OK();
}
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
if (!_config.model_name.empty()) {
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
allocator);
}
add_dimension_params(doc, allocator);
rapidjson::Value input(rapidjson::kArrayType);
for (const auto& msg : inputs) {
input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
}
doc.AddMember("input", input, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& /*media_types*/,
const std::vector<std::string>& /*media_urls*/,
const std::vector<std::string>& /*media_content_types*/,
std::string& /*request_body*/) const override {
return Status::NotSupported("{} does not support multimodal Embed feature.",
_config.provider_type);
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
// parse different response format
rapidjson::Value embedding;
if (doc.HasMember("data") && doc["data"].IsArray()) {
// "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
const auto& data = doc["data"];
results.reserve(data.Size());
for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
return Status::InternalError("Invalid {} response format",
_config.provider_type);
}
std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
} else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
// "embeddings":[[0.1, 0.2, ...]]
results.reserve(1);
for (int i = 0; i < doc["embeddings"].Size(); i++) {
embedding = doc["embeddings"][i];
std::transform(embedding.Begin(), embedding.End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
} else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
// "embedding":[0.1, 0.2, ...]
results.reserve(1);
embedding = doc["embedding"];
std::transform(embedding.Begin(), embedding.End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
} else {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
return Status::OK();
}
private:
Status build_ollama_request(rapidjson::Document& doc,
rapidjson::Document::AllocatorType& allocator,
const std::vector<std::string>& inputs,
const char* const system_prompt, std::string& request_body) const {
/*
for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
{
"model": <model_name>,
"stream": false,
"think": false,
"options": {
"temperature": <temperature>,
"max_token": <max_token>
},
"messages": [
{"role": "system", "content": <system_prompt>},
{"role": "user", "content": <user_prompt>}
]
}
for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
{
"model": <model_name>,
"stream": false,
"think": false
"options": {
"temperature": <temperature>,
"max_token": <max_token>
},
"system": <system_prompt>,
"prompt": <user_prompt>
}
*/
// For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
// The rest remains identical.
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
doc.AddMember("stream", false, allocator);
doc.AddMember("think", false, allocator);
// option section
rapidjson::Value options(rapidjson::kObjectType);
if (_config.temperature != -1) {
options.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
options.AddMember("max_token", _config.max_tokens, allocator);
}
doc.AddMember("options", options, allocator);
// prompt section
if (_config.endpoint.ends_with("chat")) {
rapidjson::Value messages(rapidjson::kArrayType);
if (system_prompt && *system_prompt) {
rapidjson::Value sys_msg(rapidjson::kObjectType);
sys_msg.AddMember("role", "system", allocator);
sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
messages.PushBack(sys_msg, allocator);
}
for (const auto& input : inputs) {
rapidjson::Value message(rapidjson::kObjectType);
message.AddMember("role", "user", allocator);
message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
messages.PushBack(message, allocator);
}
doc.AddMember("messages", messages, allocator);
} else {
if (system_prompt && *system_prompt) {
doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
}
doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
}
return Status::OK();
}
Status build_default_request(rapidjson::Document& doc,
rapidjson::Document::AllocatorType& allocator,
const std::vector<std::string>& inputs,
const char* const system_prompt, std::string& request_body) const {
/*
Default format(OpenAI-compatible):
{
"model": <model_name>,
"temperature": <temperature>,
"max_tokens": <max_tokens>,
"messages": [
{"role": "system", "content": <system_prompt>},
{"role": "user", "content": <user_prompt>}
]
}
*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
// If 'temperature' and 'max_tokens' are set, add them to the request body.
if (_config.temperature != -1) {
doc.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
doc.AddMember("max_tokens", _config.max_tokens, allocator);
}
rapidjson::Value messages(rapidjson::kArrayType);
if (system_prompt && *system_prompt) {
rapidjson::Value sys_msg(rapidjson::kObjectType);
sys_msg.AddMember("role", "system", allocator);
sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
messages.PushBack(sys_msg, allocator);
}
for (const auto& input : inputs) {
rapidjson::Value message(rapidjson::kObjectType);
message.AddMember("role", "user", allocator);
message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
messages.PushBack(message, allocator);
}
doc.AddMember("messages", messages, allocator);
return Status::OK();
}
};
// The OpenAI API format can be reused with some compatible AIs.
class OpenAIAdapter : public VoyageAIAdapter {
public:
Status set_authentication(HttpClient* client) const override {
client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
client->set_content_type("application/json");
return Status::OK();
}
Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
if (_config.endpoint.ends_with("responses")) {
/*{
"model": "gpt-4.1-mini",
"input": [
{"role": "system", "content": "system_prompt here"},
{"role": "user", "content": "xxx"}
],
"temperature": 0.7,
"max_output_tokens": 150
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
allocator);
// If 'temperature' and 'max_tokens' are set, add them to the request body.
if (_config.temperature != -1) {
doc.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
}
// input
rapidjson::Value input(rapidjson::kArrayType);
if (system_prompt && *system_prompt) {
rapidjson::Value sys_msg(rapidjson::kObjectType);
sys_msg.AddMember("role", "system", allocator);
sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
input.PushBack(sys_msg, allocator);
}
for (const auto& msg : inputs) {
rapidjson::Value message(rapidjson::kObjectType);
message.AddMember("role", "user", allocator);
message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
input.PushBack(message, allocator);
}
doc.AddMember("input", input, allocator);
} else {
/*{
"model": "gpt-4",
"messages": [
{"role": "system", "content": "system_prompt here"},
{"role": "user", "content": "xxx"}
],
"temperature": x,
"max_tokens": x,
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
allocator);
// If 'temperature' and 'max_tokens' are set, add them to the request body.
if (_config.temperature != -1) {
doc.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
doc.AddMember("max_tokens", _config.max_tokens, allocator);
}
rapidjson::Value messages(rapidjson::kArrayType);
if (system_prompt && *system_prompt) {
rapidjson::Value sys_msg(rapidjson::kObjectType);
sys_msg.AddMember("role", "system", allocator);
sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
messages.PushBack(sys_msg, allocator);
}
for (const auto& input : inputs) {
rapidjson::Value message(rapidjson::kObjectType);
message.AddMember("role", "user", allocator);
message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
messages.PushBack(message, allocator);
}
doc.AddMember("messages", messages, allocator);
}
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
if (doc.HasMember("output") && doc["output"].IsArray()) {
/// for responses endpoint
/*{
"output": [
{
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "result text here" <- result
}
]
}
]
}*/
const auto& output = doc["output"];
results.reserve(output.Size());
for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
!output[i]["content"][0]["text"].IsString()) {
return Status::InternalError("Invalid output format in {} response: {}",
_config.provider_type, response_body);
}
RETURN_IF_ERROR(append_parsed_text_result(
output[i]["content"][0]["text"].GetString(), results));
}
} else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
/// for completions endpoint
/*{
"object": "chat.completion",
"model": "gpt-4",
"choices": [
{
...
"message": {
"role": "assistant",
"content": "xxx" <- result
},
...
}
],
...
}*/
const auto& choices = doc["choices"];
results.reserve(choices.Size());
for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
if (!choices[i].HasMember("message") ||
!choices[i]["message"].HasMember("content") ||
!choices[i]["message"]["content"].IsString()) {
return Status::InternalError("Invalid choice format in {} response: {}",
_config.provider_type, response_body);
}
RETURN_IF_ERROR(append_parsed_text_result(
choices[i]["message"]["content"].GetString(), results));
}
} else {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
return Status::OK();
}
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& /*media_types*/,
const std::vector<std::string>& /*media_urls*/,
const std::vector<std::string>& /*media_content_types*/,
std::string& /*request_body*/) const override {
return Status::NotSupported("{} does not support multimodal Embed feature.",
_config.provider_type);
}
protected:
bool supports_dimension_param(const std::string& model_name) const override {
return !(model_name == "text-embedding-ada-002");
}
std::string get_dimension_param_name() const override { return "dimensions"; }
};
class DeepSeekAdapter : public OpenAIAdapter {
public:
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
return embed_not_supported_status();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
return embed_not_supported_status();
}
};
class MoonShotAdapter : public OpenAIAdapter {
public:
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
return embed_not_supported_status();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
return embed_not_supported_status();
}
};
class MinimaxAdapter : public OpenAIAdapter {
public:
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"text": ["xxx", "xxx", ...],
"model": "embo-1",
"type": "db"
}*/
rapidjson::Value texts(rapidjson::kArrayType);
for (const auto& input : inputs) {
texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
}
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
doc.AddMember("texts", texts, allocator);
doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
};
class ZhipuAdapter : public OpenAIAdapter {
protected:
bool supports_dimension_param(const std::string& model_name) const override {
return !(model_name == "embedding-2");
}
};
class QwenAdapter : public OpenAIAdapter {
public:
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
const std::vector<std::string>& /*media_content_types*/,
std::string& request_body) const override {
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
"QWEN", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"model": "tongyi-embedding-vision-plus",
"input": {
"contents": [
{"image": "<url>"},
{"video": "<url>"}
]
}
"parameters": {
"dimension": 512
}
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
rapidjson::Value input(rapidjson::kObjectType);
rapidjson::Value contents(rapidjson::kArrayType);
for (size_t i = 0; i < media_urls.size(); ++i) {
rapidjson::Value media_item(rapidjson::kObjectType);
if (media_types[i] == MultimodalType::IMAGE) {
media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator),
allocator);
} else {
media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator),
allocator);
}
contents.PushBack(media_item, allocator);
}
input.AddMember("contents", contents, allocator);
doc.AddMember("input", input, allocator);
if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
rapidjson::Value parameters(rapidjson::kObjectType);
std::string param_name = get_dimension_param_name();
rapidjson::Value dimension_name(param_name.c_str(), allocator);
parameters.AddMember(dimension_name, _config.dimensions, allocator);
doc.AddMember("parameters", parameters, allocator);
}
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) [[unlikely]] {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
// Qwen multimodal embedding usually returns:
// {
// "output": {
// "embeddings": [
// {"index":0, "embedding":[...], "type":"image|video|text"},
// ...
// ]
// }
// }
//
// In text-only or compatibility endpoints, Qwen may also return OpenAI-style
// "data":[{"embedding":[...]}]. For compatibility we first parse native
// output.embeddings and then fallback to OpenAIAdapter parser.
if (doc.HasMember("output") && doc["output"].IsObject() &&
doc["output"].HasMember("embeddings") && doc["output"]["embeddings"].IsArray()) {
const auto& embeddings = doc["output"]["embeddings"];
results.reserve(embeddings.Size());
for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
if (!embeddings[i].HasMember("embedding") ||
!embeddings[i]["embedding"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}",
_config.provider_type, response_body);
}
std::transform(embeddings[i]["embedding"].Begin(), embeddings[i]["embedding"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
return Status::OK();
}
return OpenAIAdapter::parse_embedding_response(response_body, results);
}
protected:
bool supports_dimension_param(const std::string& model_name) const override {
static const std::unordered_set<std::string> no_dimension_models = {
"text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
return !no_dimension_models.contains(model_name);
}
std::string get_dimension_param_name() const override { return "dimension"; }
};
class JinaAdapter : public VoyageAIAdapter {
public:
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
const std::vector<std::string>& /*media_content_types*/,
std::string& request_body) const override {
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
"JINA", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"model": "jina-embeddings-v4",
"task": "text-matching",
"input": [
{"image": "<url>"},
{"video": "<url>"}
]
}*/
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
doc.AddMember("task", "text-matching", allocator);
rapidjson::Value input(rapidjson::kArrayType);
for (size_t i = 0; i < media_urls.size(); ++i) {
rapidjson::Value media_item(rapidjson::kObjectType);
if (media_types[i] == MultimodalType::IMAGE) {
media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator),
allocator);
} else {
media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator),
allocator);
}
input.PushBack(media_item, allocator);
}
if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
doc.AddMember("dimensions", _config.dimensions, allocator);
}
doc.AddMember("input", input, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
};
class BaichuanAdapter : public OpenAIAdapter {
protected:
bool supports_dimension_param(const std::string& model_name) const override { return false; }
};
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
class GeminiAdapter : public AIAdapter {
public:
Status set_authentication(HttpClient* client) const override {
client->set_header("x-goog-api-key", _config.api_key);
client->set_content_type("application/json");
return Status::OK();
}
Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"systemInstruction": {
"parts": [
{
"text": "system_prompt here"
}
]
}
],
"contents": [
{
"parts": [
{
"text": "xxx"
}
]
}
],
"generationConfig": {
"temperature": 0.7,
"maxOutputTokens": 1024
}
}*/
if (system_prompt && *system_prompt) {
rapidjson::Value system_instruction(rapidjson::kObjectType);
rapidjson::Value parts(rapidjson::kArrayType);
rapidjson::Value part(rapidjson::kObjectType);
part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
parts.PushBack(part, allocator);
// system_instruction.PushBack(content, allocator);
system_instruction.AddMember("parts", parts, allocator);
doc.AddMember("systemInstruction", system_instruction, allocator);
}
rapidjson::Value contents(rapidjson::kArrayType);
for (const auto& input : inputs) {
rapidjson::Value content(rapidjson::kObjectType);
rapidjson::Value parts(rapidjson::kArrayType);
rapidjson::Value part(rapidjson::kObjectType);
part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
parts.PushBack(part, allocator);
content.AddMember("parts", parts, allocator);
contents.PushBack(content, allocator);
}
doc.AddMember("contents", contents, allocator);
// If 'temperature' and 'max_tokens' are set, add them to the request body.
rapidjson::Value generationConfig(rapidjson::kObjectType);
if (_config.temperature != -1) {
generationConfig.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
}
doc.AddMember("generationConfig", generationConfig, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
/*{
"candidates":[
{
"content": {
"parts": [
{
"text": "xxx"
}
]
}
}
]
}*/
const auto& candidates = doc["candidates"];
results.reserve(candidates.Size());
for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
if (!candidates[i].HasMember("content") ||
!candidates[i]["content"].HasMember("parts") ||
!candidates[i]["content"]["parts"].IsArray() ||
candidates[i]["content"]["parts"].Empty() ||
!candidates[i]["content"]["parts"][0].HasMember("text") ||
!candidates[i]["content"]["parts"][0]["text"].IsString()) {
return Status::InternalError("Invalid candidate format in {} response",
_config.provider_type);
}
RETURN_IF_ERROR(append_parsed_text_result(
candidates[i]["content"]["parts"][0]["text"].GetString(), results));
}
return Status::OK();
}
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"requests": [
{
"model": "models/gemini-embedding-001",
"content": {
"parts": [
{
"text": "xxx"
}
]
},
"outputDimensionality": 1024
},
{
"model": "models/gemini-embedding-001",
"content": {
"parts": [
{
"text": "yyy"
}
]
},
"outputDimensionality": 1024
}
]
}*/
// gemini requires the model format as `models/{model}`
std::string model_name = _config.model_name;
if (!model_name.starts_with("models/")) {
model_name = "models/" + model_name;
}
rapidjson::Value requests(rapidjson::kArrayType);
for (const auto& input : inputs) {
rapidjson::Value request(rapidjson::kObjectType);
request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
add_dimension_params(request, allocator);
rapidjson::Value content(rapidjson::kObjectType);
rapidjson::Value parts(rapidjson::kArrayType);
rapidjson::Value part(rapidjson::kObjectType);
part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
parts.PushBack(part, allocator);
content.AddMember("parts", parts, allocator);
request.AddMember("content", content, allocator);
requests.PushBack(request, allocator);
}
doc.AddMember("requests", requests, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
const std::vector<std::string>& media_urls,
const std::vector<std::string>& media_content_types,
std::string& request_body) const override {
RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
"Gemini", media_types, media_urls,
{MultimodalType::IMAGE, MultimodalType::AUDIO, MultimodalType::VIDEO}));
if (media_content_types.size() != media_urls.size()) {
return Status::InvalidArgument(
"Gemini multimodal embed input size mismatch, media_content_types={}, "
"media_urls={}",
media_content_types.size(), media_urls.size());
}
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*{
"requests": [
{
"model": "models/gemini-embedding-2-preview",
"content": {
"parts": [
{"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}}
]
},
"outputDimensionality": 768
},
{
"model": "models/gemini-embedding-2-preview",
"content": {
"parts": [
{"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}}
]
},
"outputDimensionality": 768
}
]
}*/
std::string model_name = _config.model_name;
if (!model_name.starts_with("models/")) {
model_name = "models/" + model_name;
}
rapidjson::Value requests(rapidjson::kArrayType);
for (size_t i = 0; i < media_urls.size(); ++i) {
rapidjson::Value request(rapidjson::kObjectType);
request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
add_dimension_params(request, allocator);
rapidjson::Value content(rapidjson::kObjectType);
rapidjson::Value parts(rapidjson::kArrayType);
rapidjson::Value part(rapidjson::kObjectType);
rapidjson::Value file_data(rapidjson::kObjectType);
file_data.AddMember("mime_type",
rapidjson::Value(media_content_types[i].c_str(), allocator),
allocator);
file_data.AddMember("file_uri", rapidjson::Value(media_urls[i].c_str(), allocator),
allocator);
part.AddMember("file_data", file_data, allocator);
parts.PushBack(part, allocator);
content.AddMember("parts", parts, allocator);
request.AddMember("content", content, allocator);
requests.PushBack(request, allocator);
}
doc.AddMember("requests", requests, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
/*{
"embeddings": [
{"values": [0.1, 0.2, 0.3]},
{"values": [0.4, 0.5, 0.6]}
]
}*/
const auto& embeddings = doc["embeddings"];
results.reserve(embeddings.Size());
for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
if (!embeddings[i].HasMember("values") || !embeddings[i]["values"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}",
_config.provider_type, response_body);
}
std::transform(embeddings[i]["values"].Begin(), embeddings[i]["values"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
}
return Status::OK();
}
if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
/*{
"embedding":{
"values": [0.1, 0.2, 0.3]
}
}*/
const auto& embedding = doc["embedding"];
if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
std::transform(embedding["values"].Begin(), embedding["values"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
return Status::OK();
}
protected:
bool supports_dimension_param(const std::string& model_name) const override {
static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
"embedding-001"};
return !no_dimension_models.contains(model_name);
}
std::string get_dimension_param_name() const override { return "outputDimensionality"; }
};
class AnthropicAdapter : public VoyageAIAdapter {
public:
Status set_authentication(HttpClient* client) const override {
client->set_header("x-api-key", _config.api_key);
client->set_header("anthropic-version", _config.anthropic_version);
client->set_content_type("application/json");
return Status::OK();
}
Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const override {
rapidjson::Document doc;
doc.SetObject();
auto& allocator = doc.GetAllocator();
/*
"model": "claude-opus-4-1-20250805",
"max_tokens": 1024,
"system": "system_prompt here",
"messages": [
{"role": "user", "content": "xxx"}
],
"temperature": 0.7
*/
// If 'temperature' and 'max_tokens' are set, add them to the request body.
doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
if (_config.temperature != -1) {
doc.AddMember("temperature", _config.temperature, allocator);
}
if (_config.max_tokens != -1) {
doc.AddMember("max_tokens", _config.max_tokens, allocator);
} else {
// Keep the default value, Anthropic requires this parameter
doc.AddMember("max_tokens", 2048, allocator);
}
if (system_prompt && *system_prompt) {
doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
}
rapidjson::Value messages(rapidjson::kArrayType);
for (const auto& input : inputs) {
rapidjson::Value message(rapidjson::kObjectType);
message.AddMember("role", "user", allocator);
message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
messages.PushBack(message, allocator);
}
doc.AddMember("messages", messages, allocator);
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
request_body = buffer.GetString();
return Status::OK();
}
Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const override {
rapidjson::Document doc;
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
response_body);
}
if (!doc.HasMember("content") || !doc["content"].IsArray()) {
return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
response_body);
}
/*{
"content": [
{
"text": "xxx",
"type": "text"
}
]
}*/
const auto& content = doc["content"];
results.reserve(1);
std::string result;
for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
!content[i].HasMember("text") || !content[i]["text"].IsString()) {
continue;
}
if (std::string(content[i]["type"].GetString()) == "text") {
if (!result.empty()) {
result += "\n";
}
result += content[i]["text"].GetString();
}
}
return append_parsed_text_result(result, results);
}
};
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
class MockAdapter : public AIAdapter {
public:
Status set_authentication(HttpClient* client) const override { return Status::OK(); }
Status build_request_payload(const std::vector<std::string>& inputs,
const char* const system_prompt,
std::string& request_body) const override {
return Status::OK();
}
Status parse_response(const std::string& response_body,
std::vector<std::string>& results) const override {
return append_parsed_text_result(response_body, results);
}
Status build_embedding_request(const std::vector<std::string>& inputs,
std::string& request_body) const override {
return Status::OK();
}
Status build_multimodal_embedding_request(
const std::vector<MultimodalType>& /*media_types*/,
const std::vector<std::string>& /*media_urls*/,
const std::vector<std::string>& /*media_content_types*/,
std::string& /*request_body*/) const override {
return Status::OK();
}
Status parse_embedding_response(const std::string& response_body,
std::vector<std::vector<float>>& results) const override {
rapidjson::Document doc;
doc.SetObject();
doc.Parse(response_body.c_str());
if (doc.HasParseError() || !doc.IsObject()) {
return Status::InternalError("Failed to parse embedding response");
}
if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
return Status::InternalError("Invalid embedding response format");
}
results.reserve(1);
std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
std::back_inserter(results.emplace_back()),
[](const auto& val) { return val.GetFloat(); });
return Status::OK();
}
};
class AIAdapterFactory {
public:
static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
{"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
{"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
{"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
{"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
{"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
{"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
{"JINA", []() { return std::make_shared<JinaAdapter>(); }},
{"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
{"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
{"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
{"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
{"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
auto it = adapters.find(provider_type);
return (it != adapters.end()) ? it->second() : nullptr;
}
};
} // namespace doris