blob: e10a6a54d1e5085563f27341cd60978ee881e8ba [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <gen_cpp/FrontendService.h>
#include <gen_cpp/PaloInternalService_types.h>
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "common/config.h"
#include "common/status.h"
#include "core/column/column_array.h"
#include "core/column/column_const.h"
#include "core/column/column_nullable.h"
#include "core/cow.h"
#include "core/data_type/data_type_array.h"
#include "core/data_type/data_type_number.h"
#include "core/data_type/define_primitive_type.h"
#include "core/data_type/primitive_type.h"
#include "exprs/function/ai/ai_adapter.h"
#include "exprs/function/function.h"
#include "runtime/query_context.h"
#include "runtime/runtime_state.h"
#include "service/http/http_client.h"
#include "util/threadpool.h"
namespace doris {
#include "common/compile_check_begin.h"
// Base class for AI-based functions
template <typename Derived>
class AIFunction : public IFunction {
public:
std::string get_name() const override { return assert_cast<const Derived&>(*this).name; }
// If the user doesn't provide the first arg, `resource_name`
// FE will add the `resource_name` to the arguments list using the Session Variable.
// So the value here should be the maximum number that the function can accept.
size_t get_number_of_arguments() const override {
return assert_cast<const Derived&>(*this).number_of_arguments;
}
bool is_blockable() const override { return true; }
virtual Status build_prompt(const Block& block, const ColumnNumbers& arguments, size_t row_num,
std::string& prompt) const {
const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
StringRef text_ref = text_column.column->get_data_at(row_num);
prompt = std::string(text_ref.data, text_ref.size);
return Status::OK();
}
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
DataTypePtr return_type_impl =
assert_cast<const Derived&>(*this).get_return_type_impl(DataTypes());
MutableColumnPtr col_result = return_type_impl->create_column();
TAIResource config;
std::shared_ptr<AIAdapter> adapter;
if (Status status = assert_cast<const Derived*>(this)->_init_from_resource(
context, block, arguments, config, adapter);
!status.ok()) {
return status;
}
for (size_t i = 0; i < input_rows_count; ++i) {
// Build AI prompt text
std::string prompt;
RETURN_IF_ERROR(
assert_cast<const Derived&>(*this).build_prompt(block, arguments, i, prompt));
// Execute a single AI request and get the result
if (return_type_impl->get_primitive_type() == PrimitiveType::TYPE_ARRAY) {
// Array(Float) for AI_EMBED
std::vector<float> float_result;
RETURN_IF_ERROR(
execute_single_request(prompt, float_result, config, adapter, context));
auto& col_array = assert_cast<ColumnArray&>(*col_result);
auto& offsets = col_array.get_offsets();
auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data());
auto& nested_col =
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
nested_col.reserve(nested_col.size() + float_result.size());
size_t current_offset = nested_col.size();
nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()),
float_result.size());
offsets.push_back(current_offset + float_result.size());
auto& null_map = nested_nullable_col.get_null_map_column();
null_map.insert_many_vals(0, float_result.size());
} else {
std::string string_result;
RETURN_IF_ERROR(
execute_single_request(prompt, string_result, config, adapter, context));
switch (return_type_impl->get_primitive_type()) {
case PrimitiveType::TYPE_STRING: { // string
assert_cast<ColumnString&>(*col_result)
.insert_data(string_result.data(), string_result.size());
break;
}
case PrimitiveType::TYPE_BOOLEAN: { // boolean for AI_FILTER
#ifdef BE_TEST
const char* test_result = std::getenv("AI_TEST_RESULT");
if (test_result != nullptr) {
string_result = test_result;
} else {
string_result = "0";
}
#endif
trim_string(string_result);
if (string_result != "1" && string_result != "0") {
return Status::RuntimeError("Failed to parse boolean value: " +
string_result);
}
assert_cast<ColumnUInt8&>(*col_result)
.insert_value(static_cast<UInt8>(string_result == "1"));
break;
}
case PrimitiveType::TYPE_FLOAT: { // float for AI_SIMILARITY
#ifdef BE_TEST
const char* test_result = std::getenv("AI_TEST_RESULT");
if (test_result != nullptr) {
string_result = test_result;
} else {
string_result = "0.0";
}
#endif
trim_string(string_result);
try {
float float_value = std::stof(string_result);
assert_cast<ColumnFloat32&>(*col_result).insert_value(float_value);
} catch (...) {
return Status::RuntimeError("Failed to parse float value: " +
string_result);
}
break;
}
default:
return Status::InternalError("Unsupported ReturnType for AIFunction");
}
}
}
block.replace_by_position(result, std::move(col_result));
return Status::OK();
}
protected:
// The endpoint `v1/completions` does not support `system_prompt`.
// To ensure a clear structure and stable AI results.
// Convert from `v1/completions` to `v1/chat/completions`
static void normalize_endpoint(TAIResource& config) {
if (config.endpoint.ends_with("v1/completions")) {
static constexpr std::string_view legacy_suffix = "v1/completions";
config.endpoint.replace(config.endpoint.size() - legacy_suffix.size(),
legacy_suffix.size(), "v1/chat/completions");
}
}
private:
// Trim whitespace and newlines from string
static void trim_string(std::string& str) {
str.erase(str.begin(), std::find_if(str.begin(), str.end(),
[](unsigned char ch) { return !std::isspace(ch); }));
str.erase(std::find_if(str.rbegin(), str.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
str.end());
}
// The ai resource must be literal
Status _init_from_resource(FunctionContext* context, const Block& block,
const ColumnNumbers& arguments, TAIResource& config,
std::shared_ptr<AIAdapter>& adapter) const {
// 1. Initialize config
const ColumnWithTypeAndName& resource_column = block.get_by_position(arguments[0]);
StringRef resource_name_ref = resource_column.column->get_data_at(0);
std::string resource_name = std::string(resource_name_ref.data, resource_name_ref.size);
const std::shared_ptr<std::map<std::string, TAIResource>>& ai_resources =
context->state()->get_query_ctx()->get_ai_resources();
if (!ai_resources) {
return Status::InternalError("AI resources metadata missing in QueryContext");
}
auto it = ai_resources->find(resource_name);
if (it == ai_resources->end()) {
return Status::InvalidArgument("AI resource not found: " + resource_name);
}
config = it->second;
normalize_endpoint(config);
// 2. Create an adapter based on provider_type
adapter = AIAdapterFactory::create_adapter(config.provider_type);
if (!adapter) {
return Status::InvalidArgument("Unsupported AI provider type: " + config.provider_type);
}
adapter->init(config);
return Status::OK();
}
// Executes the actual HTTP request
Status do_send_request(HttpClient* client, const std::string& request_body,
std::string& response, const TAIResource& config,
std::shared_ptr<AIAdapter>& adapter, FunctionContext* context) const {
RETURN_IF_ERROR(client->init(config.endpoint));
QueryContext* query_ctx = context->state()->get_query_ctx();
int64_t remaining_query_time = query_ctx->get_remaining_query_time_seconds();
if (remaining_query_time <= 0) {
return Status::TimedOut("Query timeout exceeded before AI request");
}
client->set_timeout_ms(remaining_query_time * 1000);
if (!config.api_key.empty()) {
RETURN_IF_ERROR(adapter->set_authentication(client));
}
return client->execute_post_request(request_body, &response);
}
// Sends the request with retry mechanism for handling transient failures
Status send_request_to_llm(const std::string& request_body, std::string& response,
const TAIResource& config, std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
return HttpClient::execute_with_retry(config.max_retries, config.retry_delay_second,
[this, &request_body, &response, &config, &adapter,
context](HttpClient* client) -> Status {
return this->do_send_request(client, request_body,
response, config,
adapter, context);
});
}
// Wrapper for executing a single LLM request
Status execute_single_request(const std::string& input, std::string& result,
const TAIResource& config, std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
std::vector<std::string> inputs = {input};
std::vector<std::string> results;
std::string request_body;
RETURN_IF_ERROR(adapter->build_request_payload(
inputs, assert_cast<const Derived&>(*this).system_prompt, request_body));
std::string response;
if (config.provider_type == "MOCK") {
// Mock path for UT
response = "this is a mock response. " + input;
} else {
RETURN_IF_ERROR(send_request_to_llm(request_body, response, config, adapter, context));
}
RETURN_IF_ERROR(adapter->parse_response(response, results));
if (results.empty()) {
return Status::InternalError("AI returned empty result");
}
result = std::move(results[0]);
return Status::OK();
}
Status execute_single_request(const std::string& input, std::vector<float>& result,
const TAIResource& config, std::shared_ptr<AIAdapter>& adapter,
FunctionContext* context) const {
std::vector<std::string> inputs = {input};
std::vector<std::vector<float>> results;
std::string request_body;
RETURN_IF_ERROR(adapter->build_embedding_request(inputs, request_body));
std::string response;
if (config.provider_type == "MOCK") {
// Mock path for UT
response = "{\"embedding\": [0, 1, 2, 3, 4]}";
} else {
RETURN_IF_ERROR(send_request_to_llm(request_body, response, config, adapter, context));
}
RETURN_IF_ERROR(adapter->parse_embedding_response(response, results));
if (results.empty()) {
return Status::InternalError("AI returned empty result");
}
result = std::move(results[0]);
return Status::OK();
}
};
#include "common/compile_check_end.h"
} // namespace doris