UDAF `sum` workaround (#741)
* provides workaround for half-migrated UDAF `sum`
Ref #730
* provide compatibility for sqlparser::ast::NullTreatment
This is now exposed as part of the API to `first_value` and `last_value` functions.
If there's a more elegant way to achieve this, please let me know.
diff --git a/Cargo.lock b/Cargo.lock
index 41742da..f05c62e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1051,6 +1051,7 @@
"pyo3-build-config",
"rand",
"regex-syntax",
+ "sqlparser",
"syn 2.0.67",
"tokio",
"url",
diff --git a/Cargo.toml b/Cargo.toml
index 4e38211..e518449 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -56,6 +56,7 @@
regex-syntax = "0.8.1"
syn = "2.0.67"
url = "2.2"
+sqlparser = "0.47.0"
[build-dependencies]
pyo3-build-config = "0.21"
diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py
index aa9491b..3f973d9 100644
--- a/examples/tpch/_tests.py
+++ b/examples/tpch/_tests.py
@@ -74,7 +74,6 @@
("q10_returned_item_reporting", "q10"),
pytest.param(
"q11_important_stock_identification", "q11",
- marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730
),
("q12_ship_mode_order_priority", "q12"),
("q13_customer_distribution", "q13"),
diff --git a/src/common.rs b/src/common.rs
index 44c557c..094e70c 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -29,6 +29,7 @@
m.add_class::<data_type::DataTypeMap>()?;
m.add_class::<data_type::PythonType>()?;
m.add_class::<data_type::SqlType>()?;
+ m.add_class::<data_type::NullTreatment>()?;
m.add_class::<schema::SqlTable>()?;
m.add_class::<schema::SqlSchema>()?;
m.add_class::<schema::SqlView>()?;
diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index cd4f864..313318f 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -757,3 +757,33 @@
VARBINARY,
VARCHAR,
}
+
+/// Specifies Ignore / Respect NULL within window functions.
+/// For example
+/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
+#[allow(non_camel_case_types)]
+#[allow(clippy::upper_case_acronyms)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+#[pyclass(name = "PythonType", module = "datafusion.common")]
+pub enum NullTreatment {
+ IGNORE_NULLS,
+ RESPECT_NULLS,
+}
+
+impl From<NullTreatment> for sqlparser::ast::NullTreatment {
+ fn from(null_treatment: NullTreatment) -> sqlparser::ast::NullTreatment {
+ match null_treatment {
+ NullTreatment::IGNORE_NULLS => sqlparser::ast::NullTreatment::IgnoreNulls,
+ NullTreatment::RESPECT_NULLS => sqlparser::ast::NullTreatment::RespectNulls,
+ }
+ }
+}
+
+impl From<sqlparser::ast::NullTreatment> for NullTreatment {
+ fn from(null_treatment: sqlparser::ast::NullTreatment) -> NullTreatment {
+ match null_treatment {
+ sqlparser::ast::NullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
+ sqlparser::ast::NullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
+ }
+ }
+}
diff --git a/src/functions.rs b/src/functions.rs
index 8e395ae..b39d98b 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -17,6 +17,7 @@
use pyo3::{prelude::*, wrap_pyfunction};
+use crate::common::data_type::NullTreatment;
use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
@@ -73,15 +74,15 @@
}
#[pyfunction]
-#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
+#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn first_value(
args: Vec<PyExpr>,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
+ null_treatment: Option<NullTreatment>,
) -> PyExpr {
- // TODO: allow user to select null_treatment
- let null_treatment = None;
+ let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::first_value(
@@ -95,15 +96,15 @@
}
#[pyfunction]
-#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
+#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn last_value(
args: Vec<PyExpr>,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
+ null_treatment: Option<NullTreatment>,
) -> PyExpr {
- // TODO: allow user to select null_treatment
- let null_treatment = None;
+ let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::last_value(
@@ -320,14 +321,20 @@
window_frame: Option<PyWindowFrame>,
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
- let fun = find_df_window_func(name).or_else(|| {
- ctx.and_then(|ctx| {
- ctx.ctx
- .udaf(name)
- .map(WindowFunctionDefinition::AggregateUDF)
- .ok()
+ // workaround for https://github.com/apache/datafusion-python/issues/730
+ let fun = if name == "sum" {
+ let sum_udf = functions_aggregate::sum::sum_udaf();
+ Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
+ } else {
+ find_df_window_func(name).or_else(|| {
+ ctx.and_then(|ctx| {
+ ctx.ctx
+ .udaf(name)
+ .map(WindowFunctionDefinition::AggregateUDF)
+ .ok()
+ })
})
- });
+ };
if fun.is_none() {
return Err(DataFusionError::Common("window function not found".to_string()).into());
}