blob: 170463c1f23291a6e382d2896363208775f850f5 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <arrow/status.h>
#include "udf/python/python_client.h"
namespace doris {
class PythonUDAFClient;
using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>;
/**
* Python UDAF Client
*
* Implements Snowflake-style UDAF pattern with the following methods:
* - __init__(): Initialize aggregate state
* - aggregate_state: Property that returns internal state
* - accumulate(input): Add new input to aggregate state
* - merge(other_state): Combine two intermediate states
* - finish(): Generate final result from aggregate state
*
* Communication protocol with Python server:
* 1. CREATE: Initialize UDAF class instance and get initial state
* 2. ACCUMULATE: Send input data batch and get updated states
* 3. SERIALIZE: Get serialized state for shuffle/merge
* 4. MERGE: Combine serialized states
* 5. FINALIZE: Get final result from state
* 6. RESET: Reset state to initial value
* 7. DESTROY: Clean up resources
*/
class PythonUDAFClient : public PythonClient {
public:
// UDAF operation types
enum class UDAFOperation : uint8_t {
CREATE = 0, // Create new aggregate state
ACCUMULATE = 1, // Add input rows to state
SERIALIZE = 2, // Serialize state for shuffle
MERGE = 3, // Merge two states
FINALIZE = 4, // Get final result
RESET = 5, // Reset state
DESTROY = 6 // Destroy state
};
PythonUDAFClient() = default;
~PythonUDAFClient() override = default;
static Status create(const PythonUDFMeta& func_meta, ProcessPtr process,
PythonUDAFClientPtr* client);
/**
* Create aggregate state for a place
* @param place_id Unique identifier for the aggregate state
* @return Status
*/
Status create(int64_t place_id);
/**
* Accumulate input data into aggregate state
* @param place_id Aggregate state identifier
* @param is_single_place Whether all rows go to single place
* @param input Input data batch
* @param row_start Start row index
* @param row_end End row index (exclusive)
* @param places Array of place pointers (for GROUP BY)
* @param place_offset Offset within each place
* @return Status
*/
Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input,
int64_t row_start, int64_t row_end, const int64_t* places = nullptr,
int64_t place_offset = 0);
/**
* Serialize aggregate state for shuffle/merge
* @param place_id Aggregate state identifier
* @param serialized_state Output serialized state
* @return Status
*/
Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state);
/**
* Merge another serialized state into current state
* @param place_id Target aggregate state identifier
* @param serialized_state Serialized state to merge
* @return Status
*/
Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state);
/**
* Get final result from aggregate state
* @param place_id Aggregate state identifier
* @param output Output result
* @return Status
*/
Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output);
/**
* Reset aggregate state to initial value
* @param place_id Aggregate state identifier
* @return Status
*/
Status reset(int64_t place_id);
/**
* Destroy aggregate state and free resources
* @param place_id Aggregate state identifier
* @return Status
*/
Status destroy(int64_t place_id);
/**
* Destroy all aggregate states
* @return Status
*/
Status destroy_all();
/**
* Close client connection and cleanup
* Overrides base class to destroy all states first
* @return Status
*/
Status close();
static std::string print_operation(UDAFOperation op);
private:
DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient);
/**
* Helper to execute a UDAF operation (CREATE, RESET, DESTROY, etc.)
* This consolidates the common pattern:
* 1. Check if process is alive
* 2. Create unified batch
* 3. Send operation
* 4. Validate boolean response (if validate_response is true)
*
* Template parameters are compile-time constants for better optimization
* @tparam operation The UDAF operation to execute
* @tparam validate_response If true, validates response as boolean and checks success
* @param place_id Aggregate state identifier
* @param metadata Optional metadata buffer
* @param data Optional data buffer
* @param output Output RecordBatch (can be nullptr if not needed)
* @return Status
*/
template <UDAFOperation operation, bool validate_response>
Status _execute_operation(int64_t place_id, const std::shared_ptr<arrow::Buffer>& metadata,
const std::shared_ptr<arrow::Buffer>& data,
std::shared_ptr<arrow::RecordBatch>* output);
/**
* Send operation request to Python server
* @param input Input data batch
* @param output Optional output data
* @return Status
*/
Status _send_operation(const arrow::RecordBatch* input,
std::shared_ptr<arrow::RecordBatch>* output);
// Track created states for cleanup
std::unordered_set<int64_t> _created_states;
// Thread safety: protect concurrent RPC calls
// Arrow Flight client (gRPC-based) is not fully thread-safe,
// so we need to serialize all operations through this mutex
mutable std::mutex _operation_mutex;
};
} // namespace doris