blob: 2006401db33e64d7124c1ee17b028b8485fbe55b [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::hash::{Hash, Hasher};
use std::ptr::NonNull;
use std::sync::Arc;
use arrow::datatypes::{Field, FieldRef};
use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::array::{ArrayData, make_array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
use datafusion::common::internal_err;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_ffi::udf::FFI_ScalarUDF;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};
use crate::array::PyArrowArrayExportable;
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;
/// This struct holds the Python written function that is a
/// ScalarUDF.
#[derive(Debug)]
pub(crate) struct PythonFunctionScalarUDF {
name: String,
func: Py<PyAny>,
signature: Signature,
return_field: FieldRef,
}
impl PythonFunctionScalarUDF {
fn new(
name: String,
func: Py<PyAny>,
input_fields: Vec<Field>,
return_field: Field,
volatility: Volatility,
) -> Self {
let input_types = input_fields.iter().map(|f| f.data_type().clone()).collect();
let signature = Signature::exact(input_types, volatility);
Self {
name,
func,
signature,
return_field: Arc::new(return_field),
}
}
/// Stored Python callable. Consumed by the codec to cloudpickle
/// the function body across process boundaries.
pub(crate) fn func(&self) -> &Py<PyAny> {
&self.func
}
pub(crate) fn return_field(&self) -> &FieldRef {
&self.return_field
}
/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
/// by the codec. Inputs collapse to `Vec<DataType>` because
/// `Signature::exact` cannot carry per-input nullability or
/// metadata — the encoder is free to discard that side of the
/// schema. `return_field` is kept as a `Field` so the post-decode
/// nullability and metadata match the sender's instance.
pub(crate) fn from_parts(
name: String,
func: Py<PyAny>,
input_types: Vec<DataType>,
return_field: Field,
volatility: Volatility,
) -> Self {
Self {
name,
func,
signature: Signature::exact(input_types, volatility),
return_field: Arc::new(return_field),
}
}
}
impl Eq for PythonFunctionScalarUDF {}
impl PartialEq for PythonFunctionScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.signature == other.signature
&& self.return_field == other.return_field
// Identical pointers ⇒ same Python object. Most equality
// checks compare `Arc`-shared clones of the same UDF
// (e.g. expression rewriting), so the pointer match short-
// circuits before touching the GIL.
&& (self.func.as_ptr() == other.func.as_ptr()
|| Python::attach(|py| {
// Rust's `PartialEq` cannot return `Result`, so we
// have to pick a side when Python `__eq__` raises.
// `false` is the conservative choice — better to
// report two UDFs as distinct than to wrongly
// merge them — but the silent miss can still
// surface as expression-dedup or cache-lookup
// anomalies. Log at `debug` so the failure is
// observable without flooding production logs.
// FIXME: revisit if upstream `ScalarUDFImpl`
// exposes a fallible `PartialEq`.
self.func
.bind(py)
.eq(other.func.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udf",
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}
impl Hash for PythonFunctionScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash only the identifying header (name + signature + return
// field). Skipping `func` is intentional: the Rust `Hash`
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
// converse, so a coarser hash is sound — `PartialEq` still
// disambiguates two UDFs with the same header but distinct
// callables. Falling back to a sentinel on `py_hash` failure
// (as a prior revision did) silently mapped every unhashable
// closure to the same bucket; that is the worst case for a
// hashmap and is what this rewrite avoids.
self.name.hash(state);
self.signature.hash(state);
self.return_field.hash(state);
}
}
impl ScalarUDFImpl for PythonFunctionScalarUDF {
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
internal_err!(
"return_field should not be called when return_field_from_args is implemented."
)
}
fn return_field_from_args(
&self,
_args: ReturnFieldArgs,
) -> datafusion::common::Result<FieldRef> {
Ok(Arc::clone(&self.return_field))
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion::common::Result<ColumnarValue> {
let num_rows = args.number_rows;
Python::attach(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.args
.into_iter()
.zip(args.arg_fields)
.map(|(arg, field)| {
let array = arg.to_array(num_rows)?;
PyArrowArrayExportable::new(array, field)
.to_pyarrow(py)
.map_err(to_datafusion_err)
})
.collect::<Result<Vec<_>, _>>()?;
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
// 2. call function
let value = self
.func
.call(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
Ok(ColumnarValue::Array(make_array(array_data)))
})
}
}
/// Represents a PyScalarUDF
#[pyclass(
from_py_object,
frozen,
name = "ScalarUDF",
module = "datafusion",
subclass
)]
#[derive(Debug, Clone)]
pub struct PyScalarUDF {
pub(crate) function: ScalarUDF,
}
#[pymethods]
impl PyScalarUDF {
#[new]
#[pyo3(signature=(name, func, input_types, return_type, volatility))]
fn new(
name: String,
func: Py<PyAny>,
input_types: PyArrowType<Vec<Field>>,
return_type: PyArrowType<Field>,
volatility: &str,
) -> PyResult<Self> {
let py_function = PythonFunctionScalarUDF::new(
name,
func,
input_types.0,
return_type.0,
parse_volatility(volatility)?,
);
let function = ScalarUDF::new_from_impl(py_function);
Ok(Self { function })
}
#[staticmethod]
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.hasattr("__datafusion_scalar_udf__")? {
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_ScalarUDF> = capsule
.pointer_checked(Some(c"datafusion_scalar_udf"))?
.cast();
let udf = unsafe { data.as_ref() };
let udf: Arc<dyn ScalarUDFImpl> = udf.into();
Ok(Self {
function: ScalarUDF::new_from_shared_impl(udf),
})
} else {
Err(crate::errors::PyDataFusionError::Common(
"__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(),
))
}
}
/// creates a new PyExpr with the call of the udf
#[pyo3(signature = (*args))]
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
let args = args.iter().map(|e| e.expr.clone()).collect();
Ok(self.function.call(args).into())
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name()))
}
#[getter]
fn name(&self) -> &str {
self.function.name()
}
}