feat: expose` DataFrame.parse_sql_expr` (#1274)
* feat: expose DataFrame.parse_sql_expr to python
* Update python/tests/test_dataframe.py
Co-authored-by: Tim Saucer <timsaucer@gmail.com>
---------
Co-authored-by: Tim Saucer <timsaucer@gmail.com>
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 1676565..86131c4 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -482,6 +482,28 @@
df = df.filter(ensure_expr(p))
return DataFrame(df)
+ def parse_sql_expr(self, expr: str) -> Expr:
+ """Creates logical expression from a SQL query text.
+
+ The expression is created and processed against the current schema.
+
+ Example::
+
+ from datafusion import col, lit
+ df.parse_sql_expr("a > 1")
+
+ should produce:
+
+ col("a") > lit(1)
+
+ Args:
+ expr: Expression string to be converted to datafusion expression
+
+ Returns:
+ Logical expression .
+ """
+ return Expr(self.df.parse_sql_expr(expr))
+
def with_column(self, name: str, expr: Expr) -> DataFrame:
"""Add an additional column to the DataFrame.
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index cd85221..76b8080 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -274,6 +274,36 @@
assert result.column(2) == pa.array([5])
+def test_parse_sql_expr(df):
+ plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()
+ plan2 = df.filter(column("a") > literal(2)).logical_plan()
+ # object equality not implemented but string representation should match
+ assert str(plan1) == str(plan2)
+
+ df1 = df.filter(df.parse_sql_expr("a > 2")).select(
+ column("a") + column("b"),
+ column("a") - column("b"),
+ )
+
+ # execute and collect the first (and only) batch
+ result = df1.collect()[0]
+
+ assert result.column(0) == pa.array([9])
+ assert result.column(1) == pa.array([-3])
+
+ df.show()
+ # verify that if there is no filter applied, internal dataframe is unchanged
+ df2 = df.filter()
+ assert df.df == df2.df
+
+ df3 = df.filter(df.parse_sql_expr("a > 1"), df.parse_sql_expr("b != 6"))
+ result = df3.collect()[0]
+
+ assert result.column(0) == pa.array([2])
+ assert result.column(1) == pa.array([5])
+ assert result.column(2) == pa.array([5])
+
+
def test_show_empty(df, capsys):
df_empty = df.filter(column("a") > literal(3))
df_empty.show()
diff --git a/src/dataframe.rs b/src/dataframe.rs
index c23c0c9..34da874 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -454,6 +454,14 @@
Ok(Self::new(df))
}
+ fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {
+ self.df
+ .as_ref()
+ .parse_sql_expr(&expr)
+ .map(|e| PyExpr::from(e))
+ .map_err(PyDataFusionError::from)
+ }
+
fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
Ok(Self::new(df))