blob: 078c34a39ea9671152bbcbd3af883493e7d5cc81 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <arrow/status.h>
#include "udf/python/python_client.h"
namespace doris {
class PythonUDAFClient;
using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>;
// Fixed-size (30 bytes) binary metadata structure for UDAF operations (Request)
struct __attribute__((packed)) UDAFMetadata {
uint32_t meta_version; // 4 bytes: metadata version (current version = 1)
uint8_t operation; // 1 byte: UDAFOperation enum
uint8_t is_single_place; // 1 byte: boolean (0 or 1, ACCUMULATE only)
int64_t place_id; // 8 bytes: aggregate state identifier (globally unique)
int64_t row_start; // 8 bytes: start row index (ACCUMULATE only)
int64_t row_end; // 8 bytes: end row index (exclusive, ACCUMULATE only)
};
static_assert(sizeof(UDAFMetadata) == 30, "UDAFMetadata size must be 30 bytes");
// Current metadata version constant
constexpr uint32_t UDAF_METADATA_VERSION = 1;
/**
* Python UDAF Client
*
* Implements Snowflake-style UDAF pattern with the following methods:
* - __init__(): Initialize aggregate state
* - aggregate_state: Property that returns internal state
* - accumulate(input): Add new input to aggregate state
* - merge(other_state): Combine two intermediate states
* - finish(): Generate final result from aggregate state
*
* Communication protocol with Python server:
* 1. CREATE: Initialize UDAF class instance and get initial state
* 2. ACCUMULATE: Send input data batch and get updated states
* 3. SERIALIZE: Get serialized state for shuffle/merge
* 4. MERGE: Combine serialized states
* 5. FINALIZE: Get final result from state
* 6. RESET: Reset state to initial value
* 7. DESTROY: Clean up resources
*/
class PythonUDAFClient : public PythonClient {
public:
// UDAF operation types
enum class UDAFOperation : uint8_t {
CREATE = 0, // Create new aggregate state
ACCUMULATE = 1, // Add input rows to state
SERIALIZE = 2, // Serialize state for shuffle
MERGE = 3, // Merge two states
FINALIZE = 4, // Get final result
RESET = 5, // Reset state
DESTROY = 6 // Destroy state
};
PythonUDAFClient() = default;
~PythonUDAFClient() override {
// Clean up all remaining states on destruction
auto st = close();
if (!st.ok()) {
LOG(WARNING) << "Failed to close PythonUDAFClient in destructor: " << st.to_string();
}
}
static Status create(const PythonUDFMeta& func_meta, ProcessPtr process,
const std::shared_ptr<arrow::Schema>& data_schema,
PythonUDAFClientPtr* client);
/**
* Initialize UDAF client with data schema
* Overrides base class to set _schema before initialization
* @param func_meta Function metadata
* @param process Python process handle
* @param data_schema Arrow schema for UDAF data
* @return Status
*/
Status init(const PythonUDFMeta& func_meta, ProcessPtr process,
const std::shared_ptr<arrow::Schema>& data_schema);
/**
* Create aggregate state for a place
* @param place_id Unique identifier for the aggregate state
* @return Status
*/
Status create(int64_t place_id);
/**
* Accumulate input data into aggregate state
*
* For single-place mode (is_single_place=true):
* - input RecordBatch contains only data columns
* - All rows are accumulated to the same place_id
*
* For multi-place mode (is_single_place=false):
* - input RecordBatch MUST contain a "places" column (int64) as the last column
* - The "places" column indicates which place each row belongs to
* - place_id parameter is ignored (set to 0 by convention)
*
* @param place_id Aggregate state identifier (used only in single-place mode)
* @param is_single_place Whether all rows go to single place
* @param input Input data batch (must contain "places" column if is_single_place=false)
* @param row_start Start row index
* @param row_end End row index (exclusive)
* @return Status
*/
Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input,
int64_t row_start, int64_t row_end);
/**
* Serialize aggregate state for shuffle/merge
* @param place_id Aggregate state identifier
* @param serialized_state Output serialized state
* @return Status
*/
Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state);
/**
* Merge another serialized state into current state
* @param place_id Target aggregate state identifier
* @param serialized_state Serialized state to merge
* @return Status
*/
Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state);
/**
* Get final result from aggregate state
* @param place_id Aggregate state identifier
* @param output Output result
* @return Status
*/
Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output);
/**
* Reset aggregate state to initial value
* @param place_id Aggregate state identifier
* @return Status
*/
Status reset(int64_t place_id);
/**
* Destroy aggregate state and free resources
* @param place_id Aggregate state identifier
* @return Status
*/
Status destroy(int64_t place_id);
/**
* Close client connection and cleanup
* Overrides base class to destroy the tracked place first
* @return Status
*/
Status close();
private:
DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient);
/**
* Send RecordBatch request to Python server with app_metadata
* @param metadata UDAFMetadata structure (will be sent as app_metadata)
* @param request_batch Request RecordBatch (contains data columns + binary_data column)
* @param response_batch Output RecordBatch
* @return Status
*/
Status _send_request(const UDAFMetadata& metadata,
const std::shared_ptr<arrow::RecordBatch>& request_batch,
std::shared_ptr<arrow::RecordBatch>* response_batch);
/**
* Create request batch with data columns (for ACCUMULATE)
* Appends NULL binary_data column to input data batch
*/
Status _create_data_request_batch(const arrow::RecordBatch& input_data,
std::shared_ptr<arrow::RecordBatch>* out);
/**
* Create request batch with binary data (for MERGE)
* Creates NULL data columns + binary_data column
*/
Status _create_binary_request_batch(const std::shared_ptr<arrow::Buffer>& binary_data,
std::shared_ptr<arrow::RecordBatch>* out);
/**
* Get or create empty request batch (for CREATE/SERIALIZE/FINALIZE/RESET/DESTROY)
* All columns are NULL. Cached after first creation for reuse.
*/
Status _get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out);
// Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
std::shared_ptr<arrow::Schema> _schema;
std::shared_ptr<arrow::RecordBatch> _empty_request_batch;
// Track created state for cleanup
std::optional<int64_t> _created_place_id;
// Thread safety: protect gRPC stream operations
// CRITICAL: gRPC ClientReaderWriter does NOT support concurrent Write() calls
// Even within same thread, multiple pipeline tasks may trigger concurrent operations
// (e.g., normal accumulate() + cleanup destroy() during task finalization)
mutable std::mutex _operation_mutex;
};
} // namespace doris