blob: 4f04967f25f48bb2c266e843c449ce5a36acaaf2 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# The leading underscore is load-bearing: pytest with --import-mode=importlib
# (used in CI) assigns synthetic module names to test modules, which breaks
# subprocess imports during multiprocessing. An underscore-prefixed module is
# not collected as a test module, so it imports under its normal __name__
# inside worker processes.
from __future__ import annotations
import pyarrow as pa
from datafusion import SessionContext, udf
from datafusion.ipc import clear_worker_ctx, set_worker_ctx
def make_double_udf():
"""Build the canonical UDF used in the multiprocessing tests."""
return udf(
lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
[pa.int64()],
pa.int64(),
volatility="immutable",
name="double",
)
def make_times_seven_udf():
"""Closure-capturing UDF — verifies cloudpickle preserves closed-over state."""
multiplier = 7
def fn(arr):
return pa.array([(v.as_py() or 0) * multiplier for v in arr])
return udf(
fn,
[pa.int64()],
pa.int64(),
volatility="immutable",
name="times_seven",
)
def init_worker_empty():
"""Pool initializer: install an empty SessionContext (no UDFs)."""
set_worker_ctx(SessionContext())
def init_worker_clear():
"""Pool initializer: explicitly clear any prior worker context."""
clear_worker_ctx()
def unpickle_and_describe(blob: bytes) -> str:
"""Unpickle a proto-bytes blob and return its canonical name."""
import pickle
expr = pickle.loads(blob) # noqa: S301
return expr.canonical_name()
def unpickle_and_evaluate(blob: bytes, batch: list[int]) -> list[int]:
"""Unpickle an expression and evaluate it against an in-memory batch.
Returns the result column as a Python list. Used to verify that
cloudpickled UDFs (including closure state) execute correctly in
a fresh worker process.
"""
import pickle
expr = pickle.loads(blob) # noqa: S301
ctx = SessionContext()
df = ctx.from_pydict({"a": batch})
out = df.with_column("result", expr).select("result")
return out.to_pydict()["result"]