| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use std::hash::{Hash, Hasher}; |
| use std::ptr::NonNull; |
| use std::sync::Arc; |
| |
| use arrow::datatypes::{Field, FieldRef}; |
| use arrow::pyarrow::ToPyArrow; |
| use datafusion::arrow::array::{ArrayData, make_array}; |
| use datafusion::arrow::datatypes::DataType; |
| use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType}; |
| use datafusion::common::internal_err; |
| use datafusion::error::DataFusionError; |
| use datafusion::logical_expr::{ |
| ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, |
| Volatility, |
| }; |
| use datafusion_ffi::udf::FFI_ScalarUDF; |
| use datafusion_python_util::parse_volatility; |
| use pyo3::prelude::*; |
| use pyo3::types::{PyCapsule, PyTuple}; |
| |
| use crate::array::PyArrowArrayExportable; |
| use crate::errors::{PyDataFusionResult, to_datafusion_err}; |
| use crate::expr::PyExpr; |
| |
| /// This struct holds the Python written function that is a |
| /// ScalarUDF. |
| #[derive(Debug)] |
| pub(crate) struct PythonFunctionScalarUDF { |
| name: String, |
| func: Py<PyAny>, |
| signature: Signature, |
| return_field: FieldRef, |
| } |
| |
| impl PythonFunctionScalarUDF { |
| fn new( |
| name: String, |
| func: Py<PyAny>, |
| input_fields: Vec<Field>, |
| return_field: Field, |
| volatility: Volatility, |
| ) -> Self { |
| let input_types = input_fields.iter().map(|f| f.data_type().clone()).collect(); |
| let signature = Signature::exact(input_types, volatility); |
| Self { |
| name, |
| func, |
| signature, |
| return_field: Arc::new(return_field), |
| } |
| } |
| |
| /// Stored Python callable. Consumed by the codec to cloudpickle |
| /// the function body across process boundaries. |
| pub(crate) fn func(&self) -> &Py<PyAny> { |
| &self.func |
| } |
| |
| pub(crate) fn return_field(&self) -> &FieldRef { |
| &self.return_field |
| } |
| |
| /// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted |
| /// by the codec. Inputs collapse to `Vec<DataType>` because |
| /// `Signature::exact` cannot carry per-input nullability or |
| /// metadata — the encoder is free to discard that side of the |
| /// schema. `return_field` is kept as a `Field` so the post-decode |
| /// nullability and metadata match the sender's instance. |
| pub(crate) fn from_parts( |
| name: String, |
| func: Py<PyAny>, |
| input_types: Vec<DataType>, |
| return_field: Field, |
| volatility: Volatility, |
| ) -> Self { |
| Self { |
| name, |
| func, |
| signature: Signature::exact(input_types, volatility), |
| return_field: Arc::new(return_field), |
| } |
| } |
| } |
| |
| impl Eq for PythonFunctionScalarUDF {} |
| impl PartialEq for PythonFunctionScalarUDF { |
| fn eq(&self, other: &Self) -> bool { |
| self.name == other.name |
| && self.signature == other.signature |
| && self.return_field == other.return_field |
| // Identical pointers ⇒ same Python object. Most equality |
| // checks compare `Arc`-shared clones of the same UDF |
| // (e.g. expression rewriting), so the pointer match short- |
| // circuits before touching the GIL. |
| && (self.func.as_ptr() == other.func.as_ptr() |
| || Python::attach(|py| { |
| // Rust's `PartialEq` cannot return `Result`, so we |
| // have to pick a side when Python `__eq__` raises. |
| // `false` is the conservative choice — better to |
| // report two UDFs as distinct than to wrongly |
| // merge them — but the silent miss can still |
| // surface as expression-dedup or cache-lookup |
| // anomalies. Log at `debug` so the failure is |
| // observable without flooding production logs. |
| // FIXME: revisit if upstream `ScalarUDFImpl` |
| // exposes a fallible `PartialEq`. |
| self.func |
| .bind(py) |
| .eq(other.func.bind(py)) |
| .unwrap_or_else(|e| { |
| log::debug!( |
| target: "datafusion_python::udf", |
| "PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}", |
| self.name, |
| ); |
| false |
| }) |
| })) |
| } |
| } |
| |
| impl Hash for PythonFunctionScalarUDF { |
| fn hash<H: Hasher>(&self, state: &mut H) { |
| // Hash only the identifying header (name + signature + return |
| // field). Skipping `func` is intentional: the Rust `Hash` |
| // contract requires `a == b ⇒ hash(a) == hash(b)`, not the |
| // converse, so a coarser hash is sound — `PartialEq` still |
| // disambiguates two UDFs with the same header but distinct |
| // callables. Falling back to a sentinel on `py_hash` failure |
| // (as a prior revision did) silently mapped every unhashable |
| // closure to the same bucket; that is the worst case for a |
| // hashmap and is what this rewrite avoids. |
| self.name.hash(state); |
| self.signature.hash(state); |
| self.return_field.hash(state); |
| } |
| } |
| |
| impl ScalarUDFImpl for PythonFunctionScalarUDF { |
| fn name(&self) -> &str { |
| &self.name |
| } |
| |
| fn signature(&self) -> &Signature { |
| &self.signature |
| } |
| |
| fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> { |
| internal_err!( |
| "return_field should not be called when return_field_from_args is implemented." |
| ) |
| } |
| |
| fn return_field_from_args( |
| &self, |
| _args: ReturnFieldArgs, |
| ) -> datafusion::common::Result<FieldRef> { |
| Ok(Arc::clone(&self.return_field)) |
| } |
| |
| fn invoke_with_args( |
| &self, |
| args: ScalarFunctionArgs, |
| ) -> datafusion::common::Result<ColumnarValue> { |
| let num_rows = args.number_rows; |
| Python::attach(|py| { |
| // 1. cast args to Pyarrow arrays |
| let py_args = args |
| .args |
| .into_iter() |
| .zip(args.arg_fields) |
| .map(|(arg, field)| { |
| let array = arg.to_array(num_rows)?; |
| PyArrowArrayExportable::new(array, field) |
| .to_pyarrow(py) |
| .map_err(to_datafusion_err) |
| }) |
| .collect::<Result<Vec<_>, _>>()?; |
| let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; |
| |
| // 2. call function |
| let value = self |
| .func |
| .call(py, py_args, None) |
| .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; |
| |
| // 3. cast to arrow::array::Array |
| let array_data = ArrayData::from_pyarrow_bound(value.bind(py)) |
| .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; |
| Ok(ColumnarValue::Array(make_array(array_data))) |
| }) |
| } |
| } |
| |
| /// Represents a PyScalarUDF |
| #[pyclass( |
| from_py_object, |
| frozen, |
| name = "ScalarUDF", |
| module = "datafusion", |
| subclass |
| )] |
| #[derive(Debug, Clone)] |
| pub struct PyScalarUDF { |
| pub(crate) function: ScalarUDF, |
| } |
| |
| #[pymethods] |
| impl PyScalarUDF { |
| #[new] |
| #[pyo3(signature=(name, func, input_types, return_type, volatility))] |
| fn new( |
| name: String, |
| func: Py<PyAny>, |
| input_types: PyArrowType<Vec<Field>>, |
| return_type: PyArrowType<Field>, |
| volatility: &str, |
| ) -> PyResult<Self> { |
| let py_function = PythonFunctionScalarUDF::new( |
| name, |
| func, |
| input_types.0, |
| return_type.0, |
| parse_volatility(volatility)?, |
| ); |
| let function = ScalarUDF::new_from_impl(py_function); |
| |
| Ok(Self { function }) |
| } |
| |
| #[staticmethod] |
| pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> { |
| if func.hasattr("__datafusion_scalar_udf__")? { |
| let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; |
| let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?; |
| let data: NonNull<FFI_ScalarUDF> = capsule |
| .pointer_checked(Some(c"datafusion_scalar_udf"))? |
| .cast(); |
| let udf = unsafe { data.as_ref() }; |
| let udf: Arc<dyn ScalarUDFImpl> = udf.into(); |
| |
| Ok(Self { |
| function: ScalarUDF::new_from_shared_impl(udf), |
| }) |
| } else { |
| Err(crate::errors::PyDataFusionError::Common( |
| "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(), |
| )) |
| } |
| } |
| |
| /// creates a new PyExpr with the call of the udf |
| #[pyo3(signature = (*args))] |
| fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> { |
| let args = args.iter().map(|e| e.expr.clone()).collect(); |
| Ok(self.function.call(args).into()) |
| } |
| |
| fn __repr__(&self) -> PyResult<String> { |
| Ok(format!("ScalarUDF({})", self.function.name())) |
| } |
| |
| #[getter] |
| fn name(&self) -> &str { |
| self.function.name() |
| } |
| } |