blob: b81388a51eab755dd8732cc7fe62dcc48cf81a9a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "common/status.h"
#include "udf/python/python_env.h"
#include "udf/python/python_udaf_client.h"
#include "udf/python/python_udf_meta.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
/**
* Aggregate state data for Python UDAF
*
* Python UDAF state is managed remotely (in Python server).
* We cache serialized state for shuffle/merge operations (similar to Java UDAF).
*/
struct AggregatePythonUDAFData {
std::string serialize_data;
PythonUDAFClientPtr client;
AggregatePythonUDAFData() = default;
AggregatePythonUDAFData(const AggregatePythonUDAFData& other)
: serialize_data(other.serialize_data), client(other.client) {}
~AggregatePythonUDAFData() = default;
Status create(int64_t place);
Status add(int64_t place_id, const IColumn** columns, int64_t row_num_start,
int64_t row_num_end, const DataTypes& argument_types);
Status add_batch(AggregateDataPtr* places, size_t place_offset, size_t num_rows,
const IColumn** columns, const DataTypes& argument_types, size_t start,
size_t end);
Status merge(const AggregatePythonUDAFData& rhs, int64_t place);
Status write(BufferWritable& buf, int64_t place) const;
void read(BufferReadable& buf);
Status reset(int64_t place);
Status destroy(int64_t place);
Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const;
};
/**
* Python UDAF Aggregate Function
*
* Implements Snowflake-style UDAF pattern:
* - __init__(): Initialize aggregate state
* - aggregate_state: Property returning serializable state
* - accumulate(*args): Add input to state
* - merge(other_state): Combine two states
* - finish(): Get final result
*
* Communication with Python server via PythonUDAFClient using Arrow Flight.
*/
class AggregatePythonUDAF final
: public IAggregateFunctionDataHelper<AggregatePythonUDAFData, AggregatePythonUDAF>,
VarargsExpression,
NullableAggregateFunction {
public:
ENABLE_FACTORY_CREATOR(AggregatePythonUDAF);
AggregatePythonUDAF(const TFunction& fn, const DataTypes& argument_types_,
const DataTypePtr& return_type)
: IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {}
~AggregatePythonUDAF() override = default;
static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
const DataTypePtr& return_type) {
return std::make_shared<AggregatePythonUDAF>(fn, argument_types_, return_type);
}
String get_name() const override { return _fn.name.function_name; }
DataTypePtr get_return_type() const override { return _return_type; }
/**
* Initialize function metadata
*/
Status open();
/**
* Create aggregate state in Python server
*/
void create(AggregateDataPtr __restrict place) const override;
/**
* Destroy aggregate state in Python server
*/
void destroy(AggregateDataPtr __restrict place) const noexcept override;
/**
* Add single row to aggregate state
*/
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override;
/**
* Add batch of rows to multiple aggregate states (GROUP BY)
*/
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena&, bool /*agg_many*/) const override;
/**
* Add batch of rows to single aggregate state (no GROUP BY)
*/
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena&) const override;
/**
* Add range of rows to single place (for window functions)
*/
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
Arena& arena, UInt8* current_window_empty,
UInt8* current_window_has_inited) const override;
/**
* Reset aggregate state to initial value
*/
void reset(AggregateDataPtr place) const override;
/**
* Merge two aggregate states
*/
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena&) const override;
/**
* Serialize aggregate state for shuffle
*/
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override;
/**
* Deserialize aggregate state from shuffle
*/
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena&) const override;
/**
* Get final result and insert into output column
*/
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override;
private:
TFunction _fn;
DataTypePtr _return_type;
PythonUDFMeta _func_meta;
PythonVersion _python_version;
// Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
// Used for all UDAF RPC operations
// - places column is always present (NULL in single-place mode, actual place_id values in GROUP BY mode)
// - binary_data column contains serialized data for MERGE operations (NULL for ACCUMULATE)
mutable std::shared_ptr<arrow::Schema> _schema;
mutable std::once_flag _schema_init_flag;
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"