Add missing array functions (#551)
* Add array_append, array_concat and array_cat
* Add tests for array functions array_append, array_concat and array_cat
* Add array_dims and list_dims
* Add tests for array_dims and list_dims
* Add array_element, array_extract, list_element and list_extract
* Add tests for array_element, array_extract, list_element and list_extract
* Add array_length and list_length
diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py
index be2a2f1..d0514f8 100644
--- a/datafusion/tests/test_functions.py
+++ b/datafusion/tests/test_functions.py
@@ -25,6 +25,8 @@
from datafusion import functions as f
from datafusion import literal
+np.seterr(invalid="ignore")
+
@pytest.fixture
def df():
@@ -197,6 +199,68 @@
)
+def test_array_functions():
+ data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
+ ctx = SessionContext()
+ batch = pa.RecordBatch.from_arrays(
+ [np.array(data, dtype=object)], names=["arr"]
+ )
+ df = ctx.create_dataframe([[batch]])
+
+ col = column("arr")
+ test_items = [
+ [
+ f.array_append(col, literal(99.0)),
+ lambda: [np.append(arr, 99.0) for arr in data],
+ ],
+ [
+ f.array_concat(col, col),
+ lambda: [np.concatenate([arr, arr]) for arr in data],
+ ],
+ [
+ f.array_cat(col, col),
+ lambda: [np.concatenate([arr, arr]) for arr in data],
+ ],
+ [
+ f.array_dims(col),
+ lambda: [[len(r)] for r in data],
+ ],
+ [
+ f.list_dims(col),
+ lambda: [[len(r)] for r in data],
+ ],
+ [
+ f.array_element(col, literal(1)),
+ lambda: [r[0] for r in data],
+ ],
+ [
+ f.array_extract(col, literal(1)),
+ lambda: [r[0] for r in data],
+ ],
+ [
+ f.list_element(col, literal(1)),
+ lambda: [r[0] for r in data],
+ ],
+ [
+ f.list_extract(col, literal(1)),
+ lambda: [r[0] for r in data],
+ ],
+ [
+ f.array_length(col),
+ lambda: [len(r) for r in data],
+ ],
+ [
+ f.list_length(col),
+ lambda: [len(r) for r in data],
+ ],
+ ]
+
+ for stmt, py_expr in test_items:
+ query_result = df.select(stmt).collect()[0].column(0).tolist()
+ for a, b in zip(query_result, py_expr()):
+ np.testing.assert_array_almost_equal(a, b)
+
+
def test_string_functions(df):
df = df.select(
f.ascii(column("a")),
diff --git a/src/functions.rs b/src/functions.rs
index d1f3e80..3dc5322 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -357,6 +357,19 @@
scalar_function!(encode, Encode);
scalar_function!(decode, Decode);
+// Array Functions
+scalar_function!(array_append, ArrayAppend);
+scalar_function!(array_concat, ArrayConcat);
+scalar_function!(array_cat, ArrayConcat);
+scalar_function!(array_dims, ArrayDims);
+scalar_function!(list_dims, ArrayDims);
+scalar_function!(array_element, ArrayElement);
+scalar_function!(array_extract, ArrayElement);
+scalar_function!(list_element, ArrayElement);
+scalar_function!(list_extract, ArrayElement);
+scalar_function!(array_length, ArrayLength);
+scalar_function!(list_length, ArrayLength);
+
aggregate_function!(approx_distinct, ApproxDistinct);
aggregate_function!(approx_median, ApproxMedian);
aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
@@ -546,5 +559,19 @@
//Binary String Functions
m.add_wrapped(wrap_pyfunction!(encode))?;
m.add_wrapped(wrap_pyfunction!(decode))?;
+
+ // Array Functions
+ m.add_wrapped(wrap_pyfunction!(array_append))?;
+ m.add_wrapped(wrap_pyfunction!(array_concat))?;
+ m.add_wrapped(wrap_pyfunction!(array_cat))?;
+ m.add_wrapped(wrap_pyfunction!(array_dims))?;
+ m.add_wrapped(wrap_pyfunction!(list_dims))?;
+ m.add_wrapped(wrap_pyfunction!(array_element))?;
+ m.add_wrapped(wrap_pyfunction!(array_extract))?;
+ m.add_wrapped(wrap_pyfunction!(list_element))?;
+ m.add_wrapped(wrap_pyfunction!(list_extract))?;
+ m.add_wrapped(wrap_pyfunction!(array_length))?;
+ m.add_wrapped(wrap_pyfunction!(list_length))?;
+
Ok(())
}