blob: ede72e9763399025f7d73c7b8451f31b2792c63a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function_python_udaf.h"
#include <arrow/array.h>
#include <arrow/memory_pool.h>
#include <arrow/record_batch.h>
#include <arrow/type.h>
#include <fmt/format.h>
#include "common/exception.h"
#include "common/logging.h"
#include "runtime/user_function_cache.h"
#include "udf/python/python_env.h"
#include "udf/python/python_udf_runtime.h"
#include "udf/python/python_udf_server.h"
#include "util/arrow/block_convertor.h"
#include "util/arrow/row_batch.h"
#include "util/timezone_utils.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/core/block.h"
#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
Status AggregatePythonUDAFData::create(int64_t place) {
DCHECK(client != nullptr) << "Client must be set before calling create";
RETURN_IF_ERROR(client->create(place));
return Status::OK();
}
Status AggregatePythonUDAFData::add(int64_t place_id, const IColumn** columns,
int64_t row_num_start, int64_t row_num_end,
const DataTypes& argument_types) {
DCHECK(client != nullptr) << "Client must be set before calling add";
Block input_block;
for (size_t i = 0; i < argument_types.size(); ++i) {
input_block.insert(
ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i)));
}
std::shared_ptr<arrow::Schema> schema;
RETURN_IF_ERROR(
get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
std::shared_ptr<arrow::RecordBatch> batch;
cctz::time_zone timezone_obj;
TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
&batch, timezone_obj));
RETURN_IF_ERROR(
client->accumulate(place_id, true, *batch, row_num_start, row_num_end, nullptr, 0));
return Status::OK();
}
Status AggregatePythonUDAFData::merge(const AggregatePythonUDAFData& rhs, int64_t place) {
DCHECK(client != nullptr) << "Client must be set before calling merge";
// Get serialized state from rhs (already stored in serialize_data by read())
auto serialized_state = arrow::Buffer::Wrap(
reinterpret_cast<const uint8_t*>(rhs.serialize_data.data()), rhs.serialize_data.size());
RETURN_IF_ERROR(client->merge(place, serialized_state));
return Status::OK();
}
Status AggregatePythonUDAFData::write(BufferWritable& buf, int64_t place) const {
DCHECK(client != nullptr) << "Client must be set before calling write";
// Serialize state from Python server
std::shared_ptr<arrow::Buffer> serialized_state;
RETURN_IF_ERROR(client->serialize(place, &serialized_state));
const char* data = reinterpret_cast<const char*>(serialized_state->data());
size_t size = serialized_state->size();
buf.write_binary(StringRef {data, size});
return Status::OK();
}
void AggregatePythonUDAFData::read(BufferReadable& buf) {
// Read serialized state from buffer into serialize_data
// This will be used later by merge() in deserialize_and_merge()
buf.read_binary(serialize_data);
}
Status AggregatePythonUDAFData::reset(int64_t place) {
DCHECK(client != nullptr) << "Client must be set before calling reset";
RETURN_IF_ERROR(client->reset(place));
return Status::OK();
}
Status AggregatePythonUDAFData::destroy(int64_t place) {
DCHECK(client != nullptr) << "Client must be set before calling destroy";
RETURN_IF_ERROR(client->destroy(place));
return Status::OK();
}
Status AggregatePythonUDAFData::get(IColumn& to, const DataTypePtr& result_type,
int64_t place) const {
DCHECK(client != nullptr) << "Client must be set before calling get";
// Get final result from Python server
std::shared_ptr<arrow::RecordBatch> result;
RETURN_IF_ERROR(client->finalize(place, &result));
// Convert Arrow RecordBatch to Block
Block result_block;
DataTypes types = {result_type};
cctz::time_zone timezone_obj;
TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
RETURN_IF_ERROR(convert_from_arrow_batch(result, types, &result_block, timezone_obj));
// Insert the result value into output column
if (result_block.rows() != 1) {
return Status::InternalError("Expected 1 row in result block, got {}", result_block.rows());
}
auto& result_column = result_block.get_by_position(0).column;
to.insert_from(*result_column, 0);
return Status::OK();
}
Status AggregatePythonUDAF::open() {
// Build function metadata from TFunction
_func_meta.id = _fn.id;
_func_meta.name = _fn.name.function_name;
// For UDAF, symbol is in aggregate_fn
if (_fn.__isset.aggregate_fn && _fn.aggregate_fn.__isset.symbol) {
_func_meta.symbol = _fn.aggregate_fn.symbol;
} else {
return Status::InvalidArgument("Python UDAF symbol is not set");
}
// Determine load type (inline code or module)
if (!_fn.function_code.empty()) {
_func_meta.type = PythonUDFLoadType::INLINE;
_func_meta.location = "inline";
_func_meta.inline_code = _fn.function_code;
} else if (!_fn.hdfs_location.empty()) {
_func_meta.type = PythonUDFLoadType::MODULE;
_func_meta.location = _fn.hdfs_location;
_func_meta.checksum = _fn.checksum;
} else {
_func_meta.type = PythonUDFLoadType::UNKNOWN;
_func_meta.location = "unknown";
}
_func_meta.input_types = argument_types;
_func_meta.return_type = _return_type;
_func_meta.client_type = PythonClientType::UDAF;
// Get Python version
if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) {
RETURN_IF_ERROR(PythonVersionManager::instance().get_version(_fn.runtime_version,
&_python_version));
} else {
return Status::InvalidArgument("Python UDAF runtime version is not set");
}
_func_meta.runtime_version = _python_version.full_version;
RETURN_IF_ERROR(_func_meta.check());
_func_meta.always_nullable = _return_type->is_nullable();
LOG(INFO) << fmt::format("Creating Python UDAF: {}, runtime_version: {}, func_meta: {}",
_fn.name.function_name, _python_version.to_string(),
_func_meta.to_string());
if (_func_meta.type == PythonUDFLoadType::MODULE) {
RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath(
_func_meta.id, _func_meta.location, _func_meta.checksum, &_func_meta.location));
}
LOG(INFO) << "Initialized Python UDAF metadata for function: " << _func_meta.name;
return Status::OK();
}
void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const {
// Initialize the data structure
new (place) Data();
DCHECK(reinterpret_cast<Data*>(place)) << "Place must not be null";
// Get or initialize shared client (thread-safe)
PythonUDAFClient* client = _get_shared_client();
if (UNLIKELY(!client)) {
this->data(place).~Data();
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Failed to get shared Python UDAF client");
}
// Set client pointer in data (all places share the same client)
this->data(place).set_client(client);
// Initialize UDAF state in Python server using shared client
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).create(place_id);
if (UNLIKELY(!st.ok())) {
this->data(place).~Data();
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::destroy(AggregateDataPtr __restrict place) const noexcept {
int64_t place_id = reinterpret_cast<int64_t>(place);
// Destroy state in Python server (client is stored in data)
if (this->data(place).client) {
Status st = this->data(place).destroy(place_id);
if (UNLIKELY(!st.ok())) {
LOG(WARNING) << "Failed to destroy Python UDAF state: " << st.to_string();
}
}
// Destroy C++ data structure
this->data(place).~Data();
}
void AggregatePythonUDAF::add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num, Arena&) const {
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).add(place_id, columns, row_num, row_num + 1, argument_types);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::add_batch(size_t batch_size, AggregateDataPtr* places,
size_t place_offset, const IColumn** columns, Arena&,
bool /*agg_many*/) const {
// With shared client optimization, all places use the same Python process
// We still need to add each row individually because they go to different place_ids
for (size_t i = 0; i < batch_size; ++i) {
int64_t place_id = reinterpret_cast<int64_t>(places[i] + place_offset);
Status st = this->data(places[i] + place_offset)
.add(place_id, columns, i, i + 1, argument_types);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
}
void AggregatePythonUDAF::add_batch_single_place(size_t batch_size, AggregateDataPtr place,
const IColumn** columns, Arena&) const {
// Add batch to single place using shared client
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).add(place_id, columns, 0, batch_size, argument_types);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::add_range_single_place(int64_t partition_start, int64_t partition_end,
int64_t frame_start, int64_t frame_end,
AggregateDataPtr place, const IColumn** columns,
Arena& arena, UInt8* current_window_empty,
UInt8* current_window_has_inited) const {
// Calculate actual frame range
frame_start = std::max<int64_t>(frame_start, partition_start);
frame_end = std::min<int64_t>(frame_end, partition_end);
if (frame_start >= frame_end) {
if (!*current_window_has_inited) {
*current_window_empty = true;
}
return;
}
// Add range to single place using shared client
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).add(place_id, columns, frame_start, frame_end, argument_types);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
*current_window_empty = false;
*current_window_has_inited = true;
}
void AggregatePythonUDAF::reset(AggregateDataPtr place) const {
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).reset(place_id);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const {
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).merge(this->data(rhs), place_id);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::serialize(ConstAggregateDataPtr __restrict place,
BufferWritable& buf) const {
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).write(buf, place_id);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
void AggregatePythonUDAF::deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const {
this->data(place).read(buf);
}
void AggregatePythonUDAF::insert_result_into(ConstAggregateDataPtr __restrict place,
IColumn& to) const {
// Get final result and insert into column using shared client
int64_t place_id = reinterpret_cast<int64_t>(place);
Status st = this->data(place).get(to, _return_type, place_id);
if (UNLIKELY(!st.ok())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
Status AggregatePythonUDAF::_init_shared_client() const {
if (_shared_client) {
return Status::OK();
}
RETURN_IF_ERROR(PythonUDFServerManager::instance().get_client(_func_meta, _python_version,
&_shared_client));
LOG(INFO) << "Initialized shared Python UDAF client for function: " << _func_meta.name;
return Status::OK();
}
PythonUDAFClient* AggregatePythonUDAF::_get_shared_client() const {
// Double-checked locking for thread-safe lazy initialization
if (!_client_initialized) {
std::lock_guard<std::mutex> lock(_client_init_mutex);
if (!_client_initialized) {
Status st = _init_shared_client();
if (!st.ok()) {
LOG(ERROR) << "Failed to initialize shared Python UDAF client: " << st.to_string();
return nullptr;
}
_client_initialized = true;
}
}
return _shared_client.get();
}
} // namespace doris::vectorized