Support None comparisons for null expressions (#1489)
* Support None comparisons for null expressions
* Fold None comparison coverage into relational expr test
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 7cd74ec..3200465 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -483,6 +483,8 @@
Accepts either an expression or any valid PyArrow scalar literal value.
"""
+ if rhs is None:
+ return self.is_null()
if not isinstance(rhs, Expr):
rhs = Expr.literal(rhs)
return Expr(self.expr.__eq__(rhs.expr))
@@ -492,6 +494,8 @@
Accepts either an expression or any valid PyArrow scalar literal value.
"""
+ if rhs is None:
+ return self.is_not_null()
if not isinstance(rhs, Expr):
rhs = Expr.literal(rhs)
return Expr(self.expr.__ne__(rhs.expr))
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index 1cf824a..d046eb4 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -153,8 +153,8 @@
batch = pa.RecordBatch.from_arrays(
[
- pa.array([1, 2, 3]),
- pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
+ pa.array([1, 2, 3, None]),
+ pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()),
],
names=["a", "b"],
)
@@ -171,6 +171,10 @@
assert df.filter(col("b") != "beta").count() == 2
assert df.filter(col("a") == "beta").count() == 0
+ assert df.filter(col("a") == None).count() == 1 # noqa: E711
+ assert df.filter(col("a") != None).count() == 3 # noqa: E711
+ assert df.filter(col("b") == None).count() == 1 # noqa: E711
+ assert df.filter(col("b") != None).count() == 3 # noqa: E711
def test_expr_to_variant():