| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| //! Python-aware extension codecs. |
| //! |
| //! Datafusion-python plans can carry references to Python-defined |
| //! objects that the upstream protobuf codecs do not know how to |
| //! serialize: pure-Python scalar UDFs, Python query-planning |
| //! extensions, and so on. Their state lives inside `Py<PyAny>` |
| //! callables and closures rather than being recoverable from a name |
| //! in the receiver's function registry. To ship a plan across a |
| //! process boundary (pickle, `multiprocessing`, Ray actor, |
| //! `datafusion-distributed`, etc.) those payloads have to be encoded |
| //! into the proto wire format itself. |
| //! |
| //! [`PythonLogicalCodec`] is the [`LogicalExtensionCodec`] that |
| //! datafusion-python parks on every `SessionContext`. It wraps a |
| //! user-supplied (or default) inner codec and adds Python-aware |
| //! in-band encoding on top: when the encoder sees a Python-defined |
| //! UDF, the codec cloudpickles the callable + signature into the |
| //! `fun_definition` proto field; when the decoder sees a payload it |
| //! produced, it reconstructs the UDF from the bytes alone — no |
| //! pre-registration on the receiver. UDFs the codec does not |
| //! recognise are delegated to `inner`, which is typically |
| //! `DefaultLogicalExtensionCodec` but may be a downstream-supplied |
| //! FFI codec installed via |
| //! `SessionContext.with_logical_extension_codec(...)`. |
| //! |
| //! [`PythonPhysicalCodec`] is the symmetric wrapper around |
| //! [`PhysicalExtensionCodec`]. Logical and physical layers each have |
| //! a `try_encode_udf` / `try_decode_udf` pair, so a `ScalarUDF` |
| //! referenced inside a `LogicalPlan`, an `ExecutionPlan`, or a |
| //! `PhysicalExpr` must encode identically through either layer for |
| //! plans to survive a serialization round-trip. Both codecs share |
| //! the same payload framing for that reason. |
| //! |
| //! Payloads emitted by these codecs are framed as |
| //! `<family_magic: 7 bytes> <version: u8> <py_major: u8> <py_minor: u8> <cloudpickle blob>`. |
| //! The family magic identifies the UDF flavor; the version byte lets |
| //! the decoder reject too-new or too-old payloads with a clean error |
| //! instead of falling into an opaque `cloudpickle` tuple-unpack |
| //! failure when the tuple shape changes; the Python `(major, minor)` |
| //! bytes catch the cloudpickle-cross-minor-version case and raise an |
| //! actionable error instead of an opaque `marshal` failure on load |
| //! (cloudpickle payloads are not portable across Python minor |
| //! versions). Dispatch precedence on decode: **family match + |
| //! supported version + matching Python version → `inner` codec → |
| //! caller's `FunctionRegistry` fallback.** |
| //! |
| //! ## Wire-format family registry |
| //! |
| //! | Layer + kind | Family prefix | |
| //! | ----------------------------- | ------------- | |
| //! | `PythonLogicalCodec` scalar | `DFPYUDF` | |
| //! | `PythonPhysicalCodec` scalar | `DFPYUDF` | |
| //! | User FFI extension codec | user-chosen | |
| //! | Default codec | (none) | |
| //! |
| //! Aggregate and window UDF families are reserved for follow-on work. |
| //! |
| //! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported |
| //! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. |
| //! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape |
| //! changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support |
| //! for an older shape. |
| //! |
| //! Downstream FFI codecs should pick non-colliding family prefixes |
| //! (use a `DF` namespace plus a crate-specific suffix). The codec |
| //! implementations in this module currently delegate every method to |
| //! `inner`; the encoder/decoder hooks for each kind are added as the |
| //! corresponding Python-side type becomes serializable. |
| |
| use std::sync::Arc; |
| |
| use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; |
| use arrow::ipc::reader::StreamReader; |
| use arrow::ipc::writer::StreamWriter; |
| use datafusion::common::{Result, TableReference}; |
| use datafusion::datasource::TableProvider; |
| use datafusion::datasource::file_format::FileFormatFactory; |
| use datafusion::execution::TaskContext; |
| use datafusion::logical_expr::{ |
| AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, |
| Volatility, WindowUDF, |
| }; |
| use datafusion::physical_expr::PhysicalExpr; |
| use datafusion::physical_plan::ExecutionPlan; |
| use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; |
| use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; |
| use pyo3::prelude::*; |
| use pyo3::sync::PyOnceLock; |
| use pyo3::types::{PyBytes, PyTuple}; |
| |
| use crate::udf::PythonFunctionScalarUDF; |
| |
| // Wire-format framing for inlined Python UDF payloads. |
| // |
| // Layout: `<family_magic: 7 bytes> <version: u8> <py_major: u8> <py_minor: u8> <cloudpickle blob>`. |
| // The family magic identifies the UDF flavor; the version byte lets |
| // the decoder reject too-new or too-old payloads with a clean error |
| // instead of falling into an opaque `cloudpickle` tuple-unpack failure |
| // when the tuple shape changes; the Python `(major, minor)` bytes |
| // catch the cloudpickle-cross-minor-version case (cloudpickle is not |
| // portable across Python minor versions) and raise an actionable |
| // error instead of an opaque `marshal` failure on load. Bump |
| // [`WIRE_VERSION_CURRENT`] whenever the tuple shape changes; raise |
| // [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support for an older |
| // shape. |
| |
| /// Family prefix for an inlined Python scalar UDF |
| /// (cloudpickled tuple of name, callable, input schema, return field, |
| /// volatility). |
| pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; |
| |
| /// Wire-format version this build emits. |
| pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; |
| |
| /// Oldest wire-format version this build still decodes. Bump when |
| /// retiring support for an older payload shape. |
| pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; |
| |
| /// Tag `buf` with the framing header for `family` at the current |
| /// wire-format version, stamping `py_version` as `(major, minor)` |
| /// bytes. Append-only — the caller writes the cloudpickle payload |
| /// after. |
| fn write_wire_header(buf: &mut Vec<u8>, family: &[u8], py_version: (u8, u8)) { |
| buf.extend_from_slice(family); |
| buf.push(WIRE_VERSION_CURRENT); |
| buf.push(py_version.0); |
| buf.push(py_version.1); |
| } |
| |
| /// Inspect the framing on `buf`. |
| /// |
| /// * `Ok(None)` — `buf` does not carry `family`. The caller should |
| /// delegate to its `inner` codec. |
| /// * `Ok(Some(payload))` — `buf` carries `family` at a version this |
| /// build accepts and a Python `(major, minor)` matching |
| /// `expected_py`; `payload` is the cloudpickle blob. |
| /// * `Err(_)` — `buf` carries `family` but the wire-format version |
| /// is outside `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`, |
| /// or the stamped Python `(major, minor)` does not match |
| /// `expected_py`. The error names the offending values so an |
| /// operator can diagnose sender/receiver drift instead of seeing |
| /// an opaque cloudpickle tuple-unpack or `marshal` failure. |
| fn strip_wire_header<'a>( |
| buf: &'a [u8], |
| family: &[u8], |
| kind: &str, |
| expected_py: (u8, u8), |
| ) -> Result<Option<&'a [u8]>> { |
| if !buf.starts_with(family) { |
| return Ok(None); |
| } |
| let version_idx = family.len(); |
| let Some(&version) = buf.get(version_idx) else { |
| return Err(datafusion::error::DataFusionError::Execution(format!( |
| "Truncated inline Python {kind} payload: missing wire-format version byte" |
| ))); |
| }; |
| if !(WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT).contains(&version) { |
| return Err(datafusion::error::DataFusionError::Execution(format!( |
| "Inline Python {kind} payload wire-format version v{version}; \ |
| this build supports v{WIRE_VERSION_MIN_SUPPORTED}..=v{WIRE_VERSION_CURRENT}. \ |
| Align datafusion-python versions on sender and receiver." |
| ))); |
| } |
| let py_major_idx = version_idx + 1; |
| let Some(&encoded_major) = buf.get(py_major_idx) else { |
| return Err(datafusion::error::DataFusionError::Execution(format!( |
| "Truncated inline Python {kind} payload: missing Python major version byte" |
| ))); |
| }; |
| let py_minor_idx = version_idx + 2; |
| let Some(&encoded_minor) = buf.get(py_minor_idx) else { |
| return Err(datafusion::error::DataFusionError::Execution(format!( |
| "Truncated inline Python {kind} payload: missing Python minor version byte" |
| ))); |
| }; |
| let (current_major, current_minor) = expected_py; |
| if encoded_major != current_major || encoded_minor != current_minor { |
| return Err(datafusion::error::DataFusionError::Execution(format!( |
| "Inline Python {kind} payload was serialized on Python \ |
| {encoded_major}.{encoded_minor} but this process is running Python \ |
| {current_major}.{current_minor}. cloudpickle payloads are not portable \ |
| across Python minor versions. Align Python versions on sender and receiver." |
| ))); |
| } |
| Ok(Some(&buf[py_minor_idx + 1..])) |
| } |
| |
| /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds |
| /// the Python-aware encoding hooks for logical-layer types |
| /// (`LogicalPlan`, `Expr`) and delegates everything it does not |
| /// handle to the composable `inner` codec — typically |
| /// `DefaultLogicalExtensionCodec`, or a downstream FFI codec |
| /// installed via `SessionContext.with_logical_extension_codec(...)`. |
| /// |
| /// Sitting at the top of the session's logical codec stack means |
| /// every serializer that reads `session.logical_codec()` automatically |
| /// picks up Python-aware encoding for free. |
| #[derive(Debug)] |
| pub struct PythonLogicalCodec { |
| inner: Arc<dyn LogicalExtensionCodec>, |
| } |
| |
| impl PythonLogicalCodec { |
| pub fn new(inner: Arc<dyn LogicalExtensionCodec>) -> Self { |
| Self { inner } |
| } |
| |
| pub fn inner(&self) -> &Arc<dyn LogicalExtensionCodec> { |
| &self.inner |
| } |
| } |
| |
| impl Default for PythonLogicalCodec { |
| fn default() -> Self { |
| Self::new(Arc::new(DefaultLogicalExtensionCodec {})) |
| } |
| } |
| |
| impl LogicalExtensionCodec for PythonLogicalCodec { |
| fn try_decode( |
| &self, |
| buf: &[u8], |
| inputs: &[LogicalPlan], |
| ctx: &TaskContext, |
| ) -> Result<Extension> { |
| self.inner.try_decode(buf, inputs, ctx) |
| } |
| |
| fn try_encode(&self, node: &Extension, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode(node, buf) |
| } |
| |
| fn try_decode_table_provider( |
| &self, |
| buf: &[u8], |
| table_ref: &TableReference, |
| schema: SchemaRef, |
| ctx: &TaskContext, |
| ) -> Result<Arc<dyn TableProvider>> { |
| self.inner |
| .try_decode_table_provider(buf, table_ref, schema, ctx) |
| } |
| |
| fn try_encode_table_provider( |
| &self, |
| table_ref: &TableReference, |
| node: Arc<dyn TableProvider>, |
| buf: &mut Vec<u8>, |
| ) -> Result<()> { |
| self.inner.try_encode_table_provider(table_ref, node, buf) |
| } |
| |
| fn try_decode_file_format( |
| &self, |
| buf: &[u8], |
| ctx: &TaskContext, |
| ) -> Result<Arc<dyn FileFormatFactory>> { |
| self.inner.try_decode_file_format(buf, ctx) |
| } |
| |
| fn try_encode_file_format( |
| &self, |
| buf: &mut Vec<u8>, |
| node: Arc<dyn FileFormatFactory>, |
| ) -> Result<()> { |
| self.inner.try_encode_file_format(buf, node) |
| } |
| |
| fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> { |
| if try_encode_python_scalar_udf(node, buf)? { |
| return Ok(()); |
| } |
| self.inner.try_encode_udf(node, buf) |
| } |
| |
| fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> { |
| if let Some(udf) = try_decode_python_scalar_udf(buf)? { |
| return Ok(udf); |
| } |
| self.inner.try_decode_udf(name, buf) |
| } |
| |
| fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode_udaf(node, buf) |
| } |
| |
| fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> { |
| self.inner.try_decode_udaf(name, buf) |
| } |
| |
| fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode_udwf(node, buf) |
| } |
| |
| fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> { |
| self.inner.try_decode_udwf(name, buf) |
| } |
| } |
| |
| /// `PhysicalExtensionCodec` mirror of [`PythonLogicalCodec`] parked |
| /// on the same `SessionContext`. Carries the Python-aware encoding |
| /// hooks for physical-layer types (`ExecutionPlan`, `PhysicalExpr`) |
| /// and delegates the rest to `inner`. |
| /// |
| /// The `PhysicalExtensionCodec` trait has its own `try_encode_udf` |
| /// / `try_decode_udf` pair distinct from the logical one, so a |
| /// `ScalarUDF` referenced inside a physical plan needs Python-aware |
| /// encoding on this layer too — otherwise a plan with a Python UDF |
| /// would round-trip at the logical level but break at the physical |
| /// level. Both layers reuse the shared payload framing |
| /// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. |
| #[derive(Debug)] |
| pub struct PythonPhysicalCodec { |
| inner: Arc<dyn PhysicalExtensionCodec>, |
| } |
| |
| impl PythonPhysicalCodec { |
| pub fn new(inner: Arc<dyn PhysicalExtensionCodec>) -> Self { |
| Self { inner } |
| } |
| |
| pub fn inner(&self) -> &Arc<dyn PhysicalExtensionCodec> { |
| &self.inner |
| } |
| } |
| |
| impl Default for PythonPhysicalCodec { |
| fn default() -> Self { |
| Self::new(Arc::new(DefaultPhysicalExtensionCodec {})) |
| } |
| } |
| |
| impl PhysicalExtensionCodec for PythonPhysicalCodec { |
| fn try_decode( |
| &self, |
| buf: &[u8], |
| inputs: &[Arc<dyn ExecutionPlan>], |
| ctx: &TaskContext, |
| ) -> Result<Arc<dyn ExecutionPlan>> { |
| self.inner.try_decode(buf, inputs, ctx) |
| } |
| |
| fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode(node, buf) |
| } |
| |
| fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> { |
| if try_encode_python_scalar_udf(node, buf)? { |
| return Ok(()); |
| } |
| self.inner.try_encode_udf(node, buf) |
| } |
| |
| fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> { |
| if let Some(udf) = try_decode_python_scalar_udf(buf)? { |
| return Ok(udf); |
| } |
| self.inner.try_decode_udf(name, buf) |
| } |
| |
| fn try_encode_expr(&self, node: &Arc<dyn PhysicalExpr>, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode_expr(node, buf) |
| } |
| |
| fn try_decode_expr( |
| &self, |
| buf: &[u8], |
| inputs: &[Arc<dyn PhysicalExpr>], |
| ) -> Result<Arc<dyn PhysicalExpr>> { |
| self.inner.try_decode_expr(buf, inputs) |
| } |
| |
| fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode_udaf(node, buf) |
| } |
| |
| fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> { |
| self.inner.try_decode_udaf(name, buf) |
| } |
| |
| fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> { |
| self.inner.try_encode_udwf(node, buf) |
| } |
| |
| fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> { |
| self.inner.try_decode_udwf(name, buf) |
| } |
| } |
| |
| // ============================================================================= |
| // Shared Python scalar UDF encode / decode helpers |
| // |
| // Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on |
| // every `try_encode_udf` / `try_decode_udf` call. Same wire format on |
| // both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` |
| // or an `ExecutionPlan` round-trips identically. |
| // ============================================================================= |
| |
| /// Encode a Python scalar UDF inline if `node` is one. Returns |
| /// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, |
| /// cloudpickled tuple) was written and the caller should skip its |
| /// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling |
| /// the caller to delegate to its `inner`. |
| pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<bool> { |
| let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionScalarUDF>() else { |
| return Ok(false); |
| }; |
| |
| Python::attach(|py| -> Result<bool> { |
| let py_version = current_python_version(py) |
| .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| let bytes = encode_python_scalar_udf(py, py_udf) |
| .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| write_wire_header(buf, PY_SCALAR_UDF_FAMILY, py_version); |
| buf.extend_from_slice(&bytes); |
| Ok(true) |
| }) |
| } |
| |
| /// Decode an inline Python scalar UDF payload. Returns `Ok(None)` |
| /// when `buf` does not carry the `DFPYUDF` family prefix, signalling |
| /// the caller to delegate to its `inner` codec (and eventually the |
| /// `FunctionRegistry`). |
| pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<ScalarUDF>>> { |
| Python::attach(|py| -> Result<Option<Arc<ScalarUDF>>> { |
| let py_version = current_python_version(py) |
| .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", py_version)? |
| else { |
| return Ok(None); |
| }; |
| let udf = decode_python_scalar_udf(py, payload) |
| .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; |
| Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) |
| }) |
| } |
| |
| /// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. |
| /// |
| /// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, |
| /// return_schema_bytes, volatility_str))`. Schema blobs are produced |
| /// by arrow-rs's native IPC stream writer (no pyarrow round-trip) and |
| /// decoded with the matching stream reader on the receiver. See |
| /// [`build_input_schema_bytes`] for what the input blob carries. |
| fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult<Vec<u8>> { |
| let signature = udf.signature(); |
| let input_dtypes = signature_input_dtypes(signature, "PythonFunctionScalarUDF")?; |
| let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; |
| let return_schema_bytes = build_single_field_schema_bytes(udf.return_field().as_ref())?; |
| let volatility = volatility_wire_str(signature.volatility); |
| |
| let payload = PyTuple::new( |
| py, |
| [ |
| udf.name().into_pyobject(py)?.into_any(), |
| udf.func().bind(py).clone().into_any(), |
| PyBytes::new(py, &input_schema_bytes).into_any(), |
| PyBytes::new(py, &return_schema_bytes).into_any(), |
| volatility.into_pyobject(py)?.into_any(), |
| ], |
| )?; |
| |
| cloudpickle(py)? |
| .call_method1("dumps", (payload,))? |
| .extract::<Vec<u8>>() |
| } |
| |
| /// Inverse of [`encode_python_scalar_udf`]. |
| fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionScalarUDF> { |
| let tuple = cloudpickle(py)? |
| .call_method1("loads", (PyBytes::new(py, payload),))? |
| .cast_into::<PyTuple>()?; |
| |
| let name: String = tuple.get_item(0)?.extract()?; |
| let func: Py<PyAny> = tuple.get_item(1)?.unbind(); |
| let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?; |
| let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?; |
| let volatility_str: String = tuple.get_item(4)?.extract()?; |
| |
| let input_types = read_input_dtypes(&input_schema_bytes)?; |
| let return_field = read_single_return_field(&return_schema_bytes, "PythonFunctionScalarUDF")?; |
| let volatility = parse_volatility_str(&volatility_str)?; |
| |
| Ok(PythonFunctionScalarUDF::from_parts( |
| name, |
| func, |
| input_types, |
| return_field, |
| volatility, |
| )) |
| } |
| |
| /// Serialize a `Schema` to a self-contained IPC stream containing |
| /// only the schema message (no record batches). Inverse: |
| /// [`schema_from_ipc_bytes`]. |
| fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result<Vec<u8>> { |
| let mut buf: Vec<u8> = Vec::new(); |
| { |
| let mut writer = StreamWriter::try_new(&mut buf, schema)?; |
| writer.finish()?; |
| } |
| Ok(buf) |
| } |
| |
| /// Decode an IPC stream containing only a schema message back into a |
| /// `Schema`. Inverse: [`schema_to_ipc_bytes`]. |
| fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result<Schema> { |
| let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; |
| Ok(reader.schema().as_ref().clone()) |
| } |
| |
| /// Extract the per-arg `DataType`s from a `Signature` known to be |
| /// `TypeSignature::Exact` (all Python-defined UDFs are constructed |
| /// with `Signature::exact`). Any other variant indicates the impl was |
| /// not built by this crate's UDF/UDAF/UDWF constructors. |
| fn signature_input_dtypes(signature: &Signature, kind: &str) -> PyResult<Vec<DataType>> { |
| match &signature.type_signature { |
| TypeSignature::Exact(types) => Ok(types.clone()), |
| other => Err(pyo3::exceptions::PyValueError::new_err(format!( |
| "{kind} expected Signature::Exact, got {other:?}" |
| ))), |
| } |
| } |
| |
| /// Wrap per-arg `DataType`s in synthetic `arg_{i}` fields and emit |
| /// the IPC schema blob the encoder writes into the cloudpickle tuple. |
| /// |
| /// The names and `nullable: true` are arbitrary: the underlying |
| /// `TypeSignature::Exact` carries no per-input nullability or |
| /// metadata, and the receiver collapses these fields back to |
| /// `Vec<DataType>` via [`read_input_dtypes`], so anything set here |
| /// beyond the data type is discarded on decode. |
| fn build_input_schema_bytes(dtypes: &[DataType]) -> PyResult<Vec<u8>> { |
| let fields: Vec<Field> = dtypes |
| .iter() |
| .enumerate() |
| .map(|(i, dt)| Field::new(format!("arg_{i}"), dt.clone(), true)) |
| .collect(); |
| schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) |
| } |
| |
| /// Emit a single-field IPC schema blob. Used for return-type and |
| /// state-field payloads where the receiver needs to recover field |
| /// metadata (names, nullability, key/value attributes) verbatim. |
| fn build_single_field_schema_bytes(field: &Field) -> PyResult<Vec<u8>> { |
| schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) |
| } |
| |
| /// Decode the per-arg `DataType`s the encoder wrote via |
| /// [`build_input_schema_bytes`]. |
| fn read_input_dtypes(bytes: &[u8]) -> PyResult<Vec<DataType>> { |
| let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; |
| Ok(schema |
| .fields() |
| .iter() |
| .map(|f| f.data_type().clone()) |
| .collect()) |
| } |
| |
| /// Decode a single-field IPC schema blob and return that field by |
| /// value. `kind` names the UDF flavor in the error message produced |
| /// when the blob is empty (should be unreachable for sender-side |
| /// payloads built via [`build_single_field_schema_bytes`]). |
| fn read_single_return_field(bytes: &[u8], kind: &str) -> PyResult<Field> { |
| let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; |
| let field = schema.fields().first().ok_or_else(|| { |
| pyo3::exceptions::PyValueError::new_err(format!( |
| "{kind} return schema must contain exactly one field" |
| )) |
| })?; |
| Ok(field.as_ref().clone()) |
| } |
| |
| fn arrow_to_py_err(e: arrow::error::ArrowError) -> PyErr { |
| pyo3::exceptions::PyValueError::new_err(format!("{e}")) |
| } |
| |
| fn parse_volatility_str(s: &str) -> PyResult<Volatility> { |
| datafusion_python_util::parse_volatility(s) |
| .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}"))) |
| } |
| |
| /// Stable wire-format string for a `Volatility`. Pinned to the three |
| /// tokens [`datafusion_python_util::parse_volatility`] accepts, so an |
| /// upstream change to `Volatility`'s `Debug` repr cannot silently |
| /// produce bytes the decoder rejects. |
| fn volatility_wire_str(v: Volatility) -> &'static str { |
| match v { |
| Volatility::Immutable => "immutable", |
| Volatility::Stable => "stable", |
| Volatility::Volatile => "volatile", |
| } |
| } |
| |
| /// Read the interpreter's `sys.version_info` as `(major, minor)`. |
| /// |
| /// Used by encoder/decoder to stamp and verify the Python version a |
| /// cloudpickle payload was produced on. cloudpickle is not portable |
| /// across Python minor versions; the wire header carries these bytes |
| /// so a mismatch surfaces an actionable error instead of an opaque |
| /// `marshal` failure at `cloudpickle.loads` time. |
| fn current_python_version(py: Python<'_>) -> PyResult<(u8, u8)> { |
| let version_info = py.import("sys")?.getattr("version_info")?; |
| let major: u8 = version_info.getattr("major")?.extract()?; |
| let minor: u8 = version_info.getattr("minor")?.extract()?; |
| Ok((major, minor)) |
| } |
| |
| /// Cached handle to the `cloudpickle` module. |
| /// |
| /// The encode/decode helpers above would otherwise re-resolve the |
| /// module on every call. `py.import` is backed by `sys.modules` and |
| /// therefore cheap, but each call still walks a dict and re-binds the |
| /// result; a plan with many Python UDFs pays that cost per UDF. |
| /// |
| /// `PyOnceLock` scopes the cached `Py<PyAny>` to the current |
| /// interpreter, so the slot drops cleanly on interpreter teardown |
| /// (relevant under CPython subinterpreters, PEP 684) instead of |
| /// resurrecting a `Py` rooted in a dead interpreter on the next call. |
| fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> { |
| static CLOUDPICKLE: PyOnceLock<Py<PyAny>> = PyOnceLock::new(); |
| CLOUDPICKLE |
| .get_or_try_init(py, || Ok(py.import("cloudpickle")?.unbind().into_any())) |
| .map(|cached| cached.bind(py).clone()) |
| } |
| |
| #[cfg(test)] |
| mod wire_header_tests { |
| use super::*; |
| |
| const TEST_PY: (u8, u8) = (3, 12); |
| |
| #[test] |
| fn strip_returns_none_when_family_absent() { |
| let buf = b"OTHER_PAYLOAD"; |
| assert!(matches!( |
| strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY), |
| Ok(None) |
| )); |
| } |
| |
| #[test] |
| fn strip_errors_on_truncated_version_byte() { |
| let buf = PY_SCALAR_UDF_FAMILY; |
| let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); |
| assert!(format!("{err}").contains("missing wire-format version byte")); |
| } |
| |
| #[test] |
| fn strip_errors_on_too_new_version() { |
| let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); |
| buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); |
| buf.push(TEST_PY.0); |
| buf.push(TEST_PY.1); |
| buf.extend_from_slice(b"payload"); |
| let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); |
| let msg = format!("{err}"); |
| assert!(msg.contains("wire-format version v")); |
| assert!(msg.contains("supports")); |
| assert!(msg.contains("Align datafusion-python versions")); |
| } |
| |
| #[test] |
| fn strip_errors_on_too_old_version() { |
| if WIRE_VERSION_MIN_SUPPORTED == 0 { |
| return; |
| } |
| let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); |
| buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); |
| buf.push(TEST_PY.0); |
| buf.push(TEST_PY.1); |
| buf.extend_from_slice(b"payload"); |
| assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).is_err()); |
| } |
| |
| #[test] |
| fn strip_errors_on_truncated_py_major() { |
| let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); |
| buf.push(WIRE_VERSION_CURRENT); |
| let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); |
| assert!(format!("{err}").contains("missing Python major version byte")); |
| } |
| |
| #[test] |
| fn strip_errors_on_truncated_py_minor() { |
| let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); |
| buf.push(WIRE_VERSION_CURRENT); |
| buf.push(TEST_PY.0); |
| let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); |
| assert!(format!("{err}").contains("missing Python minor version byte")); |
| } |
| |
| #[test] |
| fn strip_errors_on_py_minor_mismatch() { |
| let mut buf = Vec::new(); |
| write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, (3, 11)); |
| buf.extend_from_slice(b"payload"); |
| let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", (3, 12)).unwrap_err(); |
| let msg = format!("{err}"); |
| assert!(msg.contains("Python 3.11")); |
| assert!(msg.contains("Python 3.12")); |
| assert!(msg.contains("not portable across Python minor versions")); |
| } |
| |
| #[test] |
| fn strip_errors_on_py_major_mismatch() { |
| let mut buf = Vec::new(); |
| write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, (3, 12)); |
| buf.extend_from_slice(b"payload"); |
| assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", (4, 0)).is_err()); |
| } |
| |
| #[test] |
| fn write_then_strip_round_trips_payload() { |
| let mut buf = Vec::new(); |
| write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY); |
| buf.extend_from_slice(b"scalar-payload"); |
| |
| let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY) |
| .unwrap() |
| .unwrap(); |
| assert_eq!(payload, b"scalar-payload"); |
| } |
| } |