blob: 7095b52ee27d05184379c58cebc29bb2493ed45d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import base64
import importlib
import inspect
import json
import sys
import os
import traceback
import logging
import time
import threading
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Callable, Optional, Tuple, get_origin
from datetime import datetime
from enum import Enum
from pathlib import Path
import pandas as pd
import pyarrow as pa
from pyarrow import flight
class ServerState:
"""Global server state container."""
unix_socket_path: str = ""
@staticmethod
def setup_logging():
"""Setup logging configuration for the UDF server."""
doris_home = os.getenv("DORIS_HOME")
if not doris_home:
# Fallback to current directory if DORIS_HOME is not set
doris_home = os.getcwd()
log_dir = os.path.join(doris_home, "lib", "udf", "python")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "python_udf_output.log")
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
handlers=[
logging.FileHandler(log_file, mode="a", encoding="utf-8"),
logging.StreamHandler(sys.stderr), # Also log to stderr for debugging
],
)
logging.info("Logging initialized. Log file: %s", log_file)
@staticmethod
def extract_base_unix_socket_path(unix_socket_uri: str) -> str:
"""
Extract the file system path from a gRPC Unix socket URI.
Args:
unix_socket_uri: URI in format 'grpc+unix:///path/to/socket'
Returns:
The file system path without the protocol prefix
"""
if unix_socket_uri.startswith("grpc+unix://"):
unix_socket_uri = unix_socket_uri[len("grpc+unix://") :]
return unix_socket_uri
@staticmethod
def remove_unix_socket(unix_socket_uri: str) -> None:
"""
Remove the Unix domain socket file if it exists.
Args:
unix_socket_uri: URI of the Unix socket to remove
"""
if unix_socket_uri is None:
return
base_unix_socket_path = ServerState.extract_base_unix_socket_path(
unix_socket_uri
)
if os.path.exists(base_unix_socket_path):
try:
os.unlink(base_unix_socket_path)
logging.info(
"Removed UNIX socket %s successfully", base_unix_socket_path
)
except OSError as e:
logging.error(
"Failed to remove UNIX socket %s: %s", base_unix_socket_path, e
)
else:
logging.warning("UNIX socket %s does not exist", base_unix_socket_path)
@staticmethod
def monitor_parent_exit():
"""
Monitor the parent process and exit gracefully if it dies.
This prevents orphaned UDF server processes.
"""
parent_pid = os.getppid()
if parent_pid == 1:
# Parent process is init, no need to monitor
logging.info("Parent process is init (PID 1), skipping parent monitoring")
return
logging.info("Started monitoring parent process (PID: %s)", parent_pid)
while True:
try:
# os.kill(pid, 0) only checks whether the process exists
# without sending an actual signal
os.kill(parent_pid, 0)
except OSError:
# Parent process died
ServerState.remove_unix_socket(ServerState.unix_socket_path)
logging.error(
"Parent process %s died, exiting UDF server, unix socket path: %s",
parent_pid,
ServerState.unix_socket_path,
)
os._exit(0)
# Check every 2 seconds
time.sleep(2)
ServerState.setup_logging()
monitor_thread = threading.Thread(target=ServerState.monitor_parent_exit, daemon=True)
monitor_thread.start()
@contextmanager
def temporary_sys_path(path: str):
"""
Context manager to temporarily add a path to sys.path.
Ensures the path is removed after use to avoid pollution.
Args:
path: Directory path to add to sys.path
Yields:
None
"""
path_added = False
if path not in sys.path:
sys.path.insert(0, path)
path_added = True
logging.debug("Temporarily added to sys.path: %s", path)
try:
yield
finally:
if path_added and path in sys.path:
sys.path.remove(path)
logging.debug("Removed from sys.path: %s", path)
class VectorType(Enum):
"""Enum representing supported vector types."""
LIST = "list"
PANDAS_SERIES = "pandas.Series"
ARROW_ARRAY = "pyarrow.Array"
@property
def python_type(self):
"""
Returns the Python type corresponding to this VectorType.
Returns:
The Python type class (list, pd.Series, or pa.Array)
"""
mapping = {
VectorType.LIST: list,
VectorType.PANDAS_SERIES: pd.Series,
VectorType.ARROW_ARRAY: pa.Array,
}
return mapping[self]
@staticmethod
def resolve_vector_type(param: inspect.Parameter):
"""
Resolves the param's type annotation to the corresponding VectorType enum.
Returns None if the type is unsupported or not a vector type.
"""
if (
param is None
or param.annotation is None
or param.annotation is inspect.Parameter.empty
):
return None
annotation = param.annotation
origin = get_origin(annotation)
raw_type = origin if origin is not None else annotation
if raw_type is list:
return VectorType.LIST
if raw_type is pd.Series:
return VectorType.PANDAS_SERIES
return None
class PythonUDFMeta:
"""Metadata container for a Python UDF."""
def __init__(
self,
name: str,
symbol: str,
location: str,
udf_load_type: int,
runtime_version: str,
always_nullable: bool,
inline_code: bytes,
input_types: pa.Schema,
output_type: pa.DataType,
) -> None:
"""
Initialize Python UDF metadata.
Args:
name: UDF function name
symbol: Symbol to load (function name or module.function)
location: File path or directory containing the UDF
udf_load_type: 0 for inline code, 1 for module
runtime_version: Python runtime version requirement
always_nullable: Whether the UDF can return NULL values
inline_code: Base64-encoded inline Python code (if applicable)
input_types: PyArrow schema for input parameters
output_type: PyArrow data type for return value
"""
self.name = name
self.symbol = symbol
self.location = location
self.udf_load_type = udf_load_type
self.runtime_version = runtime_version
self.always_nullable = always_nullable
self.inline_code = inline_code
self.input_types = input_types
self.output_type = output_type
def __str__(self) -> str:
"""Returns a string representation of the UDF metadata."""
udf_load_type_str = "INLINE" if self.udf_load_type == 0 else "MODULE"
return (
f"PythonUDFMeta(name={self.name}, symbol={self.symbol}, "
f"location={self.location}, udf_load_type={udf_load_type_str}, runtime_version={self.runtime_version}, "
f"always_nullable={self.always_nullable}, inline_code={self.inline_code}, "
f"input_types={self.input_types}, output_type={self.output_type})"
)
class AdaptivePythonUDF:
"""
A wrapper around a UDF function that supports both scalar and vectorized execution modes.
The mode is determined by the type hints of the function parameters.
"""
def __init__(self, python_udf_meta: PythonUDFMeta, func: Callable) -> None:
"""
Initialize the adaptive UDF wrapper.
Args:
python_udf_meta: Metadata describing the UDF
func: The actual Python function to execute
"""
self.python_udf_meta = python_udf_meta
self._eval_func = func
def __str__(self) -> str:
"""Returns a string representation of the UDF wrapper."""
input_type_strs = [str(t) for t in self.python_udf_meta.input_types.types]
output_type_str = str(self.python_udf_meta.output_type)
eval_func_str = f"{self.python_udf_meta.name}({', '.join(input_type_strs)}) -> {output_type_str}"
return f"AdaptivePythonUDF(python_udf_meta: {self.python_udf_meta}, eval_func: {eval_func_str})"
def __call__(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Executes the UDF on the given record batch. Supports both scalar and vectorized modes.
:param record_batch: Input data with N columns, each of length num_rows
:return: Output array of length num_rows
"""
if record_batch.num_rows == 0:
return pa.array([], type=self._get_output_type())
if self._should_use_vectorized():
logging.info("Using vectorized mode for UDF: %s", self.python_udf_meta.name)
return self._vectorized_call(record_batch)
logging.info("Using scalar mode for UDF: %s", self.python_udf_meta.name)
return self._scalar_call(record_batch)
@staticmethod
def _cast_arrow_to_vector(arrow_array: pa.Array, vec_type: VectorType):
"""
Convert a pa.Array to an instance of the specified VectorType.
"""
if vec_type == VectorType.LIST:
return arrow_array.to_pylist()
elif vec_type == VectorType.PANDAS_SERIES:
return arrow_array.to_pandas()
else:
raise ValueError(f"Unsupported vector type: {vec_type}")
def _should_use_vectorized(self) -> bool:
"""
Determines whether to use vectorized mode based on parameter type annotations.
Returns True if any parameter is annotated as:
- list
- pd.Series
"""
try:
signature = inspect.signature(self._eval_func)
except ValueError:
# Cannot inspect built-in or C functions; default to scalar
return False
for param in signature.parameters.values():
if VectorType.resolve_vector_type(param):
return True
return False
def _convert_from_arrow_to_py(self, field):
if field is None:
return None
if pa.types.is_map(field.type):
# pyarrow.lib.MapScalar's as_py() returns a list of tuples, convert to dict
list_of_tuples = field.as_py()
return dict(list_of_tuples) if list_of_tuples is not None else None
return field.as_py()
def _scalar_call(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Applies the UDF in scalar mode: one row at a time.
Args:
record_batch: Input data batch
Returns:
Output array with results for each row
"""
columns = record_batch.columns
num_rows = record_batch.num_rows
result = []
for i in range(num_rows):
converted_args = [self._convert_from_arrow_to_py(col[i]) for col in columns]
try:
res = self._eval_func(*converted_args)
# Check if result is None when always_nullable is False
if res is None and not self.python_udf_meta.always_nullable:
raise RuntimeError(
f"the result of row {i} is null, but the return type is not nullable, "
f"please check the always_nullable property in create function statement, "
f"it should be true"
)
result.append(res)
except Exception as e:
logging.error(
"Error in scalar UDF execution at row %s: %s\nArgs: %s\nTraceback: %s",
i,
e,
converted_args,
traceback.format_exc(),
)
# Return None for failed rows if always_nullable is True
if self.python_udf_meta.always_nullable:
result.append(None)
else:
raise
return pa.array(result, type=self._get_output_type(), from_pandas=True)
def _vectorized_call(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
Applies the UDF in vectorized mode: processes entire columns at once.
Args:
record_batch: Input data batch
Returns:
Output array with results
"""
column_args = record_batch.columns
logging.info("Vectorized call with %s columns", len(column_args))
sig = inspect.signature(self._eval_func)
params = list(sig.parameters.values())
if len(column_args) != len(params):
raise ValueError(f"UDF expects {len(params)} args, got {len(column_args)}")
converted_args = []
for param, arrow_col in zip(params, column_args):
vec_type = VectorType.resolve_vector_type(param)
if vec_type is None:
# For scalar types (int, float, str, etc.), extract the first value
# instead of converting to list
pylist = arrow_col.to_pylist()
if len(pylist) > 0:
converted = pylist[0]
logging.info(
"Converted %s to scalar (first value): %s",
param.name,
type(converted).__name__,
)
else:
converted = None
logging.info(
"Converted %s to scalar (None, empty column)", param.name
)
else:
converted = self._cast_arrow_to_vector(arrow_col, vec_type)
logging.info("Converted %s: %s", param.name, vec_type)
converted_args.append(converted)
try:
result = self._eval_func(*converted_args)
except Exception as e:
logging.error(
"Error in vectorized UDF: %s\nTraceback: %s", e, traceback.format_exc()
)
raise RuntimeError(f"Error in vectorized UDF: {e}") from e
# Convert result to PyArrow Array
result_array = None
if isinstance(result, pa.Array):
result_array = result
elif isinstance(result, pa.ChunkedArray):
# Combine chunks into a single array
result_array = pa.concat_arrays(result.chunks)
elif isinstance(result, pd.Series):
result_array = pa.array(result, type=self._get_output_type())
elif isinstance(result, list):
result_array = pa.array(
result, type=self._get_output_type(), from_pandas=True
)
else:
# Scalar result - broadcast to all rows
out_type = self._get_output_type()
logging.warning(
"UDF returned scalar value, broadcasting to %s rows",
record_batch.num_rows,
)
result_array = pa.array([result] * record_batch.num_rows, type=out_type)
# Check for None values when always_nullable is False
if not self.python_udf_meta.always_nullable:
null_count = result_array.null_count
if null_count > 0:
# Find the first null index for error message
for i, value in enumerate(result_array):
if value.is_valid is False:
raise RuntimeError(
f"the result of row {i} is null, but the return type is not nullable, "
f"please check the always_nullable property in create function statement, "
f"it should be true"
)
return result_array
def _get_output_type(self) -> pa.DataType:
"""
Returns the expected output type for the UDF.
Returns:
PyArrow DataType for the output
"""
return self.python_udf_meta.output_type or pa.null()
class UDFLoader(ABC):
"""Abstract base class for loading UDFs from different sources."""
def __init__(self, python_udf_meta: PythonUDFMeta) -> None:
"""
Initialize the UDF loader.
Args:
python_udf_meta: Metadata describing the UDF to load
"""
self.python_udf_meta = python_udf_meta
@abstractmethod
def load(self) -> AdaptivePythonUDF:
"""Load the UDF and return an AdaptivePythonUDF wrapper."""
raise NotImplementedError("Subclasses must implement load().")
class InlineUDFLoader(UDFLoader):
"""Loads a UDF defined directly in inline code."""
def load(self) -> AdaptivePythonUDF:
"""
Load and execute inline Python code to extract the UDF function.
Returns:
AdaptivePythonUDF wrapper around the loaded function
Raises:
RuntimeError: If code execution fails
ValueError: If the function is not found or not callable
"""
symbol = self.python_udf_meta.symbol
inline_code = self.python_udf_meta.inline_code.decode("utf-8")
env: dict[str, Any] = {}
logging.info("Loading inline code for function '%s'", symbol)
logging.debug("Inline code:\n%s", inline_code)
try:
# Execute the code in a clean environment
# pylint: disable=exec-used
# Note: exec() is necessary here for dynamic UDF loading from inline code
exec(inline_code, env) # nosec B102
except Exception as e:
logging.error(
"Failed to exec inline code: %s\nTraceback: %s",
e,
traceback.format_exc(),
)
raise RuntimeError(f"Failed to exec inline code: {e}") from e
func = env.get(symbol)
if func is None:
available_funcs = [
k for k, v in env.items() if callable(v) and not k.startswith("_")
]
logging.error(
"Function '%s' not found in inline code. Available functions: %s",
symbol,
available_funcs,
)
raise ValueError(f"Function '{symbol}' not found in inline code.")
if not callable(func):
logging.error(
"'%s' exists but is not callable (type: %s)", symbol, type(func)
)
raise ValueError(f"'{symbol}' is not a callable function.")
logging.info("Successfully loaded function '%s' from inline code", symbol)
return AdaptivePythonUDF(self.python_udf_meta, func)
class ModuleUDFLoader(UDFLoader):
"""Loads a UDF from a Python module file (.py)."""
def load(self) -> AdaptivePythonUDF:
"""
Loads a UDF from a Python module file.
Returns:
AdaptivePythonUDF instance wrapping the loaded function
Raises:
ValueError: If module file not found
TypeError: If symbol is not callable
"""
symbol = self.python_udf_meta.symbol # [package_name.]module_name.function_name
location = self.python_udf_meta.location # /path/to/module_name[.py]
if not os.path.exists(location):
raise ValueError(f"Module file not found: {location}")
package_name, module_name, func_name = self.parse_symbol(symbol)
func = self.load_udf_from_module(location, package_name, module_name, func_name)
if not callable(func):
raise TypeError(
f"'{symbol}' exists but is not callable (type: {type(func).__name__})"
)
logging.info(
"Successfully loaded function '%s' from module: %s", symbol, location
)
return AdaptivePythonUDF(self.python_udf_meta, func)
def parse_symbol(self, symbol: str):
"""
Parse symbol into (package_name, module_name, func_name)
Supported formats:
- "module.func" → (None, module, func)
- "package.module.func" → (package, "module", func)
"""
if not symbol or "." not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. "
"Expected 'module.function' or 'package.module.function'"
)
parts = symbol.split(".")
if len(parts) == 2:
# module.func → Single-file mode
module_name, func_name = parts
package_name = None
if not module_name or not module_name.strip():
raise ValueError(f"Module name is empty in symbol: '{symbol}'")
if not func_name or not func_name.strip():
raise ValueError(f"Function name is empty in symbol: '{symbol}'")
elif len(parts) > 2:
package_name = parts[0]
module_name = ".".join(parts[1:-1])
func_name = parts[-1]
if not package_name or not package_name.strip():
raise ValueError(f"Package name is empty in symbol: '{symbol}'")
if not module_name or not module_name.strip():
raise ValueError(f"Module name is empty in symbol: '{symbol}'")
if not func_name or not func_name.strip():
raise ValueError(f"Function name is empty in symbol: '{symbol}'")
else:
raise ValueError(f"Invalid symbol format: '{symbol}'")
logging.debug(
"Parsed symbol: package=%s, module=%s, func=%s",
package_name,
module_name,
func_name,
)
return package_name, module_name, func_name
def _validate_location(self, location: str) -> None:
"""Validate that the location is a valid directory."""
if not os.path.isdir(location):
raise ValueError(f"Location is not a directory: {location}")
def _get_or_import_module(self, location: str, full_module_name: str) -> Any:
"""Get module from cache or import it."""
if full_module_name in sys.modules:
logging.warning(
"Module '%s' already loaded, using cached version", full_module_name
)
return sys.modules[full_module_name]
with temporary_sys_path(location):
return importlib.import_module(full_module_name)
def _extract_function(
self, module: Any, func_name: str, module_name: str
) -> Callable:
"""Extract and validate function from module."""
func = getattr(module, func_name, None)
if func is None:
raise AttributeError(
f"Function '{func_name}' not found in module '{module_name}'"
)
if not callable(func):
raise TypeError(f"'{func_name}' is not callable")
return func
def _load_single_file_udf(
self, location: str, module_name: str, func_name: str
) -> Callable:
"""Load UDF from a single Python file."""
py_file = os.path.join(location, f"{module_name}.py")
if not os.path.isfile(py_file):
raise ImportError(f"Python file not found: {py_file}")
try:
udf_module = self._get_or_import_module(location, module_name)
return self._extract_function(udf_module, func_name, module_name)
except (ImportError, AttributeError, TypeError) as e:
raise ImportError(
f"Failed to load single-file UDF '{module_name}.{func_name}': {e}"
) from e
except Exception as e:
logging.error(
"Unexpected error loading UDF: %s\n%s", e, traceback.format_exc()
)
raise
def _ensure_package_init(self, package_path: str, package_name: str) -> None:
"""Ensure __init__.py exists in the package directory."""
init_path = os.path.join(package_path, "__init__.py")
if not os.path.exists(init_path):
logging.warning(
"__init__.py not found in package '%s', attempting to create it",
package_name,
)
try:
with open(init_path, "w", encoding="utf-8") as f:
f.write(
"# Auto-generated by UDF loader to make directory a Python package\n"
)
logging.info("Created __init__.py in %s", package_path)
except OSError as e:
raise ImportError(
f"Cannot create __init__.py in package '{package_name}': {e}"
) from e
def _build_full_module_name(self, package_name: str, module_name: str) -> str:
"""Build the full module name for package mode."""
if module_name == "__init__":
return package_name
return f"{package_name}.{module_name}"
def _load_package_udf(
self, location: str, package_name: str, module_name: str, func_name: str
) -> Callable:
"""Load UDF from a Python package."""
package_path = os.path.join(location, package_name)
if not os.path.isdir(package_path):
raise ImportError(f"Package '{package_name}' not found in '{location}'")
self._ensure_package_init(package_path, package_name)
try:
full_module_name = self._build_full_module_name(package_name, module_name)
udf_module = self._get_or_import_module(location, full_module_name)
return self._extract_function(udf_module, func_name, full_module_name)
except (ImportError, AttributeError, TypeError) as e:
raise ImportError(
f"Failed to load packaged UDF '{package_name}.{module_name}.{func_name}': {e}"
) from e
except Exception as e:
logging.error(
"Unexpected error loading packaged UDF: %s\n%s",
e,
traceback.format_exc(),
)
raise
def load_udf_from_module(
self,
location: str,
package_name: Optional[str],
module_name: str,
func_name: str,
) -> Callable:
"""
Load a UDF from a Python module, supporting both:
1. Single-file mode: package_name=None, module_name="your_file"
2. Package mode: package_name="your_pkg", module_name="submodule" or "__init__"
Args:
location:
- In package mode: parent directory of the package
- In single-file mode: directory containing the .py file
package_name:
- If None or empty: treat as single-file mode
- Else: standard package name
module_name:
- In package mode: submodule name (e.g., "main") or "__init__"
- In single-file mode: filename without .py (e.g., "udf_script")
func_name: name of the function to load
Returns:
The callable UDF function.
"""
self._validate_location(location)
if not package_name or package_name.strip() == "":
return self._load_single_file_udf(location, module_name, func_name)
else:
return self._load_package_udf(
location, package_name, module_name, func_name
)
class UDFLoaderFactory:
"""Factory to select the appropriate loader based on UDF location."""
@staticmethod
def get_loader(python_udf_meta: PythonUDFMeta) -> UDFLoader:
"""
Factory method to create the appropriate UDF loader based on metadata.
Args:
python_udf_meta: UDF metadata containing load type and location
Returns:
Appropriate UDFLoader instance (InlineUDFLoader or ModuleUDFLoader)
Raises:
ValueError: If UDF load type or location is unsupported
"""
location = python_udf_meta.location
udf_load_type = python_udf_meta.udf_load_type # 0: inline, 1: module
if udf_load_type == 0:
return InlineUDFLoader(python_udf_meta)
elif udf_load_type == 1:
if UDFLoaderFactory.check_module(location):
return ModuleUDFLoader(python_udf_meta)
else:
raise ValueError(f"Unsupported UDF location: {location}")
else:
raise ValueError(f"Unsupported UDF load type: {udf_load_type}")
@staticmethod
def check_module(location: str) -> bool:
"""
Checks if a location is a valid Python module or package.
A valid module is either:
- A .py file, or
- A directory containing __init__.py (i.e., a package).
Raises:
ValueError: If the location does not exist or contains no Python module.
Returns:
True if valid.
"""
if not os.path.exists(location):
raise ValueError(f"Module not found: {location}")
if os.path.isfile(location):
if location.endswith(".py"):
return True
else:
raise ValueError(f"File is not a Python module (.py): {location}")
if os.path.isdir(location):
if UDFLoaderFactory.has_python_file_recursive(location):
return True
else:
raise ValueError(
f"Directory contains no Python (.py) files: {location}"
)
raise ValueError(f"Invalid module location (not file or directory): {location}")
@staticmethod
def has_python_file_recursive(location: str) -> bool:
"""
Recursively checks if a directory contains any Python (.py) files.
Args:
location: Directory path to search
Returns:
True if at least one .py file is found, False otherwise
"""
path = Path(location)
if not path.is_dir():
return False
return any(path.rglob("*.py"))
class UDFFlightServer(flight.FlightServerBase):
"""Arrow Flight server for executing Python UDFs."""
@staticmethod
def parse_python_udf_meta(
descriptor: flight.FlightDescriptor,
) -> Optional[PythonUDFMeta]:
"""Parses UDF metadata from a command descriptor."""
if descriptor.descriptor_type != flight.DescriptorType.CMD:
logging.error("Invalid descriptor type: %s", descriptor.descriptor_type)
return None
cmd_json = json.loads(descriptor.command)
name = cmd_json["name"]
symbol = cmd_json["symbol"]
location = cmd_json["location"]
udf_load_type = cmd_json["udf_load_type"]
runtime_version = cmd_json["runtime_version"]
always_nullable = cmd_json["always_nullable"]
inline_code = base64.b64decode(cmd_json["inline_code"])
input_binary = base64.b64decode(cmd_json["input_types"])
output_binary = base64.b64decode(cmd_json["return_type"])
input_schema = pa.ipc.read_schema(pa.BufferReader(input_binary))
output_schema = pa.ipc.read_schema(pa.BufferReader(output_binary))
if len(output_schema) != 1:
logging.error(
"Output schema must have exactly one field: %s", output_schema
)
return None
output_type = output_schema.field(0).type
return PythonUDFMeta(
name=name,
symbol=symbol,
location=location,
udf_load_type=udf_load_type,
runtime_version=runtime_version,
always_nullable=always_nullable,
inline_code=inline_code,
input_types=input_schema,
output_type=output_type,
)
@staticmethod
def check_schema(
record_batch: pa.RecordBatch, expected_schema: pa.Schema
) -> Tuple[bool, str]:
"""
Validates that the input RecordBatch schema matches the expected schema.
Checks that field count and types match, but field names can differ.
:return: (result, error_message)
"""
actual = record_batch.schema
expected = expected_schema
logging.info(f"Actual schema: {actual}")
logging.info(f"Expected schema: {expected}")
# Check field count
if len(actual) != len(expected):
return (
False,
f"Schema length mismatch, got {len(actual)} fields, expected {len(expected)} fields",
)
# Check each field type (ignore field names)
for i, (actual_field, expected_field) in enumerate(zip(actual, expected)):
if not actual_field.type.equals(expected_field.type):
return False, (
f"Type mismatch at field index {i}, "
f"got {actual_field.type}, expected {expected_field.type}"
)
return True, ""
def do_exchange(
self,
context: flight.ServerCallContext,
descriptor: flight.FlightDescriptor,
reader: flight.MetadataRecordBatchReader,
writer: flight.MetadataRecordBatchWriter,
) -> None:
"""Handles bidirectional streaming UDF execution."""
logging.info("Received exchange request for UDF: %s", descriptor)
python_udf_meta = UDFFlightServer.parse_python_udf_meta(descriptor)
if not python_udf_meta:
raise ValueError("Invalid or missing UDF metadata in descriptor")
loader = UDFLoaderFactory.get_loader(python_udf_meta)
udf = loader.load()
logging.info("Loaded UDF: %s", udf)
started = False
for chunk in reader:
if not chunk.data:
logging.info("Empty chunk received, skipping")
continue
check_schema_result, error_msg = UDFFlightServer.check_schema(
chunk.data, python_udf_meta.input_types
)
if not check_schema_result:
logging.error("Schema mismatch: %s", error_msg)
raise ValueError(f"Schema mismatch: {error_msg}")
result_array = udf(chunk.data)
if not python_udf_meta.output_type.equals(result_array.type):
logging.error(
"Output type mismatch: got %s, expected %s",
result_array.type,
python_udf_meta.output_type,
)
raise ValueError(
f"Output type mismatch: got {result_array.type}, expected {python_udf_meta.output_type}"
)
result_batch = pa.RecordBatch.from_arrays([result_array], ["result"])
if not started:
writer.begin(result_batch.schema)
started = True
writer.write_batch(result_batch)
def check_unix_socket_path(unix_socket_path: str) -> bool:
"""Validates the Unix domain socket path format."""
if not unix_socket_path:
logging.error("Unix socket path is empty")
return False
if not unix_socket_path.startswith("grpc+unix://"):
raise ValueError("gRPC UDS URL must start with 'grpc+unix://'")
socket_path = unix_socket_path[len("grpc+unix://") :].strip()
if not socket_path:
logging.error("Extracted socket path is empty")
return False
return True
def main(unix_socket_path: str) -> None:
"""
Main entry point for the Python UDF server.
Args:
unix_socket_path: Base path for the Unix domain socket
Raises:
SystemExit: If socket path is invalid or server fails to start
"""
try:
if not check_unix_socket_path(unix_socket_path):
print(f"ERROR: Invalid socket path: {unix_socket_path}", flush=True)
sys.exit(1)
current_pid = os.getpid()
ServerState.unix_socket_path = f"{unix_socket_path}_{current_pid}.sock"
server = UDFFlightServer(ServerState.unix_socket_path)
print("Start python server successfully", flush=True)
logging.info("##### PYTHON UDF SERVER STARTED AT %s #####", datetime.now())
server.wait()
except Exception as e:
print(
f"ERROR: Failed to start Python UDF server: {type(e).__name__}: {e}",
flush=True,
)
tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
if len(tb_lines) > 1:
print(f"DETAIL: {tb_lines[-2].strip()}", flush=True)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run an Arrow Flight UDF server over Unix socket."
)
parser.add_argument(
"unix_socket_path",
type=str,
help="Path to the Unix socket (e.g., grpc+unix:///path/to/socket)",
)
args = parser.parse_args()
main(args.unix_socket_path)