blob: 2367e4b94595404576d8fe3986d20f96d125bfba [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <glog/logging.h>
#include <rapidjson/document.h>
#include <string_view>
#include "core/data_type/data_type_nullable.h"
#include "core/data_type/primitive_type.h"
#include "exprs/function/ai/ai_functions.h"
#include "util/jsonb_utils.h"
#include "util/s3_uri.h"
#include "util/s3_util.h"
namespace doris {
class FunctionEmbed : public AIFunction<FunctionEmbed> {
public:
static constexpr auto name = "embed";
static constexpr size_t number_of_arguments = 2;
static constexpr auto system_prompt = "";
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
}
Status execute_with_adapter(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, uint32_t result,
size_t input_rows_count, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter) const {
if (arguments.size() != 2) {
return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}",
arguments.size());
}
PrimitiveType input_type =
remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
if (input_type == PrimitiveType::TYPE_JSONB) {
return _execute_multimodal_embed(context, block, arguments, result, input_rows_count,
config, adapter);
}
if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR ||
input_type == PrimitiveType::TYPE_CHAR) {
return _execute_text_embed(context, block, arguments, result, input_rows_count, config,
adapter);
}
return Status::InvalidArgument(
"Function EMBED expects the second argument to be STRING or JSON, but got type {}",
block.get_by_position(arguments[1]).type->get_name());
}
static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
private:
static int32_t _get_embed_max_batch_size(FunctionContext* context) {
QueryContext* query_ctx = context->state()->get_query_ctx();
DORIS_CHECK(query_ctx != nullptr);
return query_ctx->query_options().embed_max_batch_size;
}
Status _execute_text_embed(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, uint32_t result,
size_t input_rows_count, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter) const {
auto col_result = ColumnArray::create(
ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
std::vector<std::string> batch_prompts;
size_t current_batch_size = 0;
const int32_t max_batch_size = _get_embed_max_batch_size(context);
const size_t max_context_window_size =
static_cast<size_t>(get_ai_context_window_size(context));
for (size_t i = 0; i < input_rows_count; ++i) {
std::string prompt;
RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
const size_t prompt_size = prompt.size();
if (prompt_size > max_context_window_size) {
// flush history batch
RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
adapter, context));
current_batch_size = 0;
batch_prompts.emplace_back(std::move(prompt));
RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
adapter, context));
continue;
}
if (!batch_prompts.empty() &&
(current_batch_size + prompt_size > max_context_window_size ||
batch_prompts.size() >= static_cast<size_t>(max_batch_size))) {
RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
adapter, context));
current_batch_size = 0;
}
batch_prompts.emplace_back(std::move(prompt));
current_batch_size += prompt_size;
}
RETURN_IF_ERROR(
_flush_text_embedding_batch(batch_prompts, *col_result, config, adapter, context));
block.replace_by_position(result, std::move(col_result));
return Status::OK();
}
Status _execute_multimodal_embed(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, uint32_t result,
size_t input_rows_count, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter) const {
auto col_result = ColumnArray::create(
ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
std::vector<MultimodalType> batch_media_types;
std::vector<std::string> batch_media_content_types;
std::vector<std::string> batch_media_urls;
int64_t ttl_seconds = 3600;
QueryContext* query_ctx = context->state()->get_query_ctx();
if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds;
if (ttl_seconds <= 0) {
ttl_seconds = 3600;
}
}
const int32_t max_batch_size = _get_embed_max_batch_size(context);
const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]);
for (size_t i = 0; i < input_rows_count; ++i) {
rapidjson::Document file_input;
RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
std::string content_type;
MultimodalType media_type;
RETURN_IF_ERROR(_infer_media_type(file_input, content_type, media_type));
std::string media_url;
RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url));
if (!batch_media_urls.empty() &&
batch_media_urls.size() >= static_cast<size_t>(max_batch_size)) {
RETURN_IF_ERROR(_flush_multimodal_embedding_batch(
batch_media_types, batch_media_content_types, batch_media_urls, *col_result,
config, adapter, context));
}
batch_media_types.emplace_back(media_type);
batch_media_content_types.emplace_back(std::move(content_type));
batch_media_urls.emplace_back(std::move(media_url));
}
RETURN_IF_ERROR(_flush_multimodal_embedding_batch(
batch_media_types, batch_media_content_types, batch_media_urls, *col_result, config,
adapter, context));
block.replace_by_position(result, std::move(col_result));
return Status::OK();
}
// EMBED-private helper.
// Sends one embedding request with a prebuilt request body and validates returned row count.
Status _execute_prebuilt_embedding_request(const std::string& request_body,
std::vector<std::vector<float>>& results,
size_t expected_size, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
std::string response;
#ifdef BE_TEST
if (config.provider_type == "MOCK") {
results.clear();
results.reserve(expected_size);
for (size_t i = 0; i < expected_size; ++i) {
results.emplace_back(std::initializer_list<float> {0, 1, 2, 3, 4});
}
return Status::OK();
}
#endif
RETURN_IF_ERROR(
this->send_request_to_llm(request_body, response, config, adapter, context));
RETURN_IF_ERROR(adapter->parse_embedding_response(response, results));
if (results.empty()) {
return Status::InternalError("AI returned empty result");
}
if (results.size() != expected_size) [[unlikely]] {
return Status::InternalError(
"AI embedding returned {} results, but {} inputs were sent", results.size(),
expected_size);
}
return Status::OK();
}
// EMBED-private helper.
// Flushes one accumulated text embedding batch into the output array column.
Status _flush_text_embedding_batch(std::vector<std::string>& batch_prompts,
ColumnArray& col_result, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
if (batch_prompts.empty()) {
return Status::OK();
}
std::string request_body;
RETURN_IF_ERROR(adapter->build_embedding_request(batch_prompts, request_body));
std::vector<std::vector<float>> batch_results;
RETURN_IF_ERROR(_execute_prebuilt_embedding_request(
request_body, batch_results, batch_prompts.size(), config, adapter, context));
for (const auto& batch_result : batch_results) {
_insert_embedding_result(col_result, batch_result);
}
batch_prompts.clear();
return Status::OK();
}
// EMBED-private helper.
// Flushes one accumulated multimodal embedding batch into the output array column.
Status _flush_multimodal_embedding_batch(std::vector<MultimodalType>& batch_media_types,
std::vector<std::string>& batch_media_content_types,
std::vector<std::string>& batch_media_urls,
ColumnArray& col_result, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
if (batch_media_urls.empty()) {
return Status::OK();
}
std::string request_body;
RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(
batch_media_types, batch_media_urls, batch_media_content_types, request_body));
std::vector<std::vector<float>> batch_results;
RETURN_IF_ERROR(_execute_prebuilt_embedding_request(
request_body, batch_results, batch_media_urls.size(), config, adapter, context));
for (const auto& batch_result : batch_results) {
_insert_embedding_result(col_result, batch_result);
}
batch_media_types.clear();
batch_media_content_types.clear();
batch_media_urls.clear();
return Status::OK();
}
static void _insert_embedding_result(ColumnArray& col_array,
const std::vector<float>& float_result) {
auto& offsets = col_array.get_offsets();
auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data());
auto& nested_col =
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
nested_col.reserve(nested_col.size() + float_result.size());
size_t current_offset = nested_col.size();
nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()),
float_result.size());
offsets.push_back(current_offset + float_result.size());
auto& null_map = nested_nullable_col.get_null_map_column();
null_map.insert_many_vals(0, float_result.size());
}
static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) {
if (s.size() < prefix.size()) {
return false;
}
return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) {
return std::tolower(static_cast<unsigned char>(a)) ==
std::tolower(static_cast<unsigned char>(b));
});
}
static Status _infer_media_type(const rapidjson::Value& file_input, std::string& content_type,
MultimodalType& media_type) {
RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type));
if (_starts_with_ignore_case(content_type, "image/")) {
media_type = MultimodalType::IMAGE;
return Status::OK();
} else if (_starts_with_ignore_case(content_type, "video/")) {
media_type = MultimodalType::VIDEO;
return Status::OK();
} else if (_starts_with_ignore_case(content_type, "audio/")) {
media_type = MultimodalType::AUDIO;
return Status::OK();
}
return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type);
}
// Parse the FILE-like JSONB argument into a JSON object for downstream field reads.
static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num,
rapidjson::Document& file_input) {
std::string file_json =
JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
file_column.column->get_data_at(row_num).size);
file_input.Parse(file_json.c_str());
DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject());
return Status::OK();
}
// TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields
// replacing this function
static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name,
std::string& value) {
auto iter = obj.FindMember(field_name);
if (iter == obj.MemberEnd() || !iter->value.IsString()) {
return Status::InvalidArgument(
"EMBED file json field '{}' is required and must be a string", field_name);
}
value = iter->value.GetString();
if (value.empty()) {
return Status::InvalidArgument("EMBED file json field '{}' can not be empty",
field_name);
}
return Status::OK();
}
static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input,
S3ClientConf& s3_client_conf) {
std::string endpoint;
RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint));
std::string region;
RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region));
auto get_optional_string_field = [&](const char* field_name, std::string& value) {
auto iter = file_input.FindMember(field_name);
if (iter == file_input.MemberEnd() || iter->value.IsNull()) {
return;
}
DORIS_CHECK(iter->value.IsString());
value = iter->value.GetString();
};
get_optional_string_field("ak", s3_client_conf.ak);
get_optional_string_field("sk", s3_client_conf.sk);
get_optional_string_field("role_arn", s3_client_conf.role_arn);
get_optional_string_field("external_id", s3_client_conf.external_id);
s3_client_conf.endpoint = endpoint;
s3_client_conf.region = region;
return Status::OK();
}
Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds,
std::string& media_url) const {
std::string uri;
RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri));
// If it's a direct http/https URL, use it as-is
if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) {
media_url = uri;
return Status::OK();
}
S3ClientConf s3_client_conf;
RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf));
auto s3_client = S3ClientFactory::instance().create(s3_client_conf);
if (s3_client == nullptr) {
return Status::InternalError("Failed to create S3 client for EMBED file input");
}
S3URI s3_uri(uri);
RETURN_IF_ERROR(s3_uri.parse());
std::string bucket = s3_uri.get_bucket();
std::string key = s3_uri.get_key();
DORIS_CHECK(!bucket.empty() && !key.empty());
media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds,
s3_client_conf);
return Status::OK();
}
};
}; // namespace doris