blob: d11ba2272596a2d5810d161d32448ee9b00e3b0e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import base64
import importlib
import inspect
import json
import sys
import os
import traceback
import logging
import time
import threading
import pickle
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Callable, Optional, Tuple, get_origin, Dict
from datetime import datetime
from enum import Enum
from pathlib import Path
from logging.handlers import RotatingFileHandler
import pandas as pd
import pyarrow as pa
from pyarrow import flight
class ServerState:
"""Global server state container."""
unix_socket_path: str = ""
PYTHON_SERVER_START_SUCCESS_MSG: str = "Start python server successfully"
@staticmethod
def setup_logging():
"""Setup logging configuration for the UDF server with rotation."""
doris_home = os.getenv("DORIS_HOME")
if not doris_home:
# Fallback to current directory if DORIS_HOME is not set
doris_home = os.getcwd()
log_dir = os.path.join(doris_home, "lib", "udf", "python")
os.makedirs(log_dir, exist_ok=True)
# Use shared log file with process ID in each log line
log_file = os.path.join(log_dir, "python_udf_output.log")
max_bytes = 128 * 1024 * 1024 # 128MB
backup_count = 5
# Use RotatingFileHandler to automatically manage log file size
file_handler = RotatingFileHandler(
log_file, maxBytes=max_bytes, backupCount=backup_count, encoding="utf-8"
)
# Include process ID in log format
file_handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [PID:%(process)d] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s"
)
)
logging.basicConfig(
level=logging.INFO,
handlers=[
file_handler,
logging.StreamHandler(sys.stderr), # Also log to stderr for debugging
],
)
logging.info(
"Logging initialized. Log file: %s (max_size=%dMB, backups=%d)",
log_file,
max_bytes // (1024 * 1024),
backup_count,
)
@staticmethod
def extract_base_unix_socket_path(unix_socket_uri: str) -> str:
"""
Extract the file system path from a gRPC Unix socket URI.
Args:
unix_socket_uri: URI in format 'grpc+unix:///path/to/socket'
Returns:
The file system path without the protocol prefix
"""
if unix_socket_uri.startswith("grpc+unix://"):
unix_socket_uri = unix_socket_uri[len("grpc+unix://") :]
return unix_socket_uri
@staticmethod
def remove_unix_socket(unix_socket_uri: str) -> None:
"""
Remove the Unix domain socket file if it exists.
Args:
unix_socket_uri: URI of the Unix socket to remove
"""
if unix_socket_uri is None:
return
base_unix_socket_path = ServerState.extract_base_unix_socket_path(
unix_socket_uri
)
if os.path.exists(base_unix_socket_path):
try:
os.unlink(base_unix_socket_path)
logging.info(
"Removed UNIX socket %s successfully", base_unix_socket_path
)
except OSError as e:
logging.error(
"Failed to remove UNIX socket %s: %s", base_unix_socket_path, e
)
else:
logging.warning("UNIX socket %s does not exist", base_unix_socket_path)
@staticmethod
def monitor_parent_exit():
"""
Monitor the parent process and exit gracefully if it dies.
This prevents orphaned UDF server processes.
"""
parent_pid = os.getppid()
if parent_pid == 1:
# Parent process is init, no need to monitor
logging.info("Parent process is init (PID 1), skipping parent monitoring")
return
logging.info("Started monitoring parent process (PID: %s)", parent_pid)
while True:
try:
# os.kill(pid, 0) only checks whether the process exists
# without sending an actual signal
os.kill(parent_pid, 0)
except OSError:
# Parent process died
ServerState.remove_unix_socket(ServerState.unix_socket_path)
logging.error(
"Parent process %s died, exiting UDF server, unix socket path: %s",
parent_pid,
ServerState.unix_socket_path,
)
os._exit(0)
# Check every 2 seconds
time.sleep(2)
ServerState.setup_logging()
monitor_thread = threading.Thread(target=ServerState.monitor_parent_exit, daemon=True)
monitor_thread.start()
@contextmanager
def temporary_sys_path(path: str):
"""
Context manager to temporarily add a path to sys.path.
Ensures the path is removed after use to avoid pollution.
Args:
path: Directory path to add to sys.path
Yields:
None
"""
path_added = False
if path not in sys.path:
sys.path.insert(0, path)
path_added = True
try:
yield
finally:
if path_added and path in sys.path:
sys.path.remove(path)
class VectorType(Enum):
"""Enum representing supported vector types."""
LIST = "list"
PANDAS_SERIES = "pandas.Series"
ARROW_ARRAY = "pyarrow.Array"
@property
def python_type(self):
"""
Returns the Python type corresponding to this VectorType.
Returns:
The Python type class (list, pd.Series, or pa.Array)
"""
mapping = {
VectorType.LIST: list,
VectorType.PANDAS_SERIES: pd.Series,
VectorType.ARROW_ARRAY: pa.Array,
}
return mapping[self]
@staticmethod
def resolve_vector_type(param: inspect.Parameter):
"""
Resolves the param's type annotation to the corresponding VectorType enum.
Returns None if the type is unsupported or not a vector type.
"""
if (
param is None
or param.annotation is None
or param.annotation is inspect.Parameter.empty
):
return None
annotation = param.annotation
origin = get_origin(annotation)
raw_type = origin if origin is not None else annotation
if raw_type is list:
return VectorType.LIST
if raw_type is pd.Series:
return VectorType.PANDAS_SERIES
return None
class ClientType(Enum):
"""Enum representing Python client types."""
UDF = 0
UDAF = 1
UDTF = 2
UNKNOWN = 3
@property
def display_name(self) -> str:
"""Return a human-readable display name for the client type."""
return self.name
def __str__(self) -> str:
"""Return string representation of the client type."""
return self.name
class PythonUDFMeta:
"""Metadata container for a Python UDF."""
def __init__(
self,
name: str,
symbol: str,
location: str,
udf_load_type: int,
runtime_version: str,
always_nullable: bool,
inline_code: bytes,
input_types: pa.Schema,
output_type: pa.DataType,
client_type: int,
) -> None:
"""
Initialize Python UDF metadata.
Args:
name: UDF function name
symbol: Symbol to load (function name or module.function)
location: File path or directory containing the UDF
udf_load_type: 0 for inline code, 1 for module
runtime_version: Python runtime version requirement
always_nullable: Whether the UDF can return NULL values
inline_code: Base64-encoded inline Python code (if applicable)
input_types: PyArrow schema for input parameters
output_type: PyArrow data type for return value
client_type: 0 for UDF, 1 for UDAF, 2 for UDTF
"""
self.name = name
self.symbol = symbol
self.location = location
self.udf_load_type = udf_load_type
self.runtime_version = runtime_version
self.always_nullable = always_nullable
self.inline_code = inline_code
self.input_types = input_types
self.output_type = output_type
self.client_type = ClientType(client_type)
def is_udf(self) -> bool:
"""Check if this is a UDF (User-Defined Function)."""
return self.client_type == ClientType.UDF
def is_udaf(self) -> bool:
"""Check if this is a UDAF (User-Defined Aggregate Function)."""
return self.client_type == ClientType.UDAF
def is_udtf(self) -> bool:
"""Check if this is a UDTF (User-Defined Table Function)."""
return self.client_type == ClientType.UDTF
def __str__(self) -> str:
"""Returns a string representation of the UDF metadata."""
udf_load_type_str = "INLINE" if self.udf_load_type == 0 else "MODULE"
return (
f"PythonUDFMeta(name={self.name}, symbol={self.symbol}, "
f"location={self.location}, udf_load_type={udf_load_type_str}, runtime_version={self.runtime_version}, "
f"always_nullable={self.always_nullable}, client_type={self.client_type.name}, "
f"input_types={self.input_types}, output_type={self.output_type})"
)
class AdaptivePythonUDF:
"""
A wrapper around a UDF function that supports both scalar and vectorized execution modes.
The mode is determined by the type hints of the function parameters.
"""
def __init__(self, python_udf_meta: PythonUDFMeta, func: Callable) -> None:
"""
Initialize the adaptive UDF wrapper.
Args:
python_udf_meta: Metadata describing the UDF
func: The actual Python function to execute
"""
self.python_udf_meta = python_udf_meta
self._eval_func = func
def __str__(self) -> str:
"""Returns a string representation of the UDF wrapper."""
input_type_strs = [str(t) for t in self.python_udf_meta.input_types.types]
output_type_str = str(self.python_udf_meta.output_type)
eval_func_str = f"{self.python_udf_meta.name}({', '.join(input_type_strs)}) -> {output_type_str}"
return f"AdaptivePythonUDF(python_udf_meta: {self.python_udf_meta}, eval_func: {eval_func_str})"
def __call__(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Executes the UDF on the given record batch. Supports both scalar and vectorized modes.
:param record_batch: Input data with N columns, each of length num_rows
:return: Output array of length num_rows
"""
if record_batch.num_rows == 0:
return pa.array([], type=self._get_output_type())
if self._should_use_vectorized():
logging.info("Using vectorized mode for UDF: %s", self.python_udf_meta.name)
return self._vectorized_call(record_batch)
logging.info("Using scalar mode for UDF: %s", self.python_udf_meta.name)
return self._scalar_call(record_batch)
@staticmethod
def _cast_arrow_to_vector(arrow_array: pa.Array, vec_type: VectorType):
"""
Convert a pa.Array to an instance of the specified VectorType.
"""
if vec_type == VectorType.LIST:
return arrow_array.to_pylist()
elif vec_type == VectorType.PANDAS_SERIES:
return arrow_array.to_pandas()
else:
raise ValueError(f"Unsupported vector type: {vec_type}")
def _should_use_vectorized(self) -> bool:
"""
Determines whether to use vectorized mode based on parameter type annotations.
Returns True if any parameter is annotated as:
- list
- pd.Series
"""
try:
signature = inspect.signature(self._eval_func)
except ValueError:
# Cannot inspect built-in or C functions; default to scalar
return False
for param in signature.parameters.values():
if VectorType.resolve_vector_type(param):
return True
return False
def _convert_from_arrow_to_py(self, field):
if field is None:
return None
if pa.types.is_map(field.type):
# pyarrow.lib.MapScalar's as_py() returns a list of tuples, convert to dict
list_of_tuples = field.as_py()
return dict(list_of_tuples) if list_of_tuples is not None else None
return field.as_py()
def _scalar_call(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Applies the UDF in scalar mode: one row at a time.
Args:
record_batch: Input data batch
Returns:
Output array with results for each row
"""
columns = record_batch.columns
num_rows = record_batch.num_rows
result = []
for i in range(num_rows):
converted_args = [self._convert_from_arrow_to_py(col[i]) for col in columns]
try:
res = self._eval_func(*converted_args)
# Check if result is None when always_nullable is False
if res is None and not self.python_udf_meta.always_nullable:
raise RuntimeError(
f"the result of row {i} is null, but the return type is not nullable, "
f"please check the always_nullable property in create function statement, "
f"it should be true"
)
result.append(res)
except Exception as e:
logging.error(
"Error in scalar UDF execution at row %s: %s\nArgs: %s\nTraceback: %s",
i,
e,
converted_args,
traceback.format_exc(),
)
# Return None for failed rows if always_nullable is True
if self.python_udf_meta.always_nullable:
result.append(None)
else:
raise
return pa.array(result, type=self._get_output_type(), from_pandas=True)
def _vectorized_call(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Applies the UDF in vectorized mode: processes entire columns at once.
Args:
record_batch: Input data batch
Returns:
Output array with results
"""
column_args = record_batch.columns
logging.info("Vectorized call with %s columns", len(column_args))
sig = inspect.signature(self._eval_func)
params = list(sig.parameters.values())
if len(column_args) != len(params):
raise ValueError(f"UDF expects {len(params)} args, got {len(column_args)}")
converted_args = []
for param, arrow_col in zip(params, column_args):
vec_type = VectorType.resolve_vector_type(param)
if vec_type is None:
# For scalar types (int, float, str, etc.), extract the first value
# instead of converting to list
pylist = arrow_col.to_pylist()
if len(pylist) > 0:
converted = pylist[0]
logging.info(
"Converted %s to scalar (first value): %s",
param.name,
type(converted).__name__,
)
else:
converted = None
logging.info(
"Converted %s to scalar (None, empty column)", param.name
)
else:
converted = self._cast_arrow_to_vector(arrow_col, vec_type)
logging.info("Converted %s: %s", param.name, vec_type)
converted_args.append(converted)
try:
result = self._eval_func(*converted_args)
except Exception as e:
logging.error(
"Error in vectorized UDF: %s\nTraceback: %s", e, traceback.format_exc()
)
raise RuntimeError(f"Error in vectorized UDF: {e}") from e
# Convert result to PyArrow Array
result_array = None
if isinstance(result, pa.Array):
result_array = result
elif isinstance(result, pa.ChunkedArray):
# Combine chunks into a single array
result_array = pa.concat_arrays(result.chunks)
elif isinstance(result, pd.Series):
result_array = pa.array(result, type=self._get_output_type())
elif isinstance(result, list):
result_array = pa.array(
result, type=self._get_output_type(), from_pandas=True
)
else:
# Scalar result - broadcast to all rows
out_type = self._get_output_type()
logging.warning(
"UDF returned scalar value, broadcasting to %s rows",
record_batch.num_rows,
)
result_array = pa.array([result] * record_batch.num_rows, type=out_type)
# Check for None values when always_nullable is False
if not self.python_udf_meta.always_nullable:
null_count = result_array.null_count
if null_count > 0:
# Find the first null index for error message
for i, value in enumerate(result_array):
if value.is_valid is False:
raise RuntimeError(
f"the result of row {i} is null, but the return type is not nullable, "
f"please check the always_nullable property in create function statement, "
f"it should be true"
)
return result_array
def _get_output_type(self) -> pa.DataType:
"""
Returns the expected output type for the UDF.
Returns:
PyArrow DataType for the output
"""
return self.python_udf_meta.output_type or pa.null()
class UDFLoader(ABC):
"""Abstract base class for loading UDFs from different sources."""
def __init__(self, python_udf_meta: PythonUDFMeta) -> None:
"""
Initialize the UDF loader.
Args:
python_udf_meta: Metadata describing the UDF to load
"""
self.python_udf_meta = python_udf_meta
@abstractmethod
def load(self) -> AdaptivePythonUDF:
"""Load the UDF and return an AdaptivePythonUDF wrapper."""
raise NotImplementedError("Subclasses must implement load().")
class InlineUDFLoader(UDFLoader):
"""Loads a UDF defined directly in inline code."""
def load(self) -> AdaptivePythonUDF:
"""
Load and execute inline Python code to extract the UDF function.
Returns:
AdaptivePythonUDF wrapper around the loaded function
Raises:
RuntimeError: If code execution fails
ValueError: If the function is not found or not callable
"""
symbol = self.python_udf_meta.symbol
inline_code = self.python_udf_meta.inline_code.decode("utf-8")
env: dict[str, Any] = {}
try:
# Execute the code in a clean environment
# pylint: disable=exec-used
# Note: exec() is necessary here for dynamic UDF loading from inline code
exec(inline_code, env) # nosec B102
except Exception as e:
logging.error(
"Failed to exec inline code: %s\nTraceback: %s",
e,
traceback.format_exc(),
)
raise RuntimeError(f"Failed to exec inline code: {e}") from e
func = env.get(symbol)
if func is None:
available_funcs = [
k for k, v in env.items() if callable(v) and not k.startswith("_")
]
logging.error(
"Function '%s' not found in inline code. Available functions: %s",
symbol,
available_funcs,
)
raise ValueError(f"Function '{symbol}' not found in inline code.")
if not callable(func):
logging.error(
"'%s' exists but is not callable (type: %s)", symbol, type(func)
)
raise ValueError(f"'{symbol}' is not a callable function.")
return AdaptivePythonUDF(self.python_udf_meta, func)
class ModuleUDFLoader(UDFLoader):
"""Loads a UDF from a Python module file (.py)."""
def load(self) -> AdaptivePythonUDF:
"""
Loads a UDF from a Python module file.
Returns:
AdaptivePythonUDF instance wrapping the loaded function
Raises:
ValueError: If module file not found
TypeError: If symbol is not callable
"""
symbol = self.python_udf_meta.symbol # [package_name.]module_name.function_name
location = self.python_udf_meta.location # /path/to/module_name[.py]
if not os.path.exists(location):
raise ValueError(f"Module file not found: {location}")
package_name, module_name, func_name = self.parse_symbol(symbol)
func = self.load_udf_from_module(location, package_name, module_name, func_name)
if not callable(func):
raise TypeError(
f"'{symbol}' exists but is not callable (type: {type(func).__name__})"
)
logging.info(
"Successfully loaded function '%s' from module: %s", symbol, location
)
return AdaptivePythonUDF(self.python_udf_meta, func)
def parse_symbol(self, symbol: str):
"""
Parse symbol into (package_name, module_name, func_name)
Supported formats:
- "module.func" → (None, module, func)
- "package.module.func" → (package, "module", func)
"""
if not symbol or "." not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. "
"Expected 'module.function' or 'package.module.function'"
)
parts = symbol.split(".")
if len(parts) == 2:
# module.func → Single-file mode
module_name, func_name = parts
package_name = None
if not module_name or not module_name.strip():
raise ValueError(f"Module name is empty in symbol: '{symbol}'")
if not func_name or not func_name.strip():
raise ValueError(f"Function name is empty in symbol: '{symbol}'")
elif len(parts) > 2:
package_name = parts[0]
module_name = ".".join(parts[1:-1])
func_name = parts[-1]
if not package_name or not package_name.strip():
raise ValueError(f"Package name is empty in symbol: '{symbol}'")
if not module_name or not module_name.strip():
raise ValueError(f"Module name is empty in symbol: '{symbol}'")
if not func_name or not func_name.strip():
raise ValueError(f"Function name is empty in symbol: '{symbol}'")
else:
raise ValueError(f"Invalid symbol format: '{symbol}'")
return package_name, module_name, func_name
def _validate_location(self, location: str) -> None:
"""Validate that the location is a valid directory."""
if not os.path.isdir(location):
raise ValueError(f"Location is not a directory: {location}")
def _get_or_import_module(self, location: str, full_module_name: str) -> Any:
"""Get module from cache or import it."""
if full_module_name in sys.modules:
logging.warning(
"Module '%s' already loaded, using cached version", full_module_name
)
return sys.modules[full_module_name]
with temporary_sys_path(location):
return importlib.import_module(full_module_name)
def _extract_function(
self, module: Any, func_name: str, module_name: str
) -> Callable:
"""Extract and validate function from module."""
func = getattr(module, func_name, None)
if func is None:
raise AttributeError(
f"Function '{func_name}' not found in module '{module_name}'"
)
if not callable(func):
raise TypeError(f"'{func_name}' is not callable")
return func
def _load_single_file_udf(
self, location: str, module_name: str, func_name: str
) -> Callable:
"""Load UDF from a single Python file."""
py_file = os.path.join(location, f"{module_name}.py")
if not os.path.isfile(py_file):
raise ImportError(f"Python file not found: {py_file}")
try:
udf_module = self._get_or_import_module(location, module_name)
return self._extract_function(udf_module, func_name, module_name)
except (ImportError, AttributeError, TypeError) as e:
raise ImportError(
f"Failed to load single-file UDF '{module_name}.{func_name}': {e}"
) from e
except Exception as e:
logging.error(
"Unexpected error loading UDF: %s\n%s", e, traceback.format_exc()
)
raise
def _ensure_package_init(self, package_path: str, package_name: str) -> None:
"""Ensure __init__.py exists in the package directory."""
init_path = os.path.join(package_path, "__init__.py")
if not os.path.exists(init_path):
logging.warning(
"__init__.py not found in package '%s', attempting to create it",
package_name,
)
try:
with open(init_path, "w", encoding="utf-8") as f:
f.write(
"# Auto-generated by UDF loader to make directory a Python package\n"
)
logging.info("Created __init__.py in %s", package_path)
except OSError as e:
raise ImportError(
f"Cannot create __init__.py in package '{package_name}': {e}"
) from e
def _build_full_module_name(self, package_name: str, module_name: str) -> str:
"""Build the full module name for package mode."""
if module_name == "__init__":
return package_name
return f"{package_name}.{module_name}"
def _load_package_udf(
self, location: str, package_name: str, module_name: str, func_name: str
) -> Callable:
"""Load UDF from a Python package."""
package_path = os.path.join(location, package_name)
if not os.path.isdir(package_path):
raise ImportError(f"Package '{package_name}' not found in '{location}'")
self._ensure_package_init(package_path, package_name)
try:
full_module_name = self._build_full_module_name(package_name, module_name)
udf_module = self._get_or_import_module(location, full_module_name)
return self._extract_function(udf_module, func_name, full_module_name)
except (ImportError, AttributeError, TypeError) as e:
raise ImportError(
f"Failed to load packaged UDF '{package_name}.{module_name}.{func_name}': {e}"
) from e
except Exception as e:
logging.error(
"Unexpected error loading packaged UDF: %s\n%s",
e,
traceback.format_exc(),
)
raise
def load_udf_from_module(
self,
location: str,
package_name: Optional[str],
module_name: str,
func_name: str,
) -> Callable:
"""
Load a UDF from a Python module, supporting both:
1. Single-file mode: package_name=None, module_name="your_file"
2. Package mode: package_name="your_pkg", module_name="submodule" or "__init__"
Args:
location:
- In package mode: parent directory of the package
- In single-file mode: directory containing the .py file
package_name:
- If None or empty: treat as single-file mode
- Else: standard package name
module_name:
- In package mode: submodule name (e.g., "main") or "__init__"
- In single-file mode: filename without .py (e.g., "udf_script")
func_name: name of the function to load
Returns:
The callable UDF function.
"""
self._validate_location(location)
if not package_name or package_name.strip() == "":
return self._load_single_file_udf(location, module_name, func_name)
else:
return self._load_package_udf(
location, package_name, module_name, func_name
)
class UDFLoaderFactory:
"""Factory to select the appropriate loader based on UDF location."""
@staticmethod
def get_loader(python_udf_meta: PythonUDFMeta) -> UDFLoader:
"""
Factory method to create the appropriate UDF loader based on metadata.
Args:
python_udf_meta: UDF metadata containing load type and location
Returns:
Appropriate UDFLoader instance (InlineUDFLoader or ModuleUDFLoader)
Raises:
ValueError: If UDF load type or location is unsupported
"""
location = python_udf_meta.location
udf_load_type = python_udf_meta.udf_load_type # 0: inline, 1: module
if udf_load_type == 0:
return InlineUDFLoader(python_udf_meta)
elif udf_load_type == 1:
if UDFLoaderFactory.check_module(location):
return ModuleUDFLoader(python_udf_meta)
else:
raise ValueError(f"Unsupported UDF location: {location}")
else:
raise ValueError(f"Unsupported UDF load type: {udf_load_type}")
@staticmethod
def check_module(location: str) -> bool:
"""
Checks if a location is a valid Python module or package.
A valid module is either:
- A .py file, or
- A directory containing __init__.py (i.e., a package).
Raises:
ValueError: If the location does not exist or contains no Python module.
Returns:
True if valid.
"""
if not os.path.exists(location):
raise ValueError(f"Module not found: {location}")
if os.path.isfile(location):
if location.endswith(".py"):
return True
else:
raise ValueError(f"File is not a Python module (.py): {location}")
if os.path.isdir(location):
if UDFLoaderFactory.has_python_file_recursive(location):
return True
else:
raise ValueError(
f"Directory contains no Python (.py) files: {location}"
)
raise ValueError(f"Invalid module location (not file or directory): {location}")
@staticmethod
def has_python_file_recursive(location: str) -> bool:
"""
Recursively checks if a directory contains any Python (.py) files.
Args:
location: Directory path to search
Returns:
True if at least one .py file is found, False otherwise
"""
path = Path(location)
if not path.is_dir():
return False
return any(path.rglob("*.py"))
class UDAFClassLoader:
"""
Utility class for loading UDAF classes from various sources.
This class is responsible for loading UDAF classes from:
- Inline code (embedded in SQL)
- Module files (imported from filesystem)
"""
@staticmethod
def load_udaf_class(python_udf_meta: PythonUDFMeta) -> type:
"""
Load the UDAF class from metadata.
Args:
python_udf_meta: UDAF metadata
Returns:
The UDAF class
Raises:
RuntimeError: If inline code execution fails
ValueError: If class is not found or invalid
"""
loader = UDFLoaderFactory.get_loader(python_udf_meta)
# For UDAF, we need the class, not an instance
if isinstance(loader, InlineUDFLoader):
return UDAFClassLoader.load_from_inline(python_udf_meta)
elif isinstance(loader, ModuleUDFLoader):
return UDAFClassLoader.load_from_module(python_udf_meta, loader)
else:
raise ValueError(f"Unsupported loader type: {type(loader)}")
@staticmethod
def load_from_inline(python_udf_meta: PythonUDFMeta) -> type:
"""
Load UDAF class from inline code.
Args:
python_udf_meta: UDAF metadata with inline code
Returns:
The UDAF class
"""
symbol = python_udf_meta.symbol
inline_code = python_udf_meta.inline_code.decode("utf-8")
env: dict[str, Any] = {}
try:
exec(inline_code, env) # nosec B102
except Exception as e:
raise RuntimeError(f"Failed to exec inline code: {e}") from e
udaf_class = env.get(symbol)
if udaf_class is None:
raise ValueError(f"UDAF class '{symbol}' not found in inline code")
if not inspect.isclass(udaf_class):
raise ValueError(f"'{symbol}' is not a class (type: {type(udaf_class)})")
UDAFClassLoader.validate_udaf_class(udaf_class)
return udaf_class
@staticmethod
def load_from_module(
python_udf_meta: PythonUDFMeta, loader: ModuleUDFLoader
) -> type:
"""
Load UDAF class from module file.
Args:
python_udf_meta: UDAF metadata with module location
loader: Module loader instance
Returns:
The UDAF class
"""
symbol = python_udf_meta.symbol
location = python_udf_meta.location
package_name, module_name, class_name = loader.parse_symbol(symbol)
udaf_class = loader.load_udf_from_module(
location, package_name, module_name, class_name
)
if not inspect.isclass(udaf_class):
raise ValueError(f"'{symbol}' is not a class (type: {type(udaf_class)})")
UDAFClassLoader.validate_udaf_class(udaf_class)
return udaf_class
@staticmethod
def validate_udaf_class(udaf_class: type):
"""
Validate that the UDAF class implements required methods.
Args:
udaf_class: The class to validate
Raises:
ValueError: If class doesn't implement required methods or properties
"""
required_methods = ["__init__", "accumulate", "merge", "finish"]
for method in required_methods:
if not hasattr(udaf_class, method):
raise ValueError(
f"UDAF class must implement '{method}' method. "
f"Missing in {udaf_class.__name__}"
)
# Check for aggregate_state property
if not hasattr(udaf_class, "aggregate_state"):
raise ValueError(
f"UDAF class must have 'aggregate_state' property. "
f"Missing in {udaf_class.__name__}"
)
# Verify it's actually a property
try:
attr = inspect.getattr_static(udaf_class, "aggregate_state")
if not isinstance(attr, property):
raise ValueError(
f"'aggregate_state' must be a @property in {udaf_class.__name__}"
)
except AttributeError:
raise ValueError(
f"UDAF class must have 'aggregate_state' property. "
f"Missing in {udaf_class.__name__}"
)
class UDAFStateManager:
"""
Manages UDAF aggregate states for Python UDAF execution.
This class maintains a mapping from place_id to UDAF instances,
following the Snowflake UDAF pattern:
- __init__(): Initialize state
- aggregate_state: Property returning serializable state
- accumulate(*args): Add input values
- merge(other_state): Merge two states
- finish(): Return final result
"""
def __init__(self):
"""Initialize the state manager."""
self.states: Dict[int, Any] = {} # place_id -> UDAF instance
self.udaf_class = None # UDAF class to instantiate
self.lock = threading.Lock() # Thread-safe state access
def set_udaf_class(self, udaf_class: type):
"""
Set the UDAF class to use for creating instances.
Args:
udaf_class: The UDAF class
Note:
Validation is performed by UDAFClassLoader before calling this method.
"""
self.udaf_class = udaf_class
def create_state(self, place_id: int) -> None:
"""
Create a new UDAF state for the given place_id.
Args:
place_id: Unique identifier for this aggregate state
Raises:
RuntimeError: If state already exists for this place_id or UDAF class not set
"""
with self.lock:
if place_id in self.states:
# This should never happen
error_msg = (
f"State for place_id {place_id} already exists. "
f"CREATE should only be called once per place_id. "
f"This indicates a bug in C++ side or state management."
)
logging.error(error_msg)
raise RuntimeError(error_msg)
if self.udaf_class is None:
raise RuntimeError("UDAF class not set. Call set_udaf_class() first.")
try:
self.states[place_id] = self.udaf_class()
except Exception as e:
logging.error("Failed to create UDAF state: %s", e)
raise RuntimeError(f"Failed to create UDAF state: {e}") from e
def get_state(self, place_id: int) -> Any:
"""
Get the UDAF state for the given place_id.
Args:
place_id: Unique identifier for the aggregate state
Returns:
The UDAF instance
"""
with self.lock:
if place_id not in self.states:
raise KeyError(f"State for place_id {place_id} not found")
return self.states[place_id]
def accumulate(self, place_id: int, *args) -> None:
"""
Accumulate input values into the aggregate state.
Args:
place_id: Unique identifier for the aggregate state
*args: Input values to accumulate
"""
state = self.get_state(place_id)
try:
state.accumulate(*args)
except Exception as e:
logging.error("Error in accumulate for place_id %s: %s", place_id, e)
raise RuntimeError(f"Error in accumulate: {e}") from e
def serialize(self, place_id: int) -> bytes:
"""
Serialize the aggregate state to bytes.
Args:
place_id: Unique identifier for the aggregate state
Returns:
Serialized state as bytes (using pickle)
Note:
If the state doesn't exist, creates an empty state and serializes it.
This can happen in distributed scenarios where a node receives no data.
"""
with self.lock:
if place_id not in self.states:
# State doesn't exist - create empty state and serialize it
logging.warning(
"SERIALIZE: State for place_id %s not found, creating empty state",
place_id,
)
if self.udaf_class is None:
raise RuntimeError(
"UDAF class not set. Call set_udaf_class() first before serialize operation."
)
try:
self.states[place_id] = self.udaf_class()
except Exception as e:
logging.error("Failed to create UDAF state for serialize: %s", e)
raise RuntimeError(
f"Failed to create UDAF state for serialize: {e}"
) from e
state = self.states[place_id]
try:
aggregate_state = state.aggregate_state
serialized = pickle.dumps(aggregate_state)
logging.info(
"SERIALIZE: place_id=%s state=%s bytes=%d",
place_id,
aggregate_state,
len(serialized),
)
return serialized
except Exception as e:
logging.error("Error serializing state for place_id %s: %s", place_id, e)
raise RuntimeError(f"Error serializing state: {e}") from e
def merge(self, place_id: int, other_state_bytes: bytes) -> None:
"""
Merge another serialized state into this state.
Args:
place_id: Unique identifier for the aggregate state
other_state_bytes: Serialized state to merge (pickle bytes)
Note:
If the state doesn't exist, it will be created first.
This handles both normal merge and deserialize_and_merge scenarios.
"""
# Deserialize the incoming state
try:
other_state = pickle.loads(other_state_bytes)
logging.info(
"MERGE: Deserialized state for place_id %s: %s", place_id, other_state
)
except Exception as e:
logging.error("Error deserializing state bytes: %s", e)
raise RuntimeError(f"Error deserializing state: {e}") from e
with self.lock:
if place_id not in self.states:
# Create state if it doesn't exist
# This can happen in shuffle scenarios where deserialize_and_merge
# is called before any local aggregation
logging.info(
"MERGE: Creating new state for place_id %s (state doesn't exist)",
place_id,
)
if self.udaf_class is None:
raise RuntimeError(
"UDAF class not set. Call set_udaf_class() first before merge operation."
)
try:
self.states[place_id] = self.udaf_class()
logging.info(
"MERGE: Created new state, initial value: %s",
self.states[place_id].aggregate_state,
)
except Exception as e:
logging.error("Failed to create UDAF state for merge: %s", e)
raise RuntimeError(
f"Failed to create UDAF state for merge: {e}"
) from e
state = self.states[place_id]
before_merge = state.aggregate_state
# Merge the deserialized state
try:
state.merge(other_state)
after_merge = state.aggregate_state
logging.info(
"MERGE: place_id=%s before=%s + other=%s → after=%s",
place_id,
before_merge,
other_state,
after_merge,
)
except Exception as e:
logging.error("Error in merge for place_id %s: %s", place_id, e)
raise RuntimeError(f"Error in merge: {e}") from e
def finalize(self, place_id: int) -> Any:
"""
Get the final result from the aggregate state.
Args:
place_id: Unique identifier for the aggregate state
Returns:
Final aggregation result
Note:
If the state doesn't exist, creates an empty state and returns its finish() result.
This can happen in distributed scenarios where a node receives no data for aggregation.
"""
with self.lock:
if place_id not in self.states:
# State doesn't exist - create empty state and finalize it
logging.warning(
"FINALIZE: State for place_id %s not found, creating empty state",
place_id,
)
if self.udaf_class is None:
raise RuntimeError(
"UDAF class not set. Call set_udaf_class() first before finalize operation."
)
try:
self.states[place_id] = self.udaf_class()
except Exception as e:
logging.error("Failed to create UDAF state for finalize: %s", e)
raise RuntimeError(
f"Failed to create UDAF state for finalize: {e}"
) from e
state = self.states[place_id]
try:
result = state.finish()
logging.info("FINALIZE: place_id=%s result=%s", place_id, result)
return result
except Exception as e:
logging.error("Error finalizing state for place_id %s: %s", place_id, e)
raise RuntimeError(f"Error finalizing state: {e}") from e
def reset(self, place_id: int) -> None:
"""
Reset the aggregate state (for window functions).
Args:
place_id: Unique identifier for the aggregate state
Raises:
RuntimeError: If state does not exist for this place_id or UDAF class not set
"""
with self.lock:
if place_id not in self.states:
error_msg = (
f"Attempted to reset non-existent state for place_id {place_id}. "
f"RESET should only be called on existing states. "
f"This indicates a bug in state lifecycle management."
)
logging.error(error_msg)
raise RuntimeError(error_msg)
if self.udaf_class is None:
raise RuntimeError("UDAF class not set. Call set_udaf_class() first.")
try:
self.states[place_id] = self.udaf_class()
except Exception as e:
logging.error("Error resetting state for place_id %s: %s", place_id, e)
raise RuntimeError(f"Error resetting state: {e}") from e
def destroy(self, place_id: int) -> None:
"""
Destroy the aggregate state and free resources.
Args:
place_id: Unique identifier for the aggregate state
Raises:
RuntimeError: If state does not exist for this place_id
"""
with self.lock:
if place_id not in self.states:
error_msg = (
f"Attempted to destroy non-existent state for place_id {place_id}. "
f"State was either never created or already destroyed. "
f"This indicates a bug in state lifecycle management."
)
logging.error(error_msg)
raise RuntimeError(error_msg)
del self.states[place_id]
def clear_all(self) -> None:
"""Clear all states (for cleanup)."""
with self.lock:
count = len(self.states)
self.states.clear()
logging.info("Cleared all %d UDAF states", count)
class FlightServer(flight.FlightServerBase):
"""Arrow Flight server for executing Python UDFs, UDAFs, and UDTFs."""
def __init__(self, location: str):
"""
Initialize the Flight server.
Args:
location: Unix socket path for the server
"""
super().__init__(location)
# Use a dictionary to maintain separate state managers for each UDAF function
# Key: function signature (name + input_types), Value: UDAFStateManager instance
self.udaf_state_managers: Dict[str, UDAFStateManager] = {}
self.udaf_managers_lock = threading.Lock()
logging.info("Flight server initialized at: %s", location)
def _get_udaf_state_manager(
self, python_udaf_meta: PythonUDFMeta
) -> UDAFStateManager:
"""
Get or create a state manager for the given UDAF function.
Each UDAF function gets its own independent state manager.
Args:
python_udaf_meta: Metadata for the UDAF function
Returns:
UDAFStateManager instance for this specific UDAF
"""
# Create a unique key based on function name and argument types
type_names = [str(field.type) for field in python_udaf_meta.input_types]
func_key = f"{python_udaf_meta.name}({','.join(type_names)})"
with self.udaf_managers_lock:
if func_key not in self.udaf_state_managers:
manager = UDAFStateManager()
# Load and set the UDAF class for this manager using UDAFClassLoader
udaf_class = UDAFClassLoader.load_udaf_class(python_udaf_meta)
manager.set_udaf_class(udaf_class)
self.udaf_state_managers[func_key] = manager
logging.info(
"Created new state manager for UDAF: %s with class: %s",
func_key,
udaf_class.__name__,
)
return self.udaf_state_managers[func_key]
@staticmethod
def parse_python_udf_meta(
descriptor: flight.FlightDescriptor,
) -> Optional[PythonUDFMeta]:
"""
Parses UDF/UDAF/UDTF metadata from a command descriptor.
Returns:
PythonUDFMeta object containing the function metadata
"""
if descriptor.descriptor_type != flight.DescriptorType.CMD:
logging.error("Invalid descriptor type: %s", descriptor.descriptor_type)
return None
cmd_json = json.loads(descriptor.command)
name = cmd_json["name"]
symbol = cmd_json["symbol"]
location = cmd_json["location"]
udf_load_type = cmd_json["udf_load_type"]
runtime_version = cmd_json["runtime_version"]
always_nullable = cmd_json["always_nullable"]
# client_type: 0: UDF, 1: UDAF, 2: UDTF
client_type = cmd_json["client_type"]
inline_code = base64.b64decode(cmd_json["inline_code"])
input_binary = base64.b64decode(cmd_json["input_types"])
output_binary = base64.b64decode(cmd_json["return_type"])
input_schema = pa.ipc.read_schema(pa.BufferReader(input_binary))
output_schema = pa.ipc.read_schema(pa.BufferReader(output_binary))
if len(output_schema) != 1:
logging.error(
"Output schema must have exactly one field: %s", output_schema
)
return None
output_type = output_schema.field(0).type
python_udf_meta = PythonUDFMeta(
name=name,
symbol=symbol,
location=location,
udf_load_type=udf_load_type,
runtime_version=runtime_version,
always_nullable=always_nullable,
inline_code=inline_code,
input_types=input_schema,
output_type=output_type,
client_type=client_type,
)
return python_udf_meta
@staticmethod
def check_schema(
record_batch: pa.RecordBatch, expected_schema: pa.Schema
) -> Tuple[bool, str]:
"""
Validates that the input RecordBatch schema matches the expected schema.
Checks that field count and types match, but field names can differ.
:return: (result, error_message)
"""
actual = record_batch.schema
expected = expected_schema
# Check field count
if len(actual) != len(expected):
return (
False,
f"Schema length mismatch, got {len(actual)} fields, expected {len(expected)} fields",
)
# Check each field type (ignore field names)
for i, (actual_field, expected_field) in enumerate(zip(actual, expected)):
if not actual_field.type.equals(expected_field.type):
return False, (
f"Type mismatch at field index {i}, "
f"got {actual_field.type}, expected {expected_field.type}"
)
return True, ""
def _create_unified_response(self, result_batch: pa.RecordBatch) -> pa.RecordBatch:
"""
Convert operation-specific result to unified response schema.
Unified Response Schema: [result_data: binary]
- Serializes the result RecordBatch to Arrow IPC format
"""
# Serialize result_batch to binary
sink = pa.BufferOutputStream()
writer = pa.ipc.new_stream(sink, result_batch.schema)
writer.write_batch(result_batch)
writer.close()
result_buffer = sink.getvalue()
# Convert pyarrow.Buffer to bytes
result_binary = result_buffer.to_pybytes()
# Create unified response with single binary column
return pa.RecordBatch.from_arrays(
[pa.array([result_binary], type=pa.binary())], ["result_data"]
)
def _handle_udaf_create(
self, place_id: int, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF CREATE operation.
Returns: [success: bool]
"""
try:
state_manager.create_state(place_id)
success = True
except Exception as e:
logging.error("CREATE operation failed for place_id=%s: %s", place_id, e)
success = False
return pa.RecordBatch.from_arrays(
[pa.array([success], type=pa.bool_())], ["success"]
)
def _handle_udaf_accumulate(
self,
place_id: int,
metadata_binary: bytes,
data_binary: bytes,
state_manager: UDAFStateManager,
) -> pa.RecordBatch:
"""Handle UDAF ACCUMULATE operation.
Deserializes metadata and data from binary buffers.
Metadata contains: [is_single_place: bool, row_start: int64, row_end: int64, place_offset: int64]
Data contains: input columns (+ optional places array for GROUP BY)
Returns: [rows_processed: int64] (0 if failed)
"""
rows_processed = 0
try:
# Validate inputs
if metadata_binary is None or data_binary is None:
raise ValueError(
f"ACCUMULATE requires both metadata and data, got metadata={metadata_binary is not None}, data={data_binary is not None}"
)
logging.info(
"ACCUMULATE: Starting deserialization, metadata size=%d, data size=%d",
len(metadata_binary),
len(data_binary),
)
# Deserialize metadata
metadata_reader = pa.ipc.open_stream(metadata_binary)
metadata_batch = metadata_reader.read_next_batch()
is_single_place = metadata_batch.column(0)[0].as_py()
row_start = metadata_batch.column(1)[0].as_py()
row_end = metadata_batch.column(2)[0].as_py()
# place_offset = metadata_batch.column(3)[0].as_py() # Not used currently
# Deserialize data (input columns + optional places)
data_reader = pa.ipc.open_stream(data_binary)
data_batch = data_reader.read_next_batch()
logging.info(
"ACCUMULATE: place_id=%s, is_single_place=%s, row_start=%s, row_end=%s, data_rows=%s, data_cols=%s",
place_id,
is_single_place,
row_start,
row_end,
data_batch.num_rows,
data_batch.num_columns,
)
# Check if there's a places column
has_places = (
data_batch.schema.field(data_batch.num_columns - 1).name == "places"
)
num_input_cols = (
data_batch.num_columns - 1 if has_places else data_batch.num_columns
)
logging.info(
"ACCUMULATE: has_places=%s, num_input_cols=%s",
has_places,
num_input_cols,
)
# Calculate loop range
loop_start = row_start
loop_end = min(row_end, data_batch.num_rows)
logging.info(
"ACCUMULATE: Loop range [%d, %d), will process %d rows",
loop_start,
loop_end,
loop_end - loop_start,
)
if is_single_place:
# Single place: accumulate all rows to the same state
for i in range(loop_start, loop_end):
args = [
data_batch.column(j)[i].as_py() for j in range(num_input_cols)
]
state_manager.accumulate(place_id, *args)
rows_processed += 1
else:
# Multiple places: get place_ids from the last column
places_col = data_batch.column(data_batch.num_columns - 1)
logging.info(
"ACCUMULATE: Multiple places mode, places column data: %s",
[places_col[i].as_py() for i in range(data_batch.num_rows)],
)
for i in range(loop_start, loop_end):
row_place_id = places_col[i].as_py()
args = [
data_batch.column(j)[i].as_py() for j in range(num_input_cols)
]
logging.info(
"ACCUMULATE: Processing row %d, row_place_id=%s, args=%s",
i,
row_place_id,
args,
)
state_manager.accumulate(row_place_id, *args)
rows_processed += 1
logging.info(
"ACCUMULATE: Completed successfully, rows_processed=%d", rows_processed
)
except Exception as e:
logging.error(
"ACCUMULATE operation failed at row %d: %s\nTraceback: %s",
rows_processed,
e,
traceback.format_exc(),
)
return pa.RecordBatch.from_arrays(
[pa.array([rows_processed], type=pa.int64())], ["rows_processed"]
)
def _handle_udaf_serialize(
self, place_id: int, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF SERIALIZE operation.
Returns: [serialized_state: binary] (empty if failed)
"""
try:
serialized = state_manager.serialize(place_id)
except Exception as e:
logging.error("SERIALIZE operation failed for place_id=%s: %s", place_id, e)
serialized = b"" # Return empty binary on failure
return pa.RecordBatch.from_arrays(
[pa.array([serialized], type=pa.binary())], ["serialized_state"]
)
def _handle_udaf_merge(
self, place_id: int, data_binary: bytes, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF MERGE operation.
data_binary contains the serialized state to merge.
Returns: [success: bool]
"""
try:
# Validate input
if data_binary is None:
raise ValueError(f"MERGE requires data_binary, got None")
state_manager.merge(place_id, data_binary)
success = True
except Exception as e:
logging.error("MERGE operation failed for place_id=%s: %s", place_id, e)
success = False
return pa.RecordBatch.from_arrays(
[pa.array([success], type=pa.bool_())], ["success"]
)
def _handle_udaf_finalize(
self, place_id: int, output_type: pa.DataType, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF FINALIZE operation.
Returns: [result: output_type] (null if failed)
"""
try:
result = state_manager.finalize(place_id)
except Exception as e:
logging.error("FINALIZE operation failed for place_id=%s: %s", place_id, e)
result = None # Return null on failure
return pa.RecordBatch.from_arrays(
[pa.array([result], type=output_type)], ["result"]
)
def _handle_udaf_reset(
self, place_id: int, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF RESET operation.
Returns: [success: bool]
"""
try:
state_manager.reset(place_id)
success = True
except Exception as e:
logging.error("RESET operation failed for place_id=%s: %s", place_id, e)
success = False
return pa.RecordBatch.from_arrays(
[pa.array([success], type=pa.bool_())], ["success"]
)
def _handle_udaf_destroy(
self, place_id: int, state_manager: UDAFStateManager
) -> pa.RecordBatch:
"""Handle UDAF DESTROY operation.
Returns: [success: bool]
"""
try:
state_manager.destroy(place_id)
success = True
except Exception as e:
logging.error("DESTROY operation failed for place_id=%s: %s", place_id, e)
success = False
return pa.RecordBatch.from_arrays(
[pa.array([success], type=pa.bool_())], ["success"]
)
def _handle_exchange_udf(
self,
python_udf_meta: PythonUDFMeta,
reader: flight.MetadataRecordBatchReader,
writer: flight.MetadataRecordBatchWriter,
) -> None:
"""Handle bidirectional streaming for UDF execution."""
loader = UDFLoaderFactory.get_loader(python_udf_meta)
udf = loader.load()
logging.info("Loaded UDF: %s", udf)
started = False
for chunk in reader:
if not chunk.data:
logging.info("Empty chunk received, skipping")
continue
check_schema_result, error_msg = self.check_schema(
chunk.data, python_udf_meta.input_types
)
if not check_schema_result:
logging.error("Schema mismatch: %s", error_msg)
raise ValueError(f"Schema mismatch: {error_msg}")
result_array = udf(chunk.data)
if not python_udf_meta.output_type.equals(result_array.type):
logging.error(
"Output type mismatch: got %s, expected %s",
result_array.type,
python_udf_meta.output_type,
)
raise ValueError(
f"Output type mismatch: got {result_array.type}, expected {python_udf_meta.output_type}"
)
result_batch = pa.RecordBatch.from_arrays([result_array], ["result"])
if not started:
writer.begin(result_batch.schema)
started = True
writer.write_batch(result_batch)
def _handle_exchange_udaf(
self,
python_udaf_meta: PythonUDFMeta,
reader: flight.MetadataRecordBatchReader,
writer: flight.MetadataRecordBatchWriter,
) -> None:
"""
Handle bidirectional streaming for UDAF execution.
Request Schema (unified for ALL operations):
- operation: int8 - UDAFOperationType enum value
- place_id: int64 - Unique identifier for the aggregate state
- metadata: binary - Serialized metadata (operation-specific)
- data: binary - Serialized data (operation-specific)
Response Schema (unified for ALL operations):
- result_data: binary - Serialized result RecordBatch in Arrow IPC format
"""
# Get or create state manager for this specific UDAF function
state_manager = self._get_udaf_state_manager(python_udaf_meta)
# Define unified response schema (used for all responses)
unified_response_schema = pa.schema([pa.field("result_data", pa.binary())])
started = False
for chunk in reader:
if not chunk.data or chunk.data.num_rows == 0:
logging.warning("Empty chunk received, skipping")
continue
batch = chunk.data
# Validate unified request schema
if batch.num_columns != 4:
raise ValueError(
f"Expected 4 columns in unified schema, got {batch.num_columns}"
)
# Extract metadata from unified schema
operation_type = UDAFOperationType(batch.column(0)[0].as_py())
place_id = batch.column(1)[0].as_py()
# Check if metadata/data columns are null
metadata_col = batch.column(2)
data_col = batch.column(3)
metadata_binary = (
metadata_col[0].as_py() if metadata_col[0].is_valid else None
)
data_binary = data_col[0].as_py() if data_col[0].is_valid else None
logging.info(
"Processing UDAF operation: %s, place_id: %s",
operation_type.name,
place_id,
)
# Handle different operations - get operation-specific result
if operation_type == UDAFOperationType.CREATE:
result_batch = self._handle_udaf_create(place_id, state_manager)
elif operation_type == UDAFOperationType.ACCUMULATE:
result_batch = self._handle_udaf_accumulate(
place_id, metadata_binary, data_binary, state_manager
)
elif operation_type == UDAFOperationType.SERIALIZE:
result_batch = self._handle_udaf_serialize(place_id, state_manager)
elif operation_type == UDAFOperationType.MERGE:
result_batch = self._handle_udaf_merge(
place_id, data_binary, state_manager
)
elif operation_type == UDAFOperationType.FINALIZE:
result_batch = self._handle_udaf_finalize(
place_id, python_udaf_meta.output_type, state_manager
)
elif operation_type == UDAFOperationType.RESET:
result_batch = self._handle_udaf_reset(place_id, state_manager)
elif operation_type == UDAFOperationType.DESTROY:
result_batch = self._handle_udaf_destroy(place_id, state_manager)
else:
raise ValueError(f"Unsupported operation type: {operation_type}")
# Convert to unified response format
unified_response = self._create_unified_response(result_batch)
# Write result back - begin only once with unified response schema
if not started:
writer.begin(unified_response_schema)
started = True
logging.info("Initialized UDAF response stream with unified schema")
writer.write_batch(unified_response)
def _handle_exchange_udtf(
self,
python_udtf_meta: PythonUDFMeta,
reader: flight.MetadataRecordBatchReader,
writer: flight.MetadataRecordBatchWriter,
) -> None:
"""
Handle bidirectional streaming for UDTF execution.
Protocol (ListArray-based):
- Input: RecordBatch with input columns
- Output: RecordBatch with a single ListArray column
* ListArray automatically manages offsets internally
* Each list element contains the outputs for one input row
Example:
Input: 3 rows
UDTF yields: Row 0 -> 5 outputs, Row 1 -> 2 outputs, Row 2 -> 3 outputs
Output: ListArray with 3 elements (one per input row)
- Element 0: List of 5 structs
- Element 1: List of 2 structs
- Element 2: List of 3 structs
"""
loader = UDFLoaderFactory.get_loader(python_udtf_meta)
adaptive_udtf = loader.load()
udtf_func = adaptive_udtf._eval_func
started = False
for chunk in reader:
if not chunk.data:
logging.info("Empty chunk received, skipping")
continue
input_batch = chunk.data
# Validate input schema
check_schema_result, error_msg = self.check_schema(
input_batch, python_udtf_meta.input_types
)
if not check_schema_result:
logging.error("Schema mismatch: %s", error_msg)
raise ValueError(f"Schema mismatch: {error_msg}")
# Process all input rows and build ListArray
try:
response_batch = self._process_udtf_with_list_array(
udtf_func, input_batch, python_udtf_meta.output_type
)
# Send the response batch
if not started:
writer.begin(response_batch.schema)
started = True
writer.write_batch(response_batch)
except Exception as e:
logging.error(
"Error in UDTF execution: %s\nTraceback: %s",
e,
traceback.format_exc(),
)
raise RuntimeError(f"Error in UDTF execution: {e}") from e
def _process_udtf_with_list_array(
self,
udtf_func: Callable,
input_batch: pa.RecordBatch,
expected_output_type: pa.DataType,
) -> pa.RecordBatch:
"""
Process UDTF function on all input rows and generate a ListArray.
Args:
udtf_func: The UDTF function to call
input_batch: Input RecordBatch with N rows
expected_output_type: Expected Arrow type for output data
Returns:
RecordBatch with a single ListArray column where each element
is a list of outputs for the corresponding input row
"""
all_results = [] # List of lists: one list per input row
# Check if output is single-field or multi-field
# For single-field output, we allow yielding scalar values directly
is_single_field = not pa.types.is_struct(expected_output_type)
# Process each input row
for row_idx in range(input_batch.num_rows):
# Extract row as tuple of arguments
row_args = tuple(
input_batch.column(col_idx)[row_idx].as_py()
for col_idx in range(input_batch.num_columns)
)
# Call UDTF function - it can yield tuples or scalar values (for single-field output)
result = udtf_func(*row_args)
# Collect output rows for this input row
row_outputs = []
if inspect.isgenerator(result):
for output_value in result:
if is_single_field:
# Single-field output: accept both scalar and tuple
if isinstance(output_value, tuple):
# User provided tuple (e.g., (value,)) - extract scalar
if len(output_value) != 1:
raise ValueError(
f"Single-field UDTF should yield 1-tuples or scalars, got {len(output_value)}-tuple"
)
row_outputs.append(output_value[0]) # Extract scalar from tuple
else:
# User provided scalar - use directly
row_outputs.append(output_value)
else:
# Multi-field output: must be tuple
if not isinstance(output_value, tuple):
raise ValueError(
f"Multi-field UDTF must yield tuples, got {type(output_value)}"
)
row_outputs.append(output_value)
elif result is not None:
# Function returned a single value instead of yielding
if is_single_field:
# Single-field: accept scalar or tuple
if isinstance(result, tuple):
if len(result) != 1:
raise ValueError(
f"Single-field UDTF should return 1-tuple or scalar, got {len(result)}-tuple"
)
row_outputs.append(result[0]) # Extract scalar from tuple
else:
row_outputs.append(result)
else:
# Multi-field: must be tuple
if not isinstance(result, tuple):
raise ValueError(
f"Multi-field UDTF must return tuples, got {type(result)}"
)
row_outputs.append(result)
all_results.append(row_outputs)
try:
list_array = pa.array(all_results, type=pa.list_(expected_output_type))
except Exception as e:
logging.error(
"Failed to create ListArray: %s, element_type: %s", e, expected_output_type
)
raise RuntimeError(f"Failed to create ListArray: {e}") from e
# Create RecordBatch with single ListArray column
schema = pa.schema([pa.field("results", pa.list_(expected_output_type))])
response_batch = pa.RecordBatch.from_arrays([list_array], schema=schema)
return response_batch
def do_exchange(
self,
context: flight.ServerCallContext,
descriptor: flight.FlightDescriptor,
reader: flight.MetadataRecordBatchReader,
writer: flight.MetadataRecordBatchWriter,
) -> None:
"""
Handle bidirectional streaming for UDF, UDAF, and UDTF execution.
Determines operation type (UDF vs UDAF vs UDTF) from descriptor metadata.
"""
python_udf_meta = self.parse_python_udf_meta(descriptor)
if not python_udf_meta:
raise ValueError("Invalid or missing metadata in descriptor")
if python_udf_meta.is_udf():
self._handle_exchange_udf(python_udf_meta, reader, writer)
elif python_udf_meta.is_udaf():
self._handle_exchange_udaf(python_udf_meta, reader, writer)
elif python_udf_meta.is_udtf():
self._handle_exchange_udtf(python_udf_meta, reader, writer)
else:
raise ValueError(f"Unsupported client type: {python_udf_meta.client_type}")
class UDAFOperationType(Enum):
"""Enum representing UDAF operation types."""
CREATE = 0
ACCUMULATE = 1
SERIALIZE = 2
MERGE = 3
FINALIZE = 4
RESET = 5
DESTROY = 6
def check_unix_socket_path(unix_socket_path: str) -> bool:
"""Validates the Unix domain socket path format."""
if not unix_socket_path:
logging.error("Unix socket path is empty")
return False
if not unix_socket_path.startswith("grpc+unix://"):
raise ValueError("gRPC UDS URL must start with 'grpc+unix://'")
socket_path = unix_socket_path[len("grpc+unix://") :].strip()
if not socket_path:
logging.error("Extracted socket path is empty")
return False
return True
def main(unix_socket_path: str) -> None:
"""
Main entry point for the Python UDF/UDAF/UDTF server.
The server handles UDF, UDAF, and UDTF operations dynamically.
Operation type is determined from metadata in each request.
Args:
unix_socket_path: Base path for the Unix domain socket
Raises:
SystemExit: If socket path is invalid or server fails to start
"""
try:
if not check_unix_socket_path(unix_socket_path):
print(f"ERROR: Invalid socket path: {unix_socket_path}", flush=True)
sys.exit(1)
current_pid = os.getpid()
ServerState.unix_socket_path = f"{unix_socket_path}_{current_pid}.sock"
# Start unified server that handles UDF, UDAF, and UDTF
server = FlightServer(ServerState.unix_socket_path)
print(ServerState.PYTHON_SERVER_START_SUCCESS_MSG, flush=True)
logging.info(
"##### PYTHON UDF/UDAF/UDTF SERVER STARTED AT %s #####", datetime.now()
)
server.wait()
except Exception as e:
print(
f"ERROR: Failed to start Python server: {type(e).__name__}: {e}",
flush=True,
)
tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
if len(tb_lines) > 1:
print(f"DETAIL: {tb_lines[-2].strip()}", flush=True)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run an Arrow Flight UDF/UDAF/UDTF server over Unix socket. "
"The server handles UDF, UDAF, and UDTF operations dynamically."
)
parser.add_argument(
"unix_socket_path",
type=str,
help="Path to the Unix socket (e.g., grpc+unix:///path/to/socket)",
)
args = parser.parse_args()
main(args.unix_socket_path)