| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import argparse |
| import base64 |
| import gc |
| import importlib |
| import inspect |
| import ipaddress |
| import json |
| import sys |
| import os |
| import traceback |
| import logging |
| import time |
| import threading |
| import pickle |
| from abc import ABC, abstractmethod |
| from contextlib import contextmanager |
| from typing import Any, Callable, Optional, Tuple, get_origin, Dict |
| from datetime import datetime |
| from enum import Enum |
| from pathlib import Path |
| from logging.handlers import RotatingFileHandler |
| |
| import pandas as pd |
| import pyarrow as pa |
| from pyarrow import flight |
| |
| |
| class ServerState: |
| """Global server state container.""" |
| |
| unix_socket_path: str = "" |
| |
| PYTHON_SERVER_START_SUCCESS_MSG: str = "Start python server successfully" |
| |
| @staticmethod |
| def setup_logging(): |
| """Setup logging configuration for the UDF server with rotation.""" |
| |
| doris_home = os.getenv("DORIS_HOME") |
| if not doris_home: |
| # Fallback to current directory if DORIS_HOME is not set |
| doris_home = os.getcwd() |
| |
| log_dir = os.path.join(doris_home, "log") |
| os.makedirs(log_dir, exist_ok=True) |
| |
| # Use shared log file with process ID in each log line |
| log_file = os.path.join(log_dir, "python_udf_output.log") |
| max_bytes = 128 * 1024 * 1024 # 128MB |
| backup_count = 5 |
| |
| # Use RotatingFileHandler to automatically manage log file size |
| file_handler = RotatingFileHandler( |
| log_file, maxBytes=max_bytes, backupCount=backup_count, encoding="utf-8" |
| ) |
| |
| # Include process ID in log format |
| file_handler.setFormatter( |
| logging.Formatter( |
| "[%(asctime)s] [PID:%(process)d] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s" |
| ) |
| ) |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| handlers=[file_handler], |
| ) |
| logging.info( |
| "Logging initialized. Log file: %s (max_size=%dMB, backups=%d)", |
| log_file, |
| max_bytes // (1024 * 1024), |
| backup_count, |
| ) |
| |
| @staticmethod |
| def extract_base_unix_socket_path(unix_socket_uri: str) -> str: |
| """ |
| Extract the file system path from a gRPC Unix socket URI. |
| |
| Args: |
| unix_socket_uri: URI in format 'grpc+unix:///path/to/socket' |
| |
| Returns: |
| The file system path without the protocol prefix |
| """ |
| if unix_socket_uri.startswith("grpc+unix://"): |
| unix_socket_uri = unix_socket_uri[len("grpc+unix://") :] |
| return unix_socket_uri |
| |
| @staticmethod |
| def remove_unix_socket(unix_socket_uri: str) -> None: |
| """ |
| Remove the Unix domain socket file if it exists. |
| |
| Args: |
| unix_socket_uri: URI of the Unix socket to remove |
| """ |
| if unix_socket_uri is None: |
| return |
| base_unix_socket_path = ServerState.extract_base_unix_socket_path( |
| unix_socket_uri |
| ) |
| if os.path.exists(base_unix_socket_path): |
| try: |
| os.unlink(base_unix_socket_path) |
| logging.info( |
| "Removed UNIX socket %s successfully", base_unix_socket_path |
| ) |
| except OSError as e: |
| logging.error( |
| "Failed to remove UNIX socket %s: %s", base_unix_socket_path, e |
| ) |
| else: |
| logging.warning("UNIX socket %s does not exist", base_unix_socket_path) |
| |
| @staticmethod |
| def monitor_parent_exit(): |
| """ |
| Monitor the parent process and exit gracefully if it dies. |
| This prevents orphaned UDF server processes. |
| """ |
| parent_pid = os.getppid() |
| if parent_pid == 1: |
| # Parent process is init, no need to monitor |
| logging.info("Parent process is init (PID 1), skipping parent monitoring") |
| return |
| |
| logging.info("Started monitoring parent process (PID: %s)", parent_pid) |
| |
| while True: |
| try: |
| # os.kill(pid, 0) only checks whether the process exists |
| # without sending an actual signal |
| os.kill(parent_pid, 0) |
| except OSError: |
| # Parent process died |
| ServerState.remove_unix_socket(ServerState.unix_socket_path) |
| logging.error( |
| "Parent process %s died, exiting UDF server, unix socket path: %s", |
| parent_pid, |
| ServerState.unix_socket_path, |
| ) |
| os._exit(0) |
| # Check every 2 seconds |
| time.sleep(2) |
| |
| |
| ServerState.setup_logging() |
| monitor_thread = threading.Thread(target=ServerState.monitor_parent_exit, daemon=True) |
| monitor_thread.start() |
| |
| |
| @contextmanager |
| def temporary_sys_path(path: str): |
| """ |
| Context manager to temporarily add a path to sys.path. |
| Ensures the path is removed after use to avoid pollution. |
| |
| Args: |
| path: Directory path to add to sys.path |
| |
| Yields: |
| None |
| """ |
| path_added = False |
| if path not in sys.path: |
| sys.path.insert(0, path) |
| path_added = True |
| |
| try: |
| yield |
| finally: |
| if path_added and path in sys.path: |
| sys.path.remove(path) |
| |
| |
| def int32_to_uint32(value: int) -> int: |
| """ |
| Convert a signed int32 to unsigned uint32 representation. |
| |
| This is used for IPv4 addresses where Arrow uses int32 but the value |
| should be interpreted as uint32. |
| |
| Args: |
| value: Signed 32-bit integer (can be negative) |
| |
| Returns: |
| Unsigned 32-bit integer (0 to 4294967295) |
| |
| Example: |
| >>> int32_to_uint32(-1062731519) |
| 3232235777 # 192.168.1.1 |
| """ |
| return value & 0xFFFFFFFF |
| |
| |
| def uint32_to_int32(value: int) -> int: |
| """ |
| Convert an unsigned uint32 to signed int32 representation. |
| |
| This is used when returning IPv4 addresses to Arrow which expects int32. |
| |
| Args: |
| value: Unsigned 32-bit integer (0 to 4294967295) |
| |
| Returns: |
| Signed 32-bit integer (-2147483648 to 2147483647) |
| |
| Example: |
| >>> uint32_to_int32(3232235777) |
| -1062731519 # 192.168.1.1 as signed int32 |
| """ |
| if value > 0x7FFFFFFF: # 2147483647 |
| return value - 0x100000000 # 2 ** 32 |
| return value |
| |
| |
| def convert_arrow_field_to_python(field, column_metadata=None): |
| """ |
| Convert Arrow field to Python value, such as IP types. |
| |
| This function checks the column metadata for Doris IP types and automatically |
| converts them to Python ipaddress objects: |
| - IPv4 (Arrow int32) -> ipaddress.IPv4Address |
| - IPv6 (Arrow utf8) -> ipaddress.IPv6Address |
| - ... |
| |
| Args: |
| field: Arrow scalar field value |
| column_metadata: Optional Arrow field metadata dict containing type information |
| |
| Returns: |
| Converted Python value (with automatic IP address conversion if metadata present) |
| """ |
| if field is None: |
| return None |
| |
| if pa.types.is_map(field.type): |
| # pyarrow.lib.MapScalar's as_py() returns a list of tuples, convert to dict |
| list_of_tuples = field.as_py() |
| return dict(list_of_tuples) if list_of_tuples is not None else None |
| |
| # Check if we should apply special IP type conversion based on metadata |
| if column_metadata: |
| # Arrow metadata keys can be either bytes or str depending on how they were created |
| doris_type = column_metadata.get(b'doris_type') or column_metadata.get('doris_type') |
| |
| # Handle Doris IPv4 type (Arrow int32 -> ipaddress.IPv4Address) |
| if doris_type in (b'IPV4', 'IPV4'): |
| if pa.types.is_int32(field.type): |
| value = field.as_py() |
| if value is not None: |
| try: |
| return ipaddress.IPv4Address(int32_to_uint32(value)) |
| except (ValueError, TypeError) as e: |
| logging.warning( |
| "Failed to convert int32 %s to IPv4Address: %s", value, e |
| ) |
| return value |
| return None |
| # Handle Doris IPv6 type (Arrow utf8 -> ipaddress.IPv6Address) |
| elif doris_type in (b'IPV6', 'IPV6'): |
| if pa.types.is_string(field.type) or pa.types.is_large_string(field.type): |
| value = field.as_py() |
| if value is not None: |
| try: |
| return ipaddress.IPv6Address(value) |
| except (ValueError, TypeError) as e: |
| logging.warning( |
| "Failed to convert string '%s' to IPv6Address: %s", value, e |
| ) |
| return value |
| return None |
| |
| return field.as_py() |
| |
| |
| def convert_python_to_arrow_value(value, output_type=None): |
| """ |
| Convert Python value back to Arrow-compatible value. |
| |
| This function handles the reverse conversion of IP addresses: |
| - ipaddress.IPv4Address -> int (with uint32 to int32 conversion) |
| - ipaddress.IPv6Address -> str (for Arrow utf8) |
| |
| Type Safety: |
| For IPv4/IPv6 return types, MUST return ipaddress objects. |
| Returning raw integers or strings will raise TypeError. |
| |
| Args: |
| value: Python value to convert (can be single value or iterable) |
| output_type: Optional Arrow DataType with metadata |
| |
| Returns: |
| Arrow-compatible value |
| """ |
| if value is None: |
| return None |
| |
| is_ipv4_output = False |
| is_ipv6_output = False |
| |
| if output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata: |
| # Arrow metadata keys can be either bytes or str depending on how they were created |
| doris_type = output_type.metadata.get(b'doris_type') or output_type.metadata.get('doris_type') |
| if doris_type in (b'IPV4', 'IPV4'): |
| is_ipv4_output = True |
| elif doris_type in (b'IPV6', 'IPV6'): |
| is_ipv6_output = True |
| |
| # Convert IPv4Address back to int |
| if isinstance(value, ipaddress.IPv4Address): |
| return uint32_to_int32(int(value)) |
| |
| # Convert IPv6Address back to str |
| if isinstance(value, ipaddress.IPv6Address): |
| return str(value) |
| |
| # IPv4 output must return IPv4Address objects |
| if is_ipv4_output and isinstance(value, int): |
| raise TypeError( |
| f"IPv4 UDF must return ipaddress.IPv4Address object, got int ({value}). " |
| f"Use: return ipaddress.IPv4Address({value})" |
| ) |
| |
| # IPv6 output must return IPv6Address objects |
| if is_ipv6_output and isinstance(value, str): |
| raise TypeError( |
| f"IPv6 UDF must return ipaddress.IPv6Address object, got str ('{value}'). " |
| f"Use: return ipaddress.IPv6Address('{value}')" |
| ) |
| |
| # Handle list of values (but not tuples that might be struct data) |
| if isinstance(value, list): |
| # For list types, recursively convert elements |
| if output_type and pa.types.is_list(output_type): |
| element_type = output_type.value_type |
| return [convert_python_to_arrow_value(v, element_type) for v in value] |
| else: |
| # No type info, just recurse without type |
| return [convert_python_to_arrow_value(v, None) for v in value] |
| |
| # Handle tuple values (could be struct data) |
| if isinstance(value, tuple): |
| # For struct types, convert each field with its corresponding type |
| if output_type and pa.types.is_struct(output_type): |
| if len(value) != len(output_type): |
| raise ValueError( |
| f"Struct has {len(output_type)} fields but tuple has {len(value)} elements" |
| ) |
| # Convert each tuple element with its corresponding field type |
| return tuple( |
| convert_python_to_arrow_value(v, output_type[i].type) |
| for i, v in enumerate(value) |
| ) |
| else: |
| # Not a struct type, treat as regular tuple and recurse without type |
| return tuple(convert_python_to_arrow_value(v, None) for v in value) |
| |
| if isinstance(value, dict): |
| # For map types, convert keys and values recursively |
| if output_type and pa.types.is_map(output_type): |
| key_type = output_type.key_type |
| item_type = output_type.item_type |
| # Convert dict to list of tuples (PyArrow Map format) |
| converted_items = [ |
| (convert_python_to_arrow_value(k, key_type), |
| convert_python_to_arrow_value(v, item_type)) |
| for k, v in value.items() |
| ] |
| return converted_items |
| else: |
| # No type info, just recurse without type |
| return [(convert_python_to_arrow_value(k, None), |
| convert_python_to_arrow_value(v, None)) |
| for k, v in value.items()] |
| |
| if isinstance(value, pd.Series): |
| return value.apply(lambda v: convert_python_to_arrow_value(v, output_type)) |
| |
| return value |
| |
| |
| class VectorType(Enum): |
| """Enum representing supported vector types.""" |
| |
| LIST = "list" |
| PANDAS_SERIES = "pandas.Series" |
| ARROW_ARRAY = "pyarrow.Array" |
| |
| @property |
| def python_type(self): |
| """ |
| Returns the Python type corresponding to this VectorType. |
| |
| Returns: |
| The Python type class (list, pd.Series, or pa.Array) |
| """ |
| mapping = { |
| VectorType.LIST: list, |
| VectorType.PANDAS_SERIES: pd.Series, |
| VectorType.ARROW_ARRAY: pa.Array, |
| } |
| return mapping[self] |
| |
| @staticmethod |
| def resolve_vector_type(param: inspect.Parameter): |
| """ |
| Resolves the param's type annotation to the corresponding VectorType enum. |
| Returns None if the type is unsupported or not a vector type. |
| """ |
| if ( |
| param is None |
| or param.annotation is None |
| or param.annotation is inspect.Parameter.empty |
| ): |
| return None |
| |
| annotation = param.annotation |
| origin = get_origin(annotation) |
| raw_type = origin if origin is not None else annotation |
| |
| if raw_type is list: |
| return VectorType.LIST |
| if raw_type is pd.Series: |
| return VectorType.PANDAS_SERIES |
| |
| return None |
| |
| |
| class ClientType(Enum): |
| """Enum representing Python client types.""" |
| |
| UDF = 0 |
| UDAF = 1 |
| UDTF = 2 |
| UNKNOWN = 3 |
| |
| def __str__(self) -> str: |
| """Return string representation of the client type.""" |
| return self.name |
| |
| |
| class PythonUDFMeta: |
| """Metadata container for a Python UDF.""" |
| |
| def __init__( |
| self, |
| name: str, |
| symbol: str, |
| location: str, |
| udf_load_type: int, |
| runtime_version: str, |
| always_nullable: bool, |
| inline_code: bytes, |
| input_types: pa.Schema, |
| output_type: pa.DataType, |
| client_type: int, |
| ) -> None: |
| """ |
| Initialize Python UDF metadata. |
| |
| Args: |
| name: UDF function name |
| symbol: Symbol to load (function name or module.function) |
| location: File path or directory containing the UDF |
| udf_load_type: 0 for inline code, 1 for module |
| runtime_version: Python runtime version requirement |
| always_nullable: Whether the UDF can return NULL values |
| inline_code: Base64-encoded inline Python code (if applicable) |
| input_types: PyArrow schema for input parameters |
| output_type: PyArrow data type for return value |
| client_type: 0 for UDF, 1 for UDAF, 2 for UDTF |
| """ |
| self.name = name |
| self.symbol = symbol |
| self.location = location |
| self.udf_load_type = udf_load_type |
| self.runtime_version = runtime_version |
| self.always_nullable = always_nullable |
| self.inline_code = inline_code |
| self.input_types = input_types |
| self.output_type = output_type |
| self.client_type = ClientType(client_type) |
| |
| def is_udf(self) -> bool: |
| """Check if this is a UDF (User-Defined Function).""" |
| return self.client_type == ClientType.UDF |
| |
| def is_udaf(self) -> bool: |
| """Check if this is a UDAF (User-Defined Aggregate Function).""" |
| return self.client_type == ClientType.UDAF |
| |
| def is_udtf(self) -> bool: |
| """Check if this is a UDTF (User-Defined Table Function).""" |
| return self.client_type == ClientType.UDTF |
| |
| def __str__(self) -> str: |
| """Returns a string representation of the UDF metadata.""" |
| udf_load_type_str = "INLINE" if self.udf_load_type == 0 else "MODULE" |
| return ( |
| f"PythonUDFMeta(name={self.name}, symbol={self.symbol}, " |
| f"location={self.location}, udf_load_type={udf_load_type_str}, runtime_version={self.runtime_version}, " |
| f"always_nullable={self.always_nullable}, client_type={self.client_type.name}, " |
| f"input_types={self.input_types}, output_type={self.output_type})" |
| ) |
| |
| |
| class AdaptivePythonUDF: |
| """ |
| A wrapper around a UDF function that supports both scalar and vectorized execution modes. |
| The mode is determined by the type hints of the function parameters. |
| """ |
| |
| def __init__(self, python_udf_meta: PythonUDFMeta, func: Callable) -> None: |
| """ |
| Initialize the adaptive UDF wrapper. |
| |
| Args: |
| python_udf_meta: Metadata describing the UDF |
| func: The actual Python function to execute |
| """ |
| self.python_udf_meta = python_udf_meta |
| self._eval_func = func |
| |
| def __str__(self) -> str: |
| """Returns a string representation of the UDF wrapper.""" |
| input_type_strs = [str(t) for t in self.python_udf_meta.input_types.types] |
| output_type_str = str(self.python_udf_meta.output_type) |
| eval_func_str = f"{self.python_udf_meta.name}({', '.join(input_type_strs)}) -> {output_type_str}" |
| return f"AdaptivePythonUDF(python_udf_meta: {self.python_udf_meta}, eval_func: {eval_func_str})" |
| |
| def __call__(self, record_batch: pa.RecordBatch) -> pa.Array: |
| """ |
| Executes the UDF on the given record batch. Supports both scalar and vectorized modes. |
| |
| :param record_batch: Input data with N columns, each of length num_rows |
| :return: Output array of length num_rows |
| """ |
| if record_batch.num_rows == 0: |
| return pa.array([], type=self._get_output_type()) |
| |
| if self._should_use_vectorized(): |
| return self._vectorized_call(record_batch) |
| |
| return self._scalar_call(record_batch) |
| |
| @staticmethod |
| def _cast_arrow_to_vector(arrow_array: pa.Array, vec_type: VectorType): |
| """ |
| Convert a pa.Array to an instance of the specified VectorType. |
| """ |
| if vec_type == VectorType.LIST: |
| return arrow_array.to_pylist() |
| elif vec_type == VectorType.PANDAS_SERIES: |
| return arrow_array.to_pandas() |
| else: |
| raise ValueError(f"Unsupported vector type: {vec_type}") |
| |
| def _should_use_vectorized(self) -> bool: |
| """ |
| Determines whether to use vectorized mode based on parameter type annotations. |
| Returns True if any parameter is annotated as: |
| - list |
| - pd.Series |
| """ |
| try: |
| signature = inspect.signature(self._eval_func) |
| except ValueError: |
| # Cannot inspect built-in or C functions; default to scalar |
| return False |
| |
| for param in signature.parameters.values(): |
| if VectorType.resolve_vector_type(param): |
| return True |
| |
| return False |
| |
| def _scalar_call(self, record_batch: pa.RecordBatch) -> pa.Array: |
| """ |
| Applies the UDF in scalar mode: one row at a time. |
| |
| Args: |
| record_batch: Input data batch |
| |
| Returns: |
| Output array with results for each row |
| """ |
| columns = record_batch.columns |
| num_rows = record_batch.num_rows |
| result = [] |
| |
| column_metadata = [ |
| record_batch.schema.field(col_idx).metadata |
| for col_idx in range(len(columns)) |
| ] |
| |
| for i in range(num_rows): |
| converted_args = [ |
| convert_arrow_field_to_python(col[i], meta) |
| for col, meta in zip(columns, column_metadata) |
| ] |
| |
| try: |
| res = self._eval_func(*converted_args) |
| # Check if result is None when always_nullable is False |
| if res is None and not self.python_udf_meta.always_nullable: |
| raise RuntimeError( |
| f"the result of row {i} is null, but the return type is not nullable, " |
| f"please check the always_nullable property in create function statement, " |
| f"it should be true" |
| ) |
| result.append(convert_python_to_arrow_value(res, self.python_udf_meta.output_type)) |
| except Exception as e: |
| logging.error( |
| "Error in scalar UDF execution at row %s: %s\nArgs: %s\nTraceback: %s", |
| i, |
| e, |
| converted_args, |
| traceback.format_exc(), |
| ) |
| # Return None for failed rows if always_nullable is True |
| if self.python_udf_meta.always_nullable: |
| result.append(None) |
| else: |
| raise |
| |
| return pa.array(result, type=self._get_output_type()) |
| |
| def _vectorized_call(self, record_batch: pa.RecordBatch) -> pa.Array: |
| """ |
| Applies the UDF in vectorized mode: processes entire columns at once. |
| |
| Args: |
| record_batch: Input data batch |
| |
| Returns: |
| Output array with results |
| """ |
| column_args = record_batch.columns |
| logging.info("Vectorized call with %s columns", len(column_args)) |
| |
| sig = inspect.signature(self._eval_func) |
| params = list(sig.parameters.values()) |
| |
| if len(column_args) != len(params): |
| raise ValueError(f"UDF expects {len(params)} args, got {len(column_args)}") |
| |
| converted_args = [] |
| for param, arrow_col in zip(params, column_args): |
| vec_type = VectorType.resolve_vector_type(param) |
| |
| if vec_type is None: |
| # For scalar types (int, float, str, etc.), extract the first value |
| # instead of converting to list |
| pylist = arrow_col.to_pylist() |
| if len(pylist) > 0: |
| converted = pylist[0] |
| logging.info( |
| "Converted %s to scalar (first value): %s", |
| param.name, |
| type(converted).__name__, |
| ) |
| else: |
| converted = None |
| logging.info( |
| "Converted %s to scalar (None, empty column)", param.name |
| ) |
| else: |
| converted = self._cast_arrow_to_vector(arrow_col, vec_type) |
| logging.info("Converted %s: %s", param.name, vec_type) |
| |
| converted_args.append(converted) |
| |
| try: |
| result = self._eval_func(*converted_args) |
| except Exception as e: |
| logging.error( |
| "Error in vectorized UDF: %s\nTraceback: %s", e, traceback.format_exc() |
| ) |
| raise RuntimeError(f"Error in vectorized UDF: {e}") from e |
| |
| result = convert_python_to_arrow_value(result, self.python_udf_meta.output_type) |
| |
| # Convert result to PyArrow Array |
| result_array = None |
| if isinstance(result, pd.Series): |
| result_array = pa.array(result, type=self._get_output_type()) |
| elif isinstance(result, list): |
| result_array = pa.array(result, type=self._get_output_type()) |
| else: |
| # Scalar result - broadcast to all rows |
| out_type = self._get_output_type() |
| logging.warning( |
| "UDF returned scalar value, broadcasting to %s rows", |
| record_batch.num_rows, |
| ) |
| result_array = pa.array([result] * record_batch.num_rows, type=out_type) |
| |
| # Check for None values when always_nullable is False |
| if not self.python_udf_meta.always_nullable: |
| null_count = result_array.null_count |
| if null_count > 0: |
| # Find the first null index for error message |
| for i, value in enumerate(result_array): |
| if value.is_valid is False: |
| raise RuntimeError( |
| f"the result of row {i} is null, but the return type is not nullable, " |
| f"please check the always_nullable property in create function statement, " |
| f"it should be true" |
| ) |
| |
| return result_array |
| |
| def _get_output_type(self) -> pa.DataType: |
| """ |
| Returns the expected output type for the UDF. |
| |
| Returns: |
| PyArrow DataType for the output |
| """ |
| return self.python_udf_meta.output_type or pa.null() |
| |
| |
| class UDFLoader(ABC): |
| """Abstract base class for loading UDFs from different sources.""" |
| |
| def __init__(self, python_udf_meta: PythonUDFMeta) -> None: |
| """ |
| Initialize the UDF loader. |
| |
| Args: |
| python_udf_meta: Metadata describing the UDF to load |
| """ |
| self.python_udf_meta = python_udf_meta |
| |
| @abstractmethod |
| def load(self) -> AdaptivePythonUDF: |
| """Load the UDF and return an AdaptivePythonUDF wrapper.""" |
| raise NotImplementedError("Subclasses must implement load().") |
| |
| |
| class InlineUDFLoader(UDFLoader): |
| """Loads a UDF defined directly in inline code.""" |
| |
| def load(self) -> AdaptivePythonUDF: |
| """ |
| Load and execute inline Python code to extract the UDF function. |
| |
| Returns: |
| AdaptivePythonUDF wrapper around the loaded function |
| |
| Raises: |
| RuntimeError: If code execution fails |
| ValueError: If the function is not found or not callable |
| """ |
| symbol = self.python_udf_meta.symbol |
| inline_code = self.python_udf_meta.inline_code.decode("utf-8") |
| env: dict[str, Any] = {} |
| |
| try: |
| # Execute the code in a clean environment |
| # pylint: disable=exec-used |
| # Note: exec() is necessary here for dynamic UDF loading from inline code |
| exec(inline_code, env) # nosec B102 |
| except Exception as e: |
| logging.error( |
| "Failed to exec inline code: %s\nTraceback: %s", |
| e, |
| traceback.format_exc(), |
| ) |
| raise RuntimeError(f"Failed to exec inline code: {e}") from e |
| |
| func = env.get(symbol) |
| if func is None: |
| available_funcs = [ |
| k for k, v in env.items() if callable(v) and not k.startswith("_") |
| ] |
| logging.error( |
| "Function '%s' not found in inline code. Available functions: %s", |
| symbol, |
| available_funcs, |
| ) |
| raise ValueError(f"Function '{symbol}' not found in inline code.") |
| |
| if not callable(func): |
| logging.error( |
| "'%s' exists but is not callable (type: %s)", symbol, type(func) |
| ) |
| raise ValueError(f"'{symbol}' is not a callable function.") |
| |
| return AdaptivePythonUDF(self.python_udf_meta, func) |
| |
| |
| class ModuleUDFLoader(UDFLoader): |
| """Loads a UDF from a Python module file (.py).""" |
| |
| # Class-level lock dictionary for thread-safe module imports |
| # Using RLock allows the same thread to acquire the lock multiple times |
| # Key: (location, module_name) tuple to avoid conflicts between different locations |
| _import_locks: Dict[Tuple[str, str], threading.RLock] = {} |
| _import_locks_lock = threading.Lock() |
| _module_cache: Dict[Tuple[str, str], Any] = {} |
| _module_cache_lock = threading.Lock() |
| |
| @classmethod |
| def _get_import_lock(cls, location: str, module_name: str) -> threading.RLock: |
| """ |
| Get or create a reentrant lock for the given location and module name. |
| |
| Uses double-checked locking pattern for optimal performance: |
| - Fast path: return existing lock without acquiring global lock |
| - Slow path: create new lock under global lock protection |
| |
| Args: |
| location: The directory path where the module is located |
| module_name: The full module name to import |
| """ |
| cache_key = (location, module_name) |
| |
| # Fast path: check without lock (read-only, safe for most cases) |
| if cache_key in cls._import_locks: |
| return cls._import_locks[cache_key] |
| |
| # Slow path: create lock under protection |
| with cls._import_locks_lock: |
| # Double-check: another thread might have created it while we waited |
| if cache_key not in cls._import_locks: |
| cls._import_locks[cache_key] = threading.RLock() |
| return cls._import_locks[cache_key] |
| |
| def load(self) -> AdaptivePythonUDF: |
| """ |
| Loads a UDF from a Python module file. |
| |
| Returns: |
| AdaptivePythonUDF instance wrapping the loaded function |
| |
| Raises: |
| ValueError: If module file not found |
| TypeError: If symbol is not callable |
| """ |
| symbol = self.python_udf_meta.symbol # [package_name.]module_name.function_name |
| location = self.python_udf_meta.location # /path/to/module_name[.py] |
| |
| if not os.path.exists(location): |
| raise ValueError(f"Module file not found: {location}") |
| |
| package_name, module_name, func_name = self.parse_symbol(symbol) |
| func = self.load_udf_from_module(location, package_name, module_name, func_name) |
| |
| if not callable(func): |
| raise TypeError( |
| f"'{symbol}' exists but is not callable (type: {type(func).__name__})" |
| ) |
| |
| return AdaptivePythonUDF(self.python_udf_meta, func) |
| |
| def parse_symbol(self, symbol: str): |
| """ |
| Parse symbol into (package_name, module_name, func_name) |
| |
| Supported formats: |
| - "module.func" → (None, module, func) |
| - "package.module.func" → (package, module, func) |
| """ |
| if not symbol or "." not in symbol: |
| raise ValueError( |
| f"Invalid symbol format: '{symbol}'. " |
| "Expected 'module.function' or 'package.module.function'" |
| ) |
| |
| parts = symbol.split(".") |
| if len(parts) == 2: |
| # module.func → Single-file mode |
| module_name, func_name = parts |
| package_name = None |
| if not module_name or not module_name.strip(): |
| raise ValueError(f"Module name is empty in symbol: '{symbol}'") |
| if not func_name or not func_name.strip(): |
| raise ValueError(f"Function name is empty in symbol: '{symbol}'") |
| elif len(parts) > 2: |
| package_name = parts[0] |
| module_name = ".".join(parts[1:-1]) |
| func_name = parts[-1] |
| if not package_name or not package_name.strip(): |
| raise ValueError(f"Package name is empty in symbol: '{symbol}'") |
| if not module_name or not module_name.strip(): |
| raise ValueError(f"Module name is empty in symbol: '{symbol}'") |
| if not func_name or not func_name.strip(): |
| raise ValueError(f"Function name is empty in symbol: '{symbol}'") |
| else: |
| raise ValueError(f"Invalid symbol format: '{symbol}'") |
| |
| return package_name, module_name, func_name |
| |
| def _get_or_import_module(self, location: str, full_module_name: str) -> Any: |
| """ |
| Get module from cache or import it (thread-safe). |
| |
| Uses a location-aware cache to prevent conflicts when different locations |
| have modules with the same name. |
| """ |
| cache_key = (location, full_module_name) |
| |
| # Use a per-(location, module) lock to prevent race conditions during import |
| import_lock = ModuleUDFLoader._get_import_lock(location, full_module_name) |
| |
| with import_lock: |
| # Fast path: check location-aware cache first |
| if cache_key in ModuleUDFLoader._module_cache: |
| cached_module = ModuleUDFLoader._module_cache[cache_key] |
| if cached_module is not None and ( |
| hasattr(cached_module, "__file__") |
| or hasattr(cached_module, "__path__") |
| ): |
| return cached_module |
| else: |
| del ModuleUDFLoader._module_cache[cache_key] |
| |
| # Before importing, clear any existing module with the same name in sys.modules |
| # that might have been loaded from a different location |
| if full_module_name in sys.modules: |
| existing_module = sys.modules[full_module_name] |
| existing_file = getattr(existing_module, "__file__", None) |
| # Check if the existing module is from a different location |
| if existing_file and not existing_file.startswith(location): |
| del sys.modules[full_module_name] |
| |
| with temporary_sys_path(location): |
| try: |
| module = importlib.import_module(full_module_name) |
| # Store in location-aware cache |
| ModuleUDFLoader._module_cache[cache_key] = module |
| return module |
| except Exception: |
| # Clean up any partially-imported modules |
| if full_module_name in sys.modules: |
| del sys.modules[full_module_name] |
| if cache_key in ModuleUDFLoader._module_cache: |
| del ModuleUDFLoader._module_cache[cache_key] |
| raise |
| |
| def _extract_function( |
| self, module: Any, func_name: str, module_name: str |
| ) -> Callable: |
| """Extract and validate function from module.""" |
| func = getattr(module, func_name, None) |
| if func is None: |
| # Diagnostic info: log module details to understand why function is missing |
| module_attrs = dir(module) |
| module_file = getattr(module, "__file__", "N/A") |
| module_dict_keys = ( |
| list(module.__dict__.keys()) if hasattr(module, "__dict__") else [] |
| ) |
| |
| logging.error( |
| "Function '%s' not found in module '%s'. " |
| "Module file: %s, " |
| "Public attributes: %s, " |
| "All dict keys: %s", |
| func_name, |
| module_name, |
| module_file, |
| [a for a in module_attrs if not a.startswith("_")][:20], |
| module_dict_keys[:20], |
| ) |
| |
| # Check if module has import errors stored |
| if hasattr(module, "__import_error__"): |
| logging.error( |
| "Module '%s' has stored import error: %s", |
| module_name, |
| module.__import_error__, |
| ) |
| |
| raise AttributeError( |
| f"Function '{func_name}' not found in module '{module_name}'" |
| ) |
| if not callable(func): |
| raise TypeError(f"'{func_name}' is not callable") |
| return func |
| |
| def _load_single_file_udf( |
| self, location: str, module_name: str, func_name: str |
| ) -> Callable: |
| """Load UDF from a single Python file.""" |
| py_file = os.path.join(location, f"{module_name}.py") |
| if not os.path.isfile(py_file): |
| raise ImportError(f"Python file not found: {py_file}") |
| |
| try: |
| udf_module = self._get_or_import_module(location, module_name) |
| return self._extract_function(udf_module, func_name, module_name) |
| except (ImportError, AttributeError, TypeError) as e: |
| raise ImportError( |
| f"Failed to load single-file UDF '{module_name}.{func_name}': {e}" |
| ) from e |
| except Exception as e: |
| logging.error( |
| "Unexpected error loading UDF: %s\n%s", e, traceback.format_exc() |
| ) |
| raise |
| |
| def _ensure_package_init(self, package_path: str, package_name: str) -> None: |
| """Ensure __init__.py exists in the package directory.""" |
| init_path = os.path.join(package_path, "__init__.py") |
| if not os.path.exists(init_path): |
| logging.warning( |
| "__init__.py not found in package '%s', attempting to create it", |
| package_name, |
| ) |
| try: |
| with open(init_path, "w", encoding="utf-8") as f: |
| f.write( |
| "# Auto-generated by UDF loader to make directory a Python package\n" |
| ) |
| except OSError as e: |
| raise ImportError( |
| f"Cannot create __init__.py in package '{package_name}': {e}" |
| ) from e |
| |
| def _build_full_module_name(self, package_name: str, module_name: str) -> str: |
| """Build the full module name for package mode.""" |
| if module_name == "__init__": |
| return package_name |
| return f"{package_name}.{module_name}" |
| |
| def _load_package_udf( |
| self, location: str, package_name: str, module_name: str, func_name: str |
| ) -> Callable: |
| """Load UDF from a Python package.""" |
| package_path = os.path.join(location, package_name) |
| if not os.path.isdir(package_path): |
| raise ImportError(f"Package '{package_name}' not found in '{location}'") |
| |
| self._ensure_package_init(package_path, package_name) |
| |
| try: |
| full_module_name = self._build_full_module_name(package_name, module_name) |
| udf_module = self._get_or_import_module(location, full_module_name) |
| return self._extract_function(udf_module, func_name, full_module_name) |
| except (ImportError, AttributeError, TypeError) as e: |
| raise ImportError( |
| f"Failed to load packaged UDF '{package_name}.{module_name}.{func_name}': {e}" |
| ) from e |
| except Exception as e: |
| logging.error( |
| "Unexpected error loading packaged UDF: %s\n%s", |
| e, |
| traceback.format_exc(), |
| ) |
| raise |
| |
| def load_udf_from_module( |
| self, |
| location: str, |
| package_name: Optional[str], |
| module_name: str, |
| func_name: str, |
| ) -> Callable: |
| """ |
| Load a UDF from a Python module, supporting both: |
| 1. Single-file mode: package_name=None, module_name="your_file" |
| 2. Package mode: package_name="your_pkg", module_name="submodule" or "__init__" |
| |
| Args: |
| location: |
| - In package mode: parent directory of the package |
| - In single-file mode: directory containing the .py file |
| package_name: |
| - If None or empty: treat as single-file mode |
| - Else: standard package name |
| module_name: |
| - In package mode: submodule name (e.g., "main") or "__init__" |
| - In single-file mode: filename without .py (e.g., "udf_script") |
| func_name: name of the function to load |
| |
| Returns: |
| The callable UDF function. |
| """ |
| if not os.path.isdir(location): |
| raise ValueError(f"Location is not a directory: {location}") |
| |
| if not package_name or package_name.strip() == "": |
| return self._load_single_file_udf(location, module_name, func_name) |
| else: |
| return self._load_package_udf( |
| location, package_name, module_name, func_name |
| ) |
| |
| |
| class UDFLoaderFactory: |
| """Factory to select the appropriate loader based on UDF location.""" |
| |
| @staticmethod |
| def get_loader(python_udf_meta: PythonUDFMeta) -> UDFLoader: |
| """ |
| Factory method to create the appropriate UDF loader based on metadata. |
| |
| Args: |
| python_udf_meta: UDF metadata containing load type and location |
| |
| Returns: |
| Appropriate UDFLoader instance (InlineUDFLoader or ModuleUDFLoader) |
| |
| Raises: |
| ValueError: If UDF load type or location is unsupported |
| """ |
| location = python_udf_meta.location |
| udf_load_type = python_udf_meta.udf_load_type # 0: inline, 1: module |
| |
| if udf_load_type == 0: |
| return InlineUDFLoader(python_udf_meta) |
| elif udf_load_type == 1: |
| if UDFLoaderFactory.check_module(location): |
| return ModuleUDFLoader(python_udf_meta) |
| else: |
| raise ValueError(f"Unsupported UDF location: {location}") |
| else: |
| raise ValueError(f"Unsupported UDF load type: {udf_load_type}") |
| |
| @staticmethod |
| def check_module(location: str) -> bool: |
| """ |
| Checks if a location is a valid Python module or package. |
| |
| A valid module is either: |
| - A .py file, or |
| - A directory containing __init__.py (i.e., a package). |
| |
| Raises: |
| ValueError: If the location does not exist or contains no Python module. |
| |
| Returns: |
| True if valid. |
| """ |
| if not os.path.exists(location): |
| raise ValueError(f"Module not found: {location}") |
| |
| if os.path.isfile(location): |
| if location.endswith(".py"): |
| return True |
| else: |
| raise ValueError(f"File is not a Python module (.py): {location}") |
| |
| if os.path.isdir(location): |
| if UDFLoaderFactory.has_python_file_recursive(location): |
| return True |
| else: |
| raise ValueError( |
| f"Directory contains no Python (.py) files: {location}" |
| ) |
| |
| raise ValueError(f"Invalid module location (not file or directory): {location}") |
| |
| @staticmethod |
| def has_python_file_recursive(location: str) -> bool: |
| """ |
| Recursively checks if a directory contains any Python (.py) files. |
| |
| Args: |
| location: Directory path to search |
| |
| Returns: |
| True if at least one .py file is found, False otherwise |
| """ |
| path = Path(location) |
| if not path.is_dir(): |
| return False |
| return any(path.rglob("*.py")) |
| |
| |
| class UDAFClassLoader: |
| """ |
| Utility class for loading UDAF classes from various sources. |
| |
| This class is responsible for loading UDAF classes from: |
| - Inline code (embedded in SQL) |
| - Module files (imported from filesystem) |
| """ |
| |
| @staticmethod |
| def load_udaf_class(python_udf_meta: PythonUDFMeta) -> type: |
| """ |
| Load the UDAF class from metadata. |
| |
| Args: |
| python_udf_meta: UDAF metadata |
| |
| Returns: |
| The UDAF class |
| |
| Raises: |
| RuntimeError: If inline code execution fails |
| ValueError: If class is not found or invalid |
| """ |
| loader = UDFLoaderFactory.get_loader(python_udf_meta) |
| |
| # For UDAF, we need the class, not an instance |
| if isinstance(loader, InlineUDFLoader): |
| return UDAFClassLoader.load_from_inline(python_udf_meta) |
| elif isinstance(loader, ModuleUDFLoader): |
| return UDAFClassLoader.load_from_module(python_udf_meta, loader) |
| else: |
| raise ValueError(f"Unsupported loader type: {type(loader)}") |
| |
| @staticmethod |
| def load_from_inline(python_udf_meta: PythonUDFMeta) -> type: |
| """ |
| Load UDAF class from inline code. |
| |
| Args: |
| python_udf_meta: UDAF metadata with inline code |
| |
| Returns: |
| The UDAF class |
| """ |
| symbol = python_udf_meta.symbol |
| inline_code = python_udf_meta.inline_code.decode("utf-8") |
| env: dict[str, Any] = {} |
| |
| try: |
| exec(inline_code, env) # nosec B102 |
| except Exception as e: |
| raise RuntimeError(f"Failed to exec inline code: {e}") from e |
| |
| udaf_class = env.get(symbol) |
| if udaf_class is None: |
| raise ValueError(f"UDAF class '{symbol}' not found in inline code") |
| |
| if not inspect.isclass(udaf_class): |
| raise ValueError(f"'{symbol}' is not a class (type: {type(udaf_class)})") |
| |
| UDAFClassLoader.validate_udaf_class(udaf_class) |
| return udaf_class |
| |
| @staticmethod |
| def load_from_module( |
| python_udf_meta: PythonUDFMeta, loader: ModuleUDFLoader |
| ) -> type: |
| """ |
| Load UDAF class from module file. |
| |
| Args: |
| python_udf_meta: UDAF metadata with module location |
| loader: Module loader instance |
| |
| Returns: |
| The UDAF class |
| """ |
| symbol = python_udf_meta.symbol |
| location = python_udf_meta.location |
| |
| package_name, module_name, class_name = loader.parse_symbol(symbol) |
| udaf_class = loader.load_udf_from_module( |
| location, package_name, module_name, class_name |
| ) |
| |
| if not inspect.isclass(udaf_class): |
| raise ValueError(f"'{symbol}' is not a class (type: {type(udaf_class)})") |
| |
| UDAFClassLoader.validate_udaf_class(udaf_class) |
| return udaf_class |
| |
| @staticmethod |
| def validate_udaf_class(udaf_class: type): |
| """ |
| Validate that the UDAF class implements required methods. |
| |
| Args: |
| udaf_class: The class to validate |
| |
| Raises: |
| ValueError: If class doesn't implement required methods or properties |
| """ |
| required_methods = ["__init__", "accumulate", "merge", "finish"] |
| for method in required_methods: |
| if not hasattr(udaf_class, method): |
| raise ValueError( |
| f"UDAF class must implement '{method}' method. " |
| f"Missing in {udaf_class.__name__}" |
| ) |
| |
| # Check for aggregate_state property |
| if not hasattr(udaf_class, "aggregate_state"): |
| raise ValueError( |
| f"UDAF class must have 'aggregate_state' property. " |
| f"Missing in {udaf_class.__name__}" |
| ) |
| |
| # Verify it's actually a property |
| try: |
| attr = inspect.getattr_static(udaf_class, "aggregate_state") |
| if not isinstance(attr, property): |
| raise ValueError( |
| f"'aggregate_state' must be a @property in {udaf_class.__name__}" |
| ) |
| except AttributeError: |
| raise ValueError( |
| f"UDAF class must have 'aggregate_state' property. " |
| f"Missing in {udaf_class.__name__}" |
| ) |
| |
| |
| class UDAFStateManager: |
| """ |
| Manages UDAF aggregate states for Python UDAF execution. |
| |
| This class maintains a mapping from place_id to UDAF instances, |
| following the Snowflake UDAF pattern: |
| - __init__(): Initialize state |
| - aggregate_state: Property returning serializable state |
| - accumulate(*args): Add input values |
| - merge(other_state): Merge two states |
| - finish(): Return final result |
| """ |
| |
| def __init__(self): |
| """Initialize the state manager.""" |
| self.states: Dict[int, Any] = {} # place_id -> UDAF instance |
| self.udaf_class = None # UDAF class to instantiate |
| self._destroy_counter = 0 # Track number of destroys since last GC |
| self._gc_threshold = 100 # Trigger GC every N destroys |
| |
| def set_udaf_class(self, udaf_class: type): |
| """ |
| Set the UDAF class to use for creating instances. |
| |
| Args: |
| udaf_class: The UDAF class |
| |
| Note: |
| Validation is performed by UDAFClassLoader before calling this method. |
| """ |
| self.udaf_class = udaf_class |
| |
| def create_state(self, place_id: int) -> None: |
| """ |
| Create a new UDAF state for the given place_id. |
| |
| Args: |
| place_id: Unique identifier for this aggregate state (globally unique) |
| |
| Note: |
| This method assumes C++ layer guarantees no concurrent access to the same place_id. |
| """ |
| try: |
| self.states[place_id] = self.udaf_class() |
| except Exception as e: |
| logging.error( |
| "Failed to create UDAF state for place_id=%s: %s\nUDAF class: %s\nTraceback: %s", |
| place_id, |
| e, |
| self.udaf_class.__name__ if self.udaf_class else "None", |
| traceback.format_exc(), |
| ) |
| raise RuntimeError(f"Failed to create UDAF state: {e}") from e |
| |
| def get_state(self, place_id: int) -> Any: |
| """ |
| Get the UDAF state for the given place_id. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| |
| Returns: |
| The UDAF instance |
| """ |
| return self.states[place_id] |
| |
| def accumulate(self, place_id: int, *args) -> None: |
| """ |
| Accumulate input values into the aggregate state. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| *args: Input values to accumulate |
| """ |
| state = self.states[place_id] |
| try: |
| state.accumulate(*args) |
| except Exception as e: |
| logging.error( |
| "Error in accumulate for place_id %s: %s", |
| place_id, |
| e, |
| ) |
| raise RuntimeError(f"Error in accumulate: {e}") from e |
| |
| def serialize(self, place_id: int) -> bytes: |
| """ |
| Serialize the aggregate state to bytes. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| |
| Returns: |
| Serialized state as bytes (using pickle) |
| """ |
| state = self.states[place_id] |
| try: |
| aggregate_state = state.aggregate_state |
| serialized = pickle.dumps(aggregate_state) |
| return serialized |
| except Exception as e: |
| logging.error( |
| "Error serializing state for place_id %s: %s", |
| place_id, |
| e, |
| ) |
| raise RuntimeError(f"Error serializing state: {e}") from e |
| |
| def merge(self, place_id: int, other_state_bytes: bytes) -> None: |
| """ |
| Merge another serialized state into this state. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| other_state_bytes: Serialized state to merge (pickle bytes) |
| """ |
| try: |
| other_state = pickle.loads(other_state_bytes) |
| except Exception as e: |
| logging.error("Error deserializing state bytes: %s", e) |
| raise RuntimeError(f"Error deserializing state: {e}") from e |
| |
| state = self.states[place_id] |
| |
| try: |
| state.merge(other_state) |
| except Exception as e: |
| logging.error( |
| "Error in merge for place_id %s: %s", |
| place_id, |
| e, |
| ) |
| raise RuntimeError(f"Error in merge: {e}") from e |
| |
| def finalize(self, place_id: int) -> Any: |
| """ |
| Get the final result from the aggregate state. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| |
| Returns: |
| Final aggregation result |
| """ |
| state = self.states[place_id] |
| try: |
| result = state.finish() |
| return result |
| except Exception as e: |
| logging.error( |
| "Error finalizing state for place_id %s: %s", |
| place_id, |
| e, |
| ) |
| raise RuntimeError(f"Error finalizing state: {e}") from e |
| |
| def reset(self, place_id: int) -> None: |
| """ |
| Reset the aggregate state (for window functions). |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| |
| Raises: |
| RuntimeError: If state does not exist for this place_id or UDAF class not set |
| """ |
| try: |
| self.states[place_id] = self.udaf_class() |
| except Exception as e: |
| logging.error( |
| "Error resetting state for place_id %s: %s", |
| place_id, |
| e, |
| ) |
| raise RuntimeError(f"Error resetting state: {e}") from e |
| |
| def destroy(self, place_id: int) -> None: |
| """ |
| Destroy the aggregate state and free resources. |
| |
| Args: |
| place_id: Unique identifier for the aggregate state |
| """ |
| if place_id not in self.states: |
| return |
| |
| del self.states[place_id] |
| |
| self._destroy_counter += 1 |
| # Trigger GC periodically based on destroy count |
| if self._destroy_counter >= self._gc_threshold: |
| remaining = len(self.states) |
| |
| # Clear all states - force full cleanup |
| if remaining == 0: |
| self.states.clear() |
| gc.collect() |
| logging.debug( |
| "[UDAF GC] Full cleanup: all states destroyed, GC triggered" |
| ) |
| # Many states destroyed recently - trigger GC |
| elif self._destroy_counter >= self._gc_threshold: |
| gc.collect() |
| logging.debug( |
| "[UDAF GC] Periodic GC triggered after %d destroys, %d states remaining", |
| self._destroy_counter, |
| remaining, |
| ) |
| |
| self._destroy_counter = 0 |
| |
| |
| class FlightServer(flight.FlightServerBase): |
| """Arrow Flight server for executing Python UDFs, UDAFs, and UDTFs.""" |
| |
| def __init__(self, location: str): |
| """ |
| Initialize the Flight server. |
| |
| Args: |
| location: Unix socket path for the server |
| """ |
| super().__init__(location) |
| # Use a dictionary to maintain separate state managers for each UDAF function |
| # Key: function signature (name + input_types), Value: UDAFStateManager instance |
| self.udaf_state_managers: Dict[str, UDAFStateManager] = {} |
| self.udaf_managers_lock = threading.Lock() |
| |
| def _get_udaf_state_manager( |
| self, python_udaf_meta: PythonUDFMeta |
| ) -> UDAFStateManager: |
| """ |
| Get or create a state manager for the given UDAF function. |
| Each UDAF function gets its own independent state manager. |
| |
| Args: |
| python_udaf_meta: Metadata for the UDAF function |
| |
| Returns: |
| UDAFStateManager instance for this specific UDAF |
| """ |
| # Create a unique key based on function name and argument types |
| type_names = [str(field.type) for field in python_udaf_meta.input_types] |
| func_key = f"{python_udaf_meta.name}({','.join(type_names)})" |
| |
| with self.udaf_managers_lock: |
| if func_key not in self.udaf_state_managers: |
| manager = UDAFStateManager() |
| # Load and set the UDAF class for this manager using UDAFClassLoader |
| udaf_class = UDAFClassLoader.load_udaf_class(python_udaf_meta) |
| manager.set_udaf_class(udaf_class) |
| self.udaf_state_managers[func_key] = manager |
| |
| return self.udaf_state_managers[func_key] |
| |
| @staticmethod |
| def parse_python_udf_meta( |
| descriptor: flight.FlightDescriptor, |
| ) -> Optional[PythonUDFMeta]: |
| """ |
| Parses UDF/UDAF/UDTF metadata from a command descriptor. |
| |
| Returns: |
| PythonUDFMeta object containing the function metadata |
| """ |
| |
| if descriptor.descriptor_type != flight.DescriptorType.CMD: |
| logging.error("Invalid descriptor type: %s", descriptor.descriptor_type) |
| return None |
| |
| cmd_json = json.loads(descriptor.command) |
| name = cmd_json["name"] |
| symbol = cmd_json["symbol"] |
| location = cmd_json["location"] |
| udf_load_type = cmd_json["udf_load_type"] |
| runtime_version = cmd_json["runtime_version"] |
| always_nullable = cmd_json["always_nullable"] |
| # client_type: 0: UDF, 1: UDAF, 2: UDTF |
| client_type = cmd_json["client_type"] |
| |
| inline_code = base64.b64decode(cmd_json["inline_code"]) |
| input_binary = base64.b64decode(cmd_json["input_types"]) |
| output_binary = base64.b64decode(cmd_json["return_type"]) |
| |
| input_schema = pa.ipc.read_schema(pa.BufferReader(input_binary)) |
| output_schema = pa.ipc.read_schema(pa.BufferReader(output_binary)) |
| |
| if len(output_schema) != 1: |
| logging.error( |
| "Output schema must have exactly one field: %s", output_schema |
| ) |
| return None |
| |
| output_type = output_schema.field(0).type |
| |
| python_udf_meta = PythonUDFMeta( |
| name=name, |
| symbol=symbol, |
| location=location, |
| udf_load_type=udf_load_type, |
| runtime_version=runtime_version, |
| always_nullable=always_nullable, |
| inline_code=inline_code, |
| input_types=input_schema, |
| output_type=output_type, |
| client_type=client_type, |
| ) |
| |
| return python_udf_meta |
| |
| @staticmethod |
| def check_schema( |
| record_batch: pa.RecordBatch, expected_schema: pa.Schema |
| ) -> Tuple[bool, str]: |
| """ |
| Validates that the input RecordBatch schema matches the expected schema. |
| Checks that field count and types match, but field names can differ. |
| |
| :return: (result, error_message) |
| """ |
| actual = record_batch.schema |
| expected = expected_schema |
| |
| # Check field count |
| if len(actual) != len(expected): |
| return ( |
| False, |
| f"Schema length mismatch, got {len(actual)} fields, expected {len(expected)} fields", |
| ) |
| |
| # Check each field type (ignore field names) |
| for i, (actual_field, expected_field) in enumerate(zip(actual, expected)): |
| if not actual_field.type.equals(expected_field.type): |
| return False, ( |
| f"Type mismatch at field index {i}, " |
| f"got {actual_field.type}, expected {expected_field.type}" |
| ) |
| |
| return True, "" |
| |
| def _create_unified_response( |
| self, success: bool, rows_processed: int, data: bytes |
| ) -> pa.RecordBatch: |
| """ |
| Create unified UDAF response batch. |
| |
| Schema: [success: bool, rows_processed: int64, serialized_data: binary] |
| """ |
| return pa.RecordBatch.from_arrays( |
| [ |
| pa.array([success], type=pa.bool_()), |
| pa.array([rows_processed], type=pa.int64()), |
| pa.array([data], type=pa.binary()), |
| ], |
| schema=pa.schema( |
| [ |
| pa.field("success", pa.bool_()), |
| pa.field("rows_processed", pa.int64()), |
| pa.field("serialized_data", pa.binary()), |
| ] |
| ), |
| ) |
| |
| def _handle_udaf_create( |
| self, place_id: int, state_manager: UDAFStateManager |
| ) -> pa.RecordBatch: |
| """Handle UDAF CREATE operation. |
| |
| Returns: [success: bool] |
| """ |
| try: |
| state_manager.create_state(place_id) |
| success = True |
| except Exception as e: |
| logging.error( |
| "CREATE operation failed for place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| success = False |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([success], type=pa.bool_())], ["success"] |
| ) |
| |
| def _handle_udaf_accumulate( |
| self, |
| place_id: int, |
| is_single_place: bool, |
| row_start: int, |
| row_end: int, |
| data_batch: pa.RecordBatch, |
| state_manager: UDAFStateManager, |
| ) -> pa.RecordBatch: |
| """ |
| Handle UDAF ACCUMULATE operation with optimized metadata from app_metadata. |
| |
| Args: |
| place_id: Primary place identifier |
| is_single_place: If True, single aggregation; if False, GROUP BY aggregation |
| row_start: Start row index in data batch |
| row_end: End row index in data batch (exclusive) |
| data_batch: Input data RecordBatch (argument columns + optional places column) |
| state_manager: UDAF state manager instance |
| |
| Returns: [rows_processed: int64] (0 if failed) |
| """ |
| if data_batch is None: |
| raise ValueError("ACCUMULATE requires data_batch, got None") |
| |
| rows_processed = 0 |
| |
| try: |
| has_places = ( |
| data_batch.schema.field(data_batch.num_columns - 1).name == "places" |
| ) |
| num_input_cols = ( |
| data_batch.num_columns - 1 if has_places else data_batch.num_columns |
| ) |
| loop_start = row_start |
| loop_end = min(row_end, data_batch.num_rows) |
| |
| if is_single_place: |
| if place_id not in state_manager.states: |
| raise KeyError(f"State for place_id {place_id} not found") |
| state = state_manager.states[place_id] |
| |
| # Extract row range using Arrow slicing (zero-copy) |
| sliced_batch = data_batch.slice(loop_start, loop_end - loop_start) |
| columns = [sliced_batch.column(j) for j in range(num_input_cols)] |
| |
| column_metadata = [ |
| data_batch.schema.field(j).metadata |
| for j in range(num_input_cols) |
| ] |
| |
| for i in range(sliced_batch.num_rows): |
| try: |
| # Apply IP conversion based on metadata |
| row_args = tuple( |
| convert_arrow_field_to_python(col[i], meta) |
| for col, meta in zip(columns, column_metadata) |
| ) |
| state.accumulate(*row_args) |
| rows_processed += 1 |
| except Exception as e: |
| logging.error( |
| "Error in accumulate for place_id %s at row %d: %s", |
| place_id, |
| loop_start + i, |
| e, |
| ) |
| raise RuntimeError(f"Error in accumulate: {e}") from e |
| |
| del columns |
| del sliced_batch |
| else: |
| # Multiple places (GROUP BY): iterate row by row |
| places_col = data_batch.column(data_batch.num_columns - 1) |
| num_rows = data_batch.num_rows |
| data_columns = [data_batch.column(j) for j in range(num_input_cols)] |
| |
| column_metadata = [ |
| data_batch.schema.field(j).metadata |
| for j in range(num_input_cols) |
| ] |
| |
| # Process each row directly from Arrow arrays (single pass) |
| for i in range(num_rows): |
| try: |
| place_id = places_col[i].as_py() |
| state = state_manager.states[place_id] |
| row_args = tuple( |
| convert_arrow_field_to_python(col[i], meta) |
| for col, meta in zip(data_columns, column_metadata) |
| ) |
| state.accumulate(*row_args) |
| rows_processed += 1 |
| except KeyError: |
| logging.error( |
| "State not found for place_id=%s at row %d. " |
| "CREATE must be called before ACCUMULATE.", |
| place_id, |
| i, |
| ) |
| raise |
| except Exception as e: |
| logging.error( |
| "Error in accumulate for place_id %s at row %d: %s", |
| place_id, |
| i, |
| e, |
| ) |
| raise RuntimeError(f"Error in accumulate: {e}") from e |
| |
| del data_columns |
| del places_col |
| del data_batch |
| |
| except Exception as e: |
| logging.error( |
| "ACCUMULATE operation failed at row %d: %s\nTraceback: %s", |
| rows_processed, |
| e, |
| traceback.format_exc(), |
| ) |
| raise |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([rows_processed], type=pa.int64())], ["rows_processed"] |
| ) |
| |
| def _handle_udaf_serialize( |
| self, place_id: int, state_manager: UDAFStateManager |
| ) -> pa.RecordBatch: |
| """Handle UDAF SERIALIZE operation. |
| |
| Returns: [serialized_state: binary] (empty if failed) |
| """ |
| try: |
| serialized = state_manager.serialize(place_id) |
| except Exception as e: |
| logging.error( |
| "SERIALIZE operation failed for place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| serialized = b"" |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([serialized], type=pa.binary())], ["serialized_state"] |
| ) |
| |
| def _handle_udaf_merge( |
| self, |
| place_id: int, |
| data_binary: bytes, |
| state_manager: UDAFStateManager, |
| ) -> pa.RecordBatch: |
| """Handle UDAF MERGE operation. |
| |
| data_binary contains the serialized state to merge. |
| Returns: [success: bool] |
| """ |
| if data_binary is None: |
| raise ValueError(f"MERGE requires data_binary, got None") |
| |
| try: |
| state_manager.merge(place_id, data_binary) |
| success = True |
| except Exception as e: |
| logging.error( |
| "MERGE operation failed for place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| success = False |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([success], type=pa.bool_())], ["success"] |
| ) |
| |
| def _handle_udaf_finalize( |
| self, |
| place_id: int, |
| output_type: pa.DataType, |
| state_manager: UDAFStateManager, |
| ) -> pa.RecordBatch: |
| """Handle UDAF FINALIZE operation. |
| |
| Returns: [result: output_type] (null if failed) |
| """ |
| try: |
| result = convert_python_to_arrow_value(state_manager.finalize(place_id), output_type) |
| except Exception as e: |
| logging.error( |
| "FINALIZE operation failed for place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| result = None |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([result], type=output_type)], ["result"] |
| ) |
| |
| def _handle_udaf_reset( |
| self, place_id: int, state_manager: UDAFStateManager |
| ) -> pa.RecordBatch: |
| """Handle UDAF RESET operation. |
| |
| Returns: [success: bool] |
| """ |
| try: |
| state_manager.reset(place_id) |
| success = True |
| except Exception as e: |
| logging.error( |
| "RESET operation failed for place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| success = False |
| |
| return pa.RecordBatch.from_arrays( |
| [pa.array([success], type=pa.bool_())], ["success"] |
| ) |
| |
| def _handle_udaf_destroy( |
| self, place_ids: list, state_manager: UDAFStateManager |
| ) -> bool: |
| """Handle UDAF DESTROY operation for one or more place_ids. |
| |
| Args: |
| place_ids: List of place_ids to destroy (can be single element) |
| state_manager: UDAF state manager |
| |
| Returns: |
| bool: True if all destroys succeeded, False if any failed |
| """ |
| num_ids = len(place_ids) |
| success_count = 0 |
| failed_count = 0 |
| |
| for place_id in place_ids: |
| try: |
| state_manager.destroy(place_id) |
| success_count += 1 |
| except Exception as e: |
| logging.error( |
| "Failed to destroy place_id=%s: %s", |
| place_id, |
| e, |
| ) |
| failed_count += 1 |
| |
| if failed_count > 0: |
| if num_ids > 1: |
| logging.warning( |
| "[UDAF Memory] Destroy completed with %d succeeded, %d failed", |
| success_count, |
| failed_count, |
| ) |
| return False |
| |
| return True |
| |
| def _handle_exchange_udf( |
| self, |
| python_udf_meta: PythonUDFMeta, |
| reader: flight.MetadataRecordBatchReader, |
| writer: flight.MetadataRecordBatchWriter, |
| ) -> None: |
| """Handle bidirectional streaming for UDF execution.""" |
| loader = UDFLoaderFactory.get_loader(python_udf_meta) |
| udf = loader.load() |
| logging.info("Loaded UDF: %s", udf) |
| |
| started = False |
| for chunk in reader: |
| if not chunk.data: |
| logging.info("Empty chunk received, skipping") |
| continue |
| |
| check_schema_result, error_msg = self.check_schema( |
| chunk.data, python_udf_meta.input_types |
| ) |
| if not check_schema_result: |
| logging.error("Schema mismatch: %s", error_msg) |
| raise ValueError(f"Schema mismatch: {error_msg}") |
| |
| result_array = udf(chunk.data) |
| |
| if not python_udf_meta.output_type.equals(result_array.type): |
| logging.error( |
| "Output type mismatch: got %s, expected %s", |
| result_array.type, |
| python_udf_meta.output_type, |
| ) |
| raise ValueError( |
| f"Output type mismatch: got {result_array.type}, expected {python_udf_meta.output_type}" |
| ) |
| |
| result_batch = pa.RecordBatch.from_arrays([result_array], ["result"]) |
| if not started: |
| try: |
| writer.begin(result_batch.schema) |
| started = True |
| except Exception as e: |
| logging.error( |
| "Failed to begin UDF writer stream (client may have disconnected): %s", |
| e, |
| ) |
| return |
| |
| try: |
| writer.write_batch(result_batch) |
| except Exception as e: |
| logging.error( |
| "Failed to write UDF response batch (client may have disconnected): %s", |
| e, |
| ) |
| return |
| |
| def _handle_exchange_udaf( |
| self, |
| python_udaf_meta: PythonUDFMeta, |
| reader: flight.MetadataRecordBatchReader, |
| writer: flight.MetadataRecordBatchWriter, |
| ) -> None: |
| """ |
| Handle bidirectional streaming for UDAF execution. |
| |
| Protocol (optimized with direct RecordBatch transmission): |
| - app_metadata: 30-byte binary structure containing: |
| * meta_version: uint32 (4 bytes) - Metadata version (currently 1) |
| * operation: uint8 (1 byte) - UDAFOperationType enum |
| * is_single_place: uint8 (1 byte) - Boolean (ACCUMULATE only) |
| * place_id: int64 (8 bytes) - Aggregate state identifier (globally unique) |
| * row_start: int64 (8 bytes) - Start row index (ACCUMULATE only) |
| * row_end: int64 (8 bytes) - End row index (ACCUMULATE only) |
| |
| - RecordBatch data: [argument_types..., places: int64, binary_data: binary] |
| * Schema is function-specific: created from argument_types + places + binary_data columns |
| * Different operations fill different columns: |
| - ACCUMULATE (single-place): data columns are filled, places is NULL, binary_data is NULL |
| - ACCUMULATE (multi-place): data columns are filled, places contains place IDs, binary_data is NULL |
| - MERGE: data columns are NULL, places is NULL, binary_data contains serialized state |
| - Other operations (CREATE/SERIALIZE/FINALIZE/RESET/DESTROY): all columns are NULL |
| * places column: indicates which place each row belongs to in GROUP BY scenarios |
| * This eliminates extra serialization/deserialization for ACCUMULATE operations |
| |
| Response: Unified schema [success: bool, rows_processed: int64, serialized_data: binary] |
| - Different operations use different fields: |
| * CREATE/MERGE/RESET/DESTROY: use success only |
| * ACCUMULATE: use success + rows_processed (number of rows processed) |
| * SERIALIZE: use success + serialized_data (serialized_state) |
| * FINALIZE: use success + serialized_data (serialized result) |
| """ |
| |
| # Get or create state manager for this specific UDAF function |
| state_manager = self._get_udaf_state_manager(python_udaf_meta) |
| started = False |
| |
| # Define unified response schema (consistent with C++ kUnifiedUDAFResponseSchema) |
| unified_schema = pa.schema( |
| [ |
| pa.field("success", pa.bool_()), |
| pa.field("rows_processed", pa.int64()), |
| pa.field("serialized_data", pa.binary()), |
| ] |
| ) |
| |
| for chunk in reader: |
| if not chunk.data or chunk.data.num_rows == 0: |
| logging.warning("Empty chunk received, skipping") |
| continue |
| |
| batch = chunk.data |
| app_metadata = chunk.app_metadata |
| |
| # Validate app_metadata |
| if not app_metadata or len(app_metadata) != 30: |
| raise ValueError( |
| f"Invalid app_metadata: expected 30 bytes, got {len(app_metadata) if app_metadata else 0}" |
| ) |
| |
| # Parse fixed-size binary metadata (30 bytes total) |
| # Layout: meta_version(4) + operation(1) + is_single_place(1) + place_id(8) + row_start(8) + row_end(8) |
| metadata_bytes = app_metadata.to_pybytes() |
| |
| # Validate metadata version |
| meta_version = int.from_bytes(metadata_bytes[0:4], "little", signed=False) |
| if meta_version != 1: |
| raise ValueError( |
| f"Unsupported metadata version: {meta_version}. Expected version 1. " |
| "Please upgrade the Python server or downgrade the C++ client." |
| ) |
| operation_type = UDAFOperationType(metadata_bytes[4]) |
| is_single_place = metadata_bytes[5] == 1 |
| place_id = int.from_bytes(metadata_bytes[6:14], "little", signed=True) |
| row_start = int.from_bytes(metadata_bytes[14:22], "little", signed=True) |
| row_end = int.from_bytes(metadata_bytes[22:30], "little", signed=True) |
| |
| # Extract data from batch |
| # RPC schema: [argument_types..., places: int64, binary_data: binary] |
| # - Second-to-last column is places (int64) |
| # - Last column is binary_data (binary) |
| # - ACCUMULATE (single-place): data columns filled, places is NULL, binary_data is NULL |
| # - ACCUMULATE (multi-place): data columns filled, places contains place IDs, binary_data is NULL |
| # - MERGE: data columns are NULL, places is NULL, binary_data is filled |
| # - Other operations: all columns are NULL |
| |
| if batch.num_columns < 1: |
| raise ValueError(f"Expected at least 1 column, got {batch.num_columns}") |
| |
| # Last column is binary_data |
| binary_col = batch.column(batch.num_columns - 1) |
| binary_data = binary_col[0].as_py() if binary_col[0].is_valid else None |
| |
| # Handle different operations and convert to unified format |
| try: |
| if operation_type == UDAFOperationType.CREATE: |
| result_batch = self._handle_udaf_create(place_id, state_manager) |
| success = result_batch.column(0)[0].as_py() |
| result_batch = self._create_unified_response( |
| success=success, rows_processed=0, data=b"" |
| ) |
| elif operation_type == UDAFOperationType.ACCUMULATE: |
| num_data_cols = batch.num_columns - 1 |
| data_batch = pa.RecordBatch.from_arrays( |
| [batch.column(i) for i in range(num_data_cols)], |
| schema=pa.schema( |
| [batch.schema.field(i) for i in range(num_data_cols)] |
| ), |
| ) |
| result_batch_accumulate = self._handle_udaf_accumulate( |
| place_id, |
| is_single_place, |
| row_start, |
| row_end, |
| data_batch, |
| state_manager, |
| ) |
| rows_processed = result_batch_accumulate.column(0)[0].as_py() |
| result_batch = self._create_unified_response( |
| success=(rows_processed > 0), |
| rows_processed=rows_processed, |
| data=b"", |
| ) |
| elif operation_type == UDAFOperationType.SERIALIZE: |
| result_batch_serialize = self._handle_udaf_serialize( |
| place_id, state_manager |
| ) |
| serialized = result_batch_serialize.column(0)[0].as_py() |
| result_batch = self._create_unified_response( |
| success=(len(serialized) > 0) if serialized else False, |
| rows_processed=0, |
| data=serialized if serialized else b"", |
| ) |
| elif operation_type == UDAFOperationType.MERGE: |
| # For MERGE: binary_data contains the serialized state |
| result_batch_merge = self._handle_udaf_merge( |
| place_id, binary_data, state_manager |
| ) |
| success = result_batch_merge.column(0)[0].as_py() |
| result_batch = self._create_unified_response( |
| success=success, rows_processed=0, data=b"" |
| ) |
| elif operation_type == UDAFOperationType.FINALIZE: |
| result_batch_finalize = self._handle_udaf_finalize( |
| place_id, python_udaf_meta.output_type, state_manager |
| ) |
| # Serialize the result to binary (including NULL results) |
| # NULL is a valid aggregation result, not an error |
| sink = pa.BufferOutputStream() |
| ipc_writer = pa.ipc.new_stream(sink, result_batch_finalize.schema) |
| ipc_writer.write_batch(result_batch_finalize) |
| ipc_writer.close() |
| result_data = sink.getvalue().to_pybytes() |
| result_batch = self._create_unified_response( |
| success=True, |
| rows_processed=0, |
| data=result_data, |
| ) |
| elif operation_type == UDAFOperationType.RESET: |
| result_batch_reset = self._handle_udaf_reset( |
| place_id, state_manager |
| ) |
| success = result_batch_reset.column(0)[0].as_py() |
| result_batch = self._create_unified_response( |
| success=success, rows_processed=0, data=b"" |
| ) |
| elif operation_type == UDAFOperationType.DESTROY: |
| if row_end > 1: |
| # Batch destroy mode - binary_data contains serialized place_ids |
| if binary_data is None: |
| raise ValueError("DESTROY_BATCH: binary_data is None") |
| data_reader = pa.ipc.open_stream(binary_data) |
| data_batch = data_reader.read_next_batch() |
| if data_batch.num_columns != 1: |
| raise ValueError( |
| f"DESTROY_BATCH: Expected 1 column (place_ids), got {data_batch.num_columns}" |
| ) |
| place_ids_array = data_batch.column(0) |
| place_ids = [ |
| place_ids_array[i].as_py() |
| for i in range(len(place_ids_array)) |
| ] |
| else: |
| # Single destroy mode |
| place_ids = [place_id] |
| |
| success = self._handle_udaf_destroy(place_ids, state_manager) |
| result_batch = self._create_unified_response( |
| success=success, rows_processed=0, data=b"" |
| ) |
| else: |
| raise ValueError(f"Unsupported operation type: {operation_type}") |
| except Exception as e: |
| logging.error( |
| "Operation %s failed for place_id=%s: %s\nTraceback: %s", |
| operation_type, |
| place_id, |
| e, |
| traceback.format_exc(), |
| ) |
| result_batch = self._create_unified_response( |
| success=False, rows_processed=0, data=b"" |
| ) |
| |
| # Begin stream with unified schema on first call |
| if not started: |
| try: |
| writer.begin(unified_schema) |
| started = True |
| except Exception as e: |
| logging.error( |
| "Failed to begin writer stream (client may have disconnected): %s", |
| e, |
| ) |
| # Client disconnected, stop processing |
| return |
| |
| try: |
| writer.write_batch(result_batch) |
| except Exception as e: |
| logging.error( |
| "Failed to write response batch (client may have disconnected): %s", |
| e, |
| ) |
| # Client disconnected, stop processing |
| return |
| |
| del result_batch |
| |
| def _handle_exchange_udtf( |
| self, |
| python_udtf_meta: PythonUDFMeta, |
| reader: flight.MetadataRecordBatchReader, |
| writer: flight.MetadataRecordBatchWriter, |
| ) -> None: |
| """ |
| Handle bidirectional streaming for UDTF execution. |
| |
| Protocol (ListArray-based): |
| - Input: RecordBatch with input columns |
| - Output: RecordBatch with a single ListArray column |
| * ListArray automatically manages offsets internally |
| * Each list element contains the outputs for one input row |
| |
| Example: |
| Input: 3 rows |
| UDTF yields: Row 0 -> 5 outputs, Row 1 -> 2 outputs, Row 2 -> 3 outputs |
| Output: ListArray with 3 elements (one per input row) |
| - Element 0: List of 5 structs |
| - Element 1: List of 2 structs |
| - Element 2: List of 3 structs |
| """ |
| loader = UDFLoaderFactory.get_loader(python_udtf_meta) |
| adaptive_udtf = loader.load() |
| udtf_func = adaptive_udtf._eval_func |
| started = False |
| |
| for chunk in reader: |
| if not chunk.data: |
| logging.info("Empty chunk received, skipping") |
| continue |
| |
| input_batch = chunk.data |
| |
| # Validate input schema |
| check_schema_result, error_msg = self.check_schema( |
| input_batch, python_udtf_meta.input_types |
| ) |
| if not check_schema_result: |
| logging.error("Schema mismatch: %s", error_msg) |
| raise ValueError(f"Schema mismatch: {error_msg}") |
| |
| # Process all input rows and build ListArray |
| try: |
| response_batch = self._process_udtf_with_list_array( |
| udtf_func, input_batch, python_udtf_meta.output_type |
| ) |
| |
| # Send the response batch |
| if not started: |
| try: |
| writer.begin(response_batch.schema) |
| started = True |
| except Exception as e: |
| logging.error( |
| "Failed to begin UDTF writer stream (client may have disconnected): %s", |
| e, |
| ) |
| return |
| |
| try: |
| writer.write_batch(response_batch) |
| except Exception as e: |
| logging.error( |
| "Failed to write UDTF response batch (client may have disconnected): %s", |
| e, |
| ) |
| return |
| |
| except Exception as e: |
| logging.error( |
| "Error in UDTF execution: %s\nTraceback: %s", |
| e, |
| traceback.format_exc(), |
| ) |
| raise RuntimeError(f"Error in UDTF execution: {e}") from e |
| |
| def _process_udtf_with_list_array( |
| self, |
| udtf_func: Callable, |
| input_batch: pa.RecordBatch, |
| expected_output_type: pa.DataType, |
| ) -> pa.RecordBatch: |
| """ |
| Process UDTF function on all input rows and generate a ListArray. |
| |
| Args: |
| udtf_func: The UDTF function to call |
| input_batch: Input RecordBatch with N rows |
| expected_output_type: Expected Arrow type for output data |
| |
| Returns: |
| RecordBatch with a single ListArray column where each element |
| is a list of outputs for the corresponding input row |
| """ |
| all_results = [] # List of lists: one list per input row |
| |
| # Check if output is single-field or multi-field |
| # For single-field output, we allow yielding scalar values directly |
| is_single_field = not pa.types.is_struct(expected_output_type) |
| |
| column_metadata = [ |
| input_batch.schema.field(col_idx).metadata |
| for col_idx in range(input_batch.num_columns) |
| ] |
| |
| # Process each input row |
| for row_idx in range(input_batch.num_rows): |
| # Extract row as tuple of arguments with IP conversion |
| row_args = tuple( |
| convert_arrow_field_to_python( |
| input_batch.column(col_idx)[row_idx], |
| column_metadata[col_idx] |
| ) |
| for col_idx in range(input_batch.num_columns) |
| ) |
| |
| # Call UDTF function - it can yield tuples or scalar values (for single-field output) |
| result = udtf_func(*row_args) |
| |
| # Collect output rows for this input row |
| row_outputs = [] |
| if inspect.isgenerator(result): |
| for output_value in result: |
| if is_single_field: |
| # Single-field output: accept both scalar and tuple |
| if isinstance(output_value, tuple): |
| # User provided tuple (e.g., (value,)) - extract scalar |
| if len(output_value) != 1: |
| raise ValueError( |
| f"Single-field UDTF should yield 1-tuples or scalars, got {len(output_value)}-tuple" |
| ) |
| row_outputs.append( |
| output_value[0] |
| ) # Extract scalar from tuple |
| else: |
| # User provided scalar - use directly |
| row_outputs.append(output_value) |
| else: |
| # Multi-field output: must be tuple |
| if not isinstance(output_value, tuple): |
| raise ValueError( |
| f"Multi-field UDTF must yield tuples, got {type(output_value)}" |
| ) |
| row_outputs.append(output_value) |
| elif result is not None: |
| # Function returned a single value instead of yielding |
| if is_single_field: |
| # Single-field: accept scalar or tuple |
| if isinstance(result, tuple): |
| if len(result) != 1: |
| raise ValueError( |
| f"Single-field UDTF should return 1-tuple or scalar, got {len(result)}-tuple" |
| ) |
| row_outputs.append(result[0]) # Extract scalar from tuple |
| else: |
| row_outputs.append(result) |
| else: |
| # Multi-field: must be tuple |
| if not isinstance(result, tuple): |
| raise ValueError( |
| f"Multi-field UDTF must return tuples, got {type(result)}" |
| ) |
| row_outputs.append(result) |
| |
| all_results.append(row_outputs) |
| |
| all_results = convert_python_to_arrow_value(all_results, expected_output_type) |
| |
| try: |
| list_array = pa.array(all_results, type=pa.list_(expected_output_type)) |
| except Exception as e: |
| logging.error( |
| "Failed to create ListArray: %s, element_type: %s", |
| e, |
| expected_output_type, |
| ) |
| raise RuntimeError(f"Failed to create ListArray: {e}") from e |
| |
| # Create RecordBatch with single ListArray column |
| schema = pa.schema([pa.field("results", pa.list_(expected_output_type))]) |
| response_batch = pa.RecordBatch.from_arrays([list_array], schema=schema) |
| |
| return response_batch |
| |
| def do_exchange( |
| self, |
| context: flight.ServerCallContext, |
| descriptor: flight.FlightDescriptor, |
| reader: flight.MetadataRecordBatchReader, |
| writer: flight.MetadataRecordBatchWriter, |
| ) -> None: |
| """ |
| Handle bidirectional streaming for UDF, UDAF, and UDTF execution. |
| |
| Determines operation type (UDF vs UDAF vs UDTF) from descriptor metadata. |
| """ |
| python_udf_meta = self.parse_python_udf_meta(descriptor) |
| if not python_udf_meta: |
| raise ValueError("Invalid or missing metadata in descriptor") |
| |
| if python_udf_meta.is_udf(): |
| self._handle_exchange_udf(python_udf_meta, reader, writer) |
| elif python_udf_meta.is_udaf(): |
| self._handle_exchange_udaf(python_udf_meta, reader, writer) |
| elif python_udf_meta.is_udtf(): |
| self._handle_exchange_udtf(python_udf_meta, reader, writer) |
| else: |
| raise ValueError(f"Unsupported client type: {python_udf_meta.client_type}") |
| |
| def do_action( |
| self, |
| context: flight.ServerCallContext, |
| action: flight.Action, |
| ): |
| """ |
| Handle Flight actions for cache management. |
| |
| Supported actions: |
| - "clear_module_cache": Clear Python module cache for a specific location |
| Body: JSON with "location" field (the UDF cache directory path) |
| """ |
| action_type = action.type |
| |
| if action_type == "clear_module_cache": |
| yield from self._handle_clear_module_cache(action.body.to_pybytes()) |
| else: |
| raise flight.FlightUnavailableError(f"Unknown action: {action_type}") |
| |
| def _handle_clear_module_cache(self, body: bytes): |
| """ |
| Clear Python module cache for a specific UDF location. |
| |
| This removes modules from sys.modules that were loaded from the specified |
| location, allowing fresh imports when a new UDF with the same module name |
| is created. |
| """ |
| try: |
| params = json.loads(body.decode("utf-8")) |
| location = params.get("location", "") |
| |
| if not location: |
| yield flight.Result(b'{"success": false, "error": "empty location"}') |
| return |
| |
| cleared_modules = self._clear_modules_from_location(location) |
| |
| result = { |
| "success": True, |
| "cleared_modules": cleared_modules, |
| "location": location, |
| } |
| yield flight.Result(json.dumps(result).encode("utf-8")) |
| |
| except Exception as e: |
| logging.error("clear_module_cache failed: %s", e) |
| yield flight.Result(json.dumps({ |
| "success": False, |
| "error": str(e) |
| }).encode("utf-8")) |
| |
| def _clear_modules_from_location(self, location: str) -> list: |
| """ |
| Clear module cache for the given location. |
| |
| Acquires per-module import locks to ensure no concurrent import is |
| in progress for the modules being cleared, preventing race conditions |
| where sys.modules entries are removed mid-import. |
| |
| Returns list of cleared module names. |
| """ |
| cleared = [] |
| |
| with ModuleUDFLoader._module_cache_lock: |
| keys_to_remove = [ |
| key for key in ModuleUDFLoader._module_cache |
| if key[0] == location |
| ] |
| |
| # For each module, acquire its import lock before clearing. |
| # This ensures no concurrent _get_or_import_module is in progress |
| # for this (location, module_name) pair. |
| for key in keys_to_remove: |
| loc, module_name = key |
| import_lock = ModuleUDFLoader._get_import_lock(loc, module_name) |
| |
| with import_lock: |
| with ModuleUDFLoader._module_cache_lock: |
| if key in ModuleUDFLoader._module_cache: |
| del ModuleUDFLoader._module_cache[key] |
| |
| modules_to_remove = [ |
| name for name, mod in sys.modules.items() |
| if name == module_name or name.startswith(module_name + ".") |
| or ( |
| hasattr(mod, "__file__") and mod.__file__ is not None |
| and mod.__file__.startswith(location) |
| ) |
| ] |
| for mod_name in modules_to_remove: |
| del sys.modules[mod_name] |
| if mod_name not in cleared: |
| cleared.append(mod_name) |
| |
| if module_name not in cleared: |
| cleared.append(module_name) |
| |
| return cleared |
| |
| |
| class UDAFOperationType(Enum): |
| """Enum representing UDAF operation types.""" |
| |
| CREATE = 0 |
| ACCUMULATE = 1 |
| SERIALIZE = 2 |
| MERGE = 3 |
| FINALIZE = 4 |
| RESET = 5 |
| DESTROY = 6 |
| |
| |
| def check_unix_socket_path(unix_socket_path: str) -> bool: |
| """Validates the Unix domain socket path format.""" |
| if not unix_socket_path: |
| logging.error("Unix socket path is empty") |
| return False |
| |
| if not unix_socket_path.startswith("grpc+unix://"): |
| raise ValueError("gRPC UDS URL must start with 'grpc+unix://'") |
| |
| socket_path = unix_socket_path[len("grpc+unix://") :].strip() |
| if not socket_path: |
| logging.error("Extracted socket path is empty") |
| return False |
| |
| return True |
| |
| |
| def main(unix_socket_path: str) -> None: |
| """ |
| Main entry point for the Python UDF/UDAF/UDTF server. |
| |
| The server handles UDF, UDAF, and UDTF operations dynamically. |
| Operation type is determined from metadata in each request. |
| |
| Args: |
| unix_socket_path: Base path for the Unix domain socket |
| |
| Raises: |
| SystemExit: If socket path is invalid or server fails to start |
| """ |
| try: |
| if not check_unix_socket_path(unix_socket_path): |
| print(f"ERROR: Invalid socket path: {unix_socket_path}", flush=True) |
| sys.exit(1) |
| |
| current_pid = os.getpid() |
| ServerState.unix_socket_path = f"{unix_socket_path}_{current_pid}.sock" |
| |
| # Start unified server that handles UDF, UDAF, and UDTF |
| server = FlightServer(ServerState.unix_socket_path) |
| |
| print(ServerState.PYTHON_SERVER_START_SUCCESS_MSG, flush=True) |
| logging.info( |
| "##### PYTHON UDF/UDAF/UDTF SERVER STARTED AT %s #####", datetime.now() |
| ) |
| server.wait() |
| |
| except Exception as e: |
| print( |
| f"ERROR: Failed to start Python server: {type(e).__name__}: {e}", |
| flush=True, |
| ) |
| tb_lines = traceback.format_exception(type(e), e, e.__traceback__) |
| if len(tb_lines) > 1: |
| print(f"DETAIL: {tb_lines[-2].strip()}", flush=True) |
| sys.exit(1) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Run an Arrow Flight UDF/UDAF/UDTF server over Unix socket. " |
| "The server handles UDF, UDAF, and UDTF operations dynamically." |
| ) |
| parser.add_argument( |
| "unix_socket_path", |
| type=str, |
| help="Path to the Unix socket (e.g., grpc+unix:///path/to/socket)", |
| ) |
| args = parser.parse_args() |
| main(args.unix_socket_path) |