blob: 588caa21a8179d119b350690ed1ef911f1e167a2 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""In-process pickle round-trip tests for :class:`Expr`.
Built-in functions and Python UDFs (scalar, aggregate, window) travel
with the pickled expression and do not need worker-side pre-registration.
The worker context (:mod:`datafusion.ipc`) is only consulted for UDFs
imported via the FFI capsule protocol.
Cross-process tests live in ``test_pickle_multiprocessing.py``.
"""
from __future__ import annotations
import pickle
import threading
import pyarrow as pa
import pytest
from datafusion import Expr, SessionContext, col, lit, udf
from datafusion.ipc import (
clear_sender_ctx,
clear_worker_ctx,
get_sender_ctx,
get_worker_ctx,
set_sender_ctx,
set_worker_ctx,
)
@pytest.fixture(autouse=True)
def _reset_worker_ctx():
"""Ensure every test starts with no worker or sender context installed."""
clear_worker_ctx()
clear_sender_ctx()
yield
clear_worker_ctx()
clear_sender_ctx()
def _double_udf():
return udf(
lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
[pa.int64()],
pa.int64(),
volatility="immutable",
name="double",
)
class TestProtoRoundTrip:
def test_builtin_round_trip(self):
e = col("a") + lit(1)
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
assert decoded.canonical_name() == e.canonical_name()
def test_to_bytes_from_bytes(self):
e = col("x") * lit(7)
blob = e.to_bytes()
assert isinstance(blob, bytes)
decoded = Expr.from_bytes(blob)
assert decoded.canonical_name() == e.canonical_name()
def test_explicit_ctx_used(self, ctx):
e = col("a") + lit(1)
decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx)
assert decoded.canonical_name() == e.canonical_name()
class TestUDFCodec:
"""Python scalar UDFs ride inside the proto blob via the Rust codec.
No worker context needed on the receiver — the cloudpickled callable is
embedded in ``fun_definition`` and reconstructed automatically.
"""
def test_udf_self_contained_blob(self):
e = _double_udf()(col("a"))
blob = pickle.dumps(e)
# The codec inlines the callable, so the blob is much bigger than a
# pure built-in blob but doesn't depend on receiver-side registration.
assert len(blob) > 200
def test_udf_decodes_into_fresh_ctx(self):
e = _double_udf()(col("a"))
blob = e.to_bytes()
fresh = SessionContext()
decoded = Expr.from_bytes(blob, ctx=fresh)
assert "double" in decoded.canonical_name()
def test_udf_decodes_via_pickle_with_no_worker_ctx(self):
e = _double_udf()(col("a"))
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
assert "double" in decoded.canonical_name()
def test_udf_decodes_via_pickle_with_worker_ctx(self):
set_worker_ctx(SessionContext())
e = _double_udf()(col("a"))
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
assert "double" in decoded.canonical_name()
def test_closure_capturing_udf_names_match(self):
captured_multiplier = 7
def fn(arr):
return pa.array([(v.as_py() or 0) * captured_multiplier for v in arr])
u = udf(
fn,
[pa.int64()],
pa.int64(),
volatility="immutable",
name="times_seven",
)
e = u(col("a"))
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
# Round-trip names match; functional verification of captured state
# happens in test_pickle_multiprocessing via an actual UDF call.
assert decoded.canonical_name() == e.canonical_name()
def test_multi_arg_udf_round_trip(self):
"""Wire format builds synthetic `arg_{i}` fields per input — exercise
with a 2-arg UDF spanning two distinct DataTypes."""
add_scaled = udf(
lambda a, b: pa.array(
[
(x.as_py() or 0) + (y.as_py() or 0.0)
for x, y in zip(a, b, strict=False)
]
),
[pa.int64(), pa.float64()],
pa.float64(),
volatility="immutable",
name="add_scaled",
)
e = add_scaled(col("a"), col("b"))
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
assert decoded.canonical_name() == e.canonical_name()
assert "add_scaled" in decoded.canonical_name()
class TestAggregateUDFCodec:
"""Python aggregate UDFs travel inline like scalar UDFs."""
def _build_aggregate_udf(self):
from datafusion import udaf
from datafusion.user_defined import Accumulator
class CountAcc(Accumulator):
def __init__(self):
self._count = 0
def state(self):
return [pa.scalar(self._count, type=pa.int64())]
def update(self, values):
self._count += len(values)
def merge(self, states):
partition_counts = states[0]
for i in range(len(partition_counts)):
self._count += partition_counts[i].as_py()
def evaluate(self):
return pa.scalar(self._count, type=pa.int64())
return udaf(
CountAcc,
[pa.int64()],
pa.int64(),
[pa.int64()],
"immutable",
name="count_all",
)
def test_agg_udf_self_contained_blob(self):
u = self._build_aggregate_udf()
e = u(col("a"))
blob = pickle.dumps(e)
assert len(blob) > 200
def test_agg_udf_decodes_into_fresh_ctx(self):
u = self._build_aggregate_udf()
e = u(col("a"))
blob = e.to_bytes()
fresh = SessionContext()
decoded = Expr.from_bytes(blob, ctx=fresh)
assert "count_all" in decoded.canonical_name()
def test_agg_udf_decodes_via_pickle_with_no_worker_ctx(self):
u = self._build_aggregate_udf()
e = u(col("a"))
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
assert "count_all" in decoded.canonical_name()
def test_agg_udf_evaluates_after_roundtrip(self):
"""End-to-end: the decoded aggregate UDF runs and merges across
partitions, exercising the round-tripped state-field schema."""
u = self._build_aggregate_udf()
e = u(col("a"))
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
ctx = SessionContext()
schema = pa.schema([pa.field("a", pa.int64())])
batch1 = pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], schema=schema)
batch2 = pa.record_batch([pa.array([4, 5], type=pa.int64())], schema=schema)
df = ctx.create_dataframe([[batch1], [batch2]])
out = df.aggregate([], [decoded.alias("n")]).to_pydict()
assert out["n"] == [5]
class TestWindowUDFCodec:
"""Python window UDFs travel inline like scalar UDFs."""
def _build_window_udf(self):
from datafusion import udwf
from datafusion.user_defined import WindowEvaluator
class CountUpEvaluator(WindowEvaluator):
def evaluate_all(self, values, num_rows):
return pa.array(list(range(num_rows)))
return udwf(
CountUpEvaluator,
[pa.int64()],
pa.int64(),
"immutable",
name="count_up",
)
def test_window_udf_self_contained_blob(self):
u = self._build_window_udf()
e = u(col("a"))
blob = pickle.dumps(e)
assert len(blob) > 200
def test_window_udf_decodes_into_fresh_ctx(self):
u = self._build_window_udf()
e = u(col("a"))
blob = e.to_bytes()
fresh = SessionContext()
decoded = Expr.from_bytes(blob, ctx=fresh)
assert "count_up" in decoded.canonical_name()
def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self):
u = self._build_window_udf()
e = u(col("a"))
blob = pickle.dumps(e)
decoded = pickle.loads(blob) # noqa: S301
assert "count_up" in decoded.canonical_name()
def test_window_udf_evaluates_after_roundtrip(self):
"""End-to-end: decoded window UDF runs and emits per-row values
produced by the round-tripped evaluator factory."""
from datafusion.expr import WindowFrame
u = self._build_window_udf()
e = u(col("a"))
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
ctx = SessionContext()
df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]})
framed = (
decoded.window_frame(WindowFrame("rows", None, None)).build().alias("c")
)
out = df.select(framed).to_pydict()
assert out["c"] == [0, 1, 2, 3, 4]
class TestErrorPaths:
def test_from_bytes_rejects_garbage(self):
with pytest.raises(Exception): # noqa: B017
Expr.from_bytes(b"not a valid protobuf payload")
def test_from_bytes_rejects_empty(self):
with pytest.raises(Exception): # noqa: B017
Expr.from_bytes(b"")
def test_cross_version_error_message(self):
"""Decoding a payload stamped with a different Python minor
version raises a clear, actionable error rather than an opaque
marshal/unpickle failure.
The wire frame inside the protobuf is:
``DFPYUDF (7) | version (1) | py_major (1) | py_minor (1) | cloudpickle``.
We locate the frame inside the outer protobuf and patch the
minor byte at offset 9.
"""
import sys
e = _double_udf()(col("a"))
blob = e.to_bytes()
idx = blob.find(b"DFPYUDF")
assert idx >= 0, "DFPYUDF frame not found in payload"
different_minor = (sys.version_info.minor + 1) % 256
tampered = bytearray(blob)
tampered[idx + 9] = different_minor
with pytest.raises(
Exception, match="not portable across Python minor versions"
):
Expr.from_bytes(bytes(tampered))
class TestPythonUdfInliningToggle:
"""`SessionContext.with_python_udf_inlining(enabled=False)` opts out of
inline Python UDF encoding for both encode and decode paths."""
def _build_double_udf(self):
return udf(
lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
[pa.int64()],
pa.int64(),
volatility="immutable",
name="double",
)
def test_strict_encoder_omits_inline_payload(self):
"""Strict mode emits the by-name wire form: no `DFPYUDF` magic
in the blob, no cloudpickled callable. Semantic check is
sharper than a size-ratio heuristic — a renamed UDF or a
smaller-than-expected closure would still flip the magic
bytes, but might not move the size by 4x.
"""
ctx_inline = SessionContext()
ctx_strict = ctx_inline.with_python_udf_inlining(enabled=False)
u = self._build_double_udf()
e = u(col("a"))
blob_inline = e.to_bytes(ctx_inline)
blob_strict = e.to_bytes(ctx_strict)
# `DFPYUDF` is the scalar Python-UDF family prefix; see
# `PY_SCALAR_UDF_FAMILY` in crates/core/src/codec.rs.
assert b"DFPYUDF" in blob_inline
assert b"DFPYUDF" not in blob_strict
def test_toggle_off_then_on_restores_inline_encoding(self):
"""`with_python_udf_inlining` is per-call clone semantics:
flipping off and then on must produce a context that emits the
same inline form as a fresh default context, byte-for-byte.
Guards against a regression where the off→on transition leaves
the codec in a sticky strict state (e.g. by mutating shared
codec state instead of cloning).
"""
u = self._build_double_udf()
e = u(col("a"))
baseline = SessionContext()
toggled = (
SessionContext()
.with_python_udf_inlining(enabled=False)
.with_python_udf_inlining(enabled=True)
)
blob_baseline = e.to_bytes(baseline)
blob_toggled = e.to_bytes(toggled)
assert blob_baseline == blob_toggled
# Sanity check the decoded form against a fresh ctx — the
# toggled-back blob should be self-contained inline, not a
# strict by-name payload that needs registry resolution.
decoded = Expr.from_bytes(blob_toggled, ctx=SessionContext())
assert "double" in decoded.canonical_name()
def test_strict_roundtrip_via_registry(self):
"""When both sender and receiver disable inlining, the UDF
travels by name only and the receiver resolves it from its
registered functions."""
strict_sender = SessionContext().with_python_udf_inlining(enabled=False)
u = self._build_double_udf()
blob = u(col("a")).to_bytes(strict_sender)
receiver = SessionContext().with_python_udf_inlining(enabled=False)
receiver.register_udf(u)
restored = Expr.from_bytes(blob, ctx=receiver)
assert "double" in restored.canonical_name()
def test_strict_decoder_refuses_inline_payload(self):
"""An inline-encoded blob fed to a strict receiver raises with a
clear error rather than silently invoking cloudpickle.loads.
The receiver is intentionally *not* given a matching
registration: the codec refusal must trip before the registry
is ever consulted, so registering the UDF here would only mask
a regression that moved the check after registry lookup.
"""
sender = SessionContext()
u = self._build_double_udf()
blob = u(col("a")).to_bytes(sender)
strict_receiver = SessionContext().with_python_udf_inlining(enabled=False)
# `RuntimeError` (not bare `Exception`): the codec refusal is
# surfaced through `parse_expr` → `PyRuntimeError`. Tightening
# the assertion catches a regression that swallows the refusal
# as a different error type.
with pytest.raises(RuntimeError, match="inlining is disabled"):
Expr.from_bytes(blob, ctx=strict_receiver)
def test_sender_ctx_propagates_through_pickle(self):
"""`set_sender_ctx` makes `pickle.dumps` use a strict codec.
Without a sender context, pickle defaults to the inline codec
and the blob contains the `DFPYUDF` family prefix. With a
strict sender context installed, the callable encodes by name
and the prefix is absent.
"""
u = self._build_double_udf()
e = u(col("a"))
blob_default = pickle.dumps(e)
strict_sender = SessionContext().with_python_udf_inlining(enabled=False)
set_sender_ctx(strict_sender)
try:
blob_strict = pickle.dumps(e)
finally:
clear_sender_ctx()
assert b"DFPYUDF" in blob_default
assert b"DFPYUDF" not in blob_strict
def test_sender_ctx_strict_roundtrip_via_pickle(self):
"""End-to-end pickle round-trip with strict mode on both sides.
Driver installs a strict sender context. Worker installs a
matching strict context with the UDF registered. The UDF
travels by name through `pickle.dumps` / `pickle.loads`.
"""
u = self._build_double_udf()
e = u(col("a"))
strict_sender = SessionContext().with_python_udf_inlining(enabled=False)
set_sender_ctx(strict_sender)
try:
blob = pickle.dumps(e)
finally:
clear_sender_ctx()
worker = SessionContext().with_python_udf_inlining(enabled=False)
worker.register_udf(u)
set_worker_ctx(worker)
try:
decoded = pickle.loads(blob) # noqa: S301
finally:
clear_worker_ctx()
assert "double" in decoded.canonical_name()
def test_sender_ctx_strict_pickle_accepted_by_inline_worker_with_registry(self):
"""A strict-encoded blob still decodes fine on an inline worker
because the wire format is the same default-codec by-name form.
Sanity check: cross-config works as long as the receiver can
resolve the name."""
u = self._build_double_udf()
e = u(col("a"))
strict_sender = SessionContext().with_python_udf_inlining(enabled=False)
set_sender_ctx(strict_sender)
try:
blob = pickle.dumps(e)
finally:
clear_sender_ctx()
worker = SessionContext()
worker.register_udf(u)
set_worker_ctx(worker)
try:
decoded = pickle.loads(blob) # noqa: S301
finally:
clear_worker_ctx()
assert "double" in decoded.canonical_name()
class TestWorkerCtxLifecycle:
def test_set_and_clear(self):
assert get_worker_ctx() is None
ctx = SessionContext()
set_worker_ctx(ctx)
assert get_worker_ctx() is ctx
clear_worker_ctx()
assert get_worker_ctx() is None
def test_clear_when_unset_is_noop(self):
clear_worker_ctx() # no error
assert get_worker_ctx() is None
def test_thread_local_isolation(self):
main_ctx = SessionContext()
set_worker_ctx(main_ctx)
seen_in_thread: list = []
def worker():
seen_in_thread.append(get_worker_ctx())
set_worker_ctx(SessionContext())
seen_in_thread.append(get_worker_ctx())
t = threading.Thread(target=worker)
t.start()
t.join()
# Thread saw no ctx initially (thread-local), then its own.
assert seen_in_thread[0] is None
assert seen_in_thread[1] is not main_ctx
# Main thread's ctx is unchanged by the thread's actions.
assert get_worker_ctx() is main_ctx
class TestSenderCtxLifecycle:
def test_set_and_clear(self):
assert get_sender_ctx() is None
ctx = SessionContext()
set_sender_ctx(ctx)
assert get_sender_ctx() is ctx
clear_sender_ctx()
assert get_sender_ctx() is None
def test_clear_when_unset_is_noop(self):
clear_sender_ctx() # no error
assert get_sender_ctx() is None
def test_thread_local_isolation(self):
main_ctx = SessionContext()
set_sender_ctx(main_ctx)
seen_in_thread: list = []
def worker():
seen_in_thread.append(get_sender_ctx())
set_sender_ctx(SessionContext())
seen_in_thread.append(get_sender_ctx())
t = threading.Thread(target=worker)
t.start()
t.join()
assert seen_in_thread[0] is None
assert seen_in_thread[1] is not main_ctx
assert get_sender_ctx() is main_ctx