blob: 5189b651af75b0182cd94ff2e24c4fdd24be847c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from datetime import datetime, timezone
import numpy as np
import pyarrow as pa
import pytest
from datafusion import SessionContext, column, literal, string_literal
from datafusion import functions as f
np.seterr(invalid="ignore")
DEFAULT_TZ = timezone.utc
@pytest.fixture
def df():
ctx = SessionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", "World", "!"], type=pa.string_view()),
pa.array([4, 5, 6]),
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
pa.array(
[
datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
]
),
pa.array([False, True, True]),
],
names=["a", "b", "c", "d", "e"],
)
return ctx.create_dataframe([[batch]])
def test_named_struct(df):
df = df.with_column(
"d",
f.named_struct(
[
("a", column("a")),
("b", column("b")),
("c", column("c")),
]
),
)
expected = """DataFrame()
+-------+---+---------+------------------------------+-------+
| a | b | c | d | e |
+-------+---+---------+------------------------------+-------+
| Hello | 4 | hello | {a: Hello, b: 4, c: hello } | false |
| World | 5 | world | {a: World, b: 5, c: world } | true |
| ! | 6 | ! | {a: !, b: 6, c: !} | true |
+-------+---+---------+------------------------------+-------+
""".strip()
assert str(df) == expected
def test_literal(df):
df = df.select(
literal(1),
literal("1"),
literal("OK"),
literal(3.14),
literal(value=True),
literal(b"hello world"),
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
assert result.column(3) == pa.array([3.14] * 3)
assert result.column(4) == pa.array([True] * 3)
assert result.column(5) == pa.array([b"hello world"] * 3)
def test_lit_arith(df):
"""Test literals with arithmetic operations"""
df = df.select(
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([5, 6, 7])
assert result.column(1) == pa.array(
["Hello!", "World!", "!!"], type=pa.string_view()
)
def test_math_functions():
ctx = SessionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([0.1, -0.7, 0.55]), pa.array([float("nan"), 0, 2.0])],
names=["value", "na_value"],
)
df = ctx.create_dataframe([[batch]])
values = np.array([0.1, -0.7, 0.55])
na_values = np.array([np.nan, 0, 2.0])
col_v = column("value")
col_nav = column("na_value")
df = df.select(
f.abs(col_v),
f.sin(col_v),
f.cos(col_v),
f.tan(col_v),
f.asin(col_v),
f.acos(col_v),
f.exp(col_v),
f.ln(col_v + literal(pa.scalar(1))),
f.log2(col_v + literal(pa.scalar(1))),
f.log10(col_v + literal(pa.scalar(1))),
f.random(),
f.atan(col_v),
f.atan2(col_v, literal(pa.scalar(1.1))),
f.ceil(col_v),
f.floor(col_v),
f.power(col_v, literal(pa.scalar(3))),
f.pow(col_v, literal(pa.scalar(4))),
f.round(col_v),
f.round(col_v, literal(pa.scalar(3))),
f.sqrt(col_v),
f.signum(col_v),
f.trunc(col_v),
f.asinh(col_v),
f.acosh(col_v),
f.atanh(col_v),
f.cbrt(col_v),
f.cosh(col_v),
f.degrees(col_v),
f.gcd(literal(9), literal(3)),
f.lcm(literal(6), literal(4)),
f.nanvl(col_nav, literal(5)),
f.pi(),
f.radians(col_v),
f.sinh(col_v),
f.tanh(col_v),
f.factorial(literal(6)),
f.isnan(col_nav),
f.iszero(col_nav),
f.log(literal(3), col_v + literal(pa.scalar(1))),
)
batches = df.collect()
assert len(batches) == 1
result = batches[0]
np.testing.assert_array_almost_equal(result.column(0), np.abs(values))
np.testing.assert_array_almost_equal(result.column(1), np.sin(values))
np.testing.assert_array_almost_equal(result.column(2), np.cos(values))
np.testing.assert_array_almost_equal(result.column(3), np.tan(values))
np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values))
np.testing.assert_array_almost_equal(result.column(5), np.arccos(values))
np.testing.assert_array_almost_equal(result.column(6), np.exp(values))
np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0))
np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0))
np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0))
np.testing.assert_array_less(result.column(10), np.ones_like(values))
np.testing.assert_array_almost_equal(result.column(11), np.arctan(values))
np.testing.assert_array_almost_equal(result.column(12), np.arctan2(values, 1.1))
np.testing.assert_array_almost_equal(result.column(13), np.ceil(values))
np.testing.assert_array_almost_equal(result.column(14), np.floor(values))
np.testing.assert_array_almost_equal(result.column(15), np.power(values, 3))
np.testing.assert_array_almost_equal(result.column(16), np.power(values, 4))
np.testing.assert_array_almost_equal(result.column(17), np.round(values))
np.testing.assert_array_almost_equal(result.column(18), np.round(values, 3))
np.testing.assert_array_almost_equal(result.column(19), np.sqrt(values))
np.testing.assert_array_almost_equal(result.column(20), np.sign(values))
np.testing.assert_array_almost_equal(result.column(21), np.trunc(values))
np.testing.assert_array_almost_equal(result.column(22), np.arcsinh(values))
np.testing.assert_array_almost_equal(result.column(23), np.arccosh(values))
np.testing.assert_array_almost_equal(result.column(24), np.arctanh(values))
np.testing.assert_array_almost_equal(result.column(25), np.cbrt(values))
np.testing.assert_array_almost_equal(result.column(26), np.cosh(values))
np.testing.assert_array_almost_equal(result.column(27), np.degrees(values))
np.testing.assert_array_almost_equal(result.column(28), np.gcd(9, 3))
np.testing.assert_array_almost_equal(result.column(29), np.lcm(6, 4))
np.testing.assert_array_almost_equal(
result.column(30), np.where(np.isnan(na_values), 5, na_values)
)
np.testing.assert_array_almost_equal(result.column(31), np.pi)
np.testing.assert_array_almost_equal(result.column(32), np.radians(values))
np.testing.assert_array_almost_equal(result.column(33), np.sinh(values))
np.testing.assert_array_almost_equal(result.column(34), np.tanh(values))
np.testing.assert_array_almost_equal(result.column(35), math.factorial(6))
np.testing.assert_array_almost_equal(result.column(36), np.isnan(na_values))
np.testing.assert_array_almost_equal(result.column(37), na_values == 0)
np.testing.assert_array_almost_equal(
result.column(38), np.emath.logn(3, values + 1.0)
)
def py_indexof(arr, v):
try:
return arr.index(v) + 1
except ValueError:
return np.nan
def py_arr_remove(arr, v, n=None):
new_arr = arr[:]
found = 0
try:
while found != n:
new_arr.remove(v)
found += 1
except ValueError:
pass
return new_arr
def py_arr_replace(arr, from_, to, n=None):
new_arr = arr[:]
found = 0
try:
while found != n:
idx = new_arr.index(from_)
new_arr[idx] = to
found += 1
except ValueError:
pass
return new_arr
def py_arr_resize(arr, size, value):
arr = np.asarray(arr)
return np.pad(
arr,
[(0, size - arr.shape[0])],
"constant",
constant_values=value,
)
def py_flatten(arr):
result = []
for elem in arr:
if isinstance(elem, list):
result.extend(py_flatten(elem))
else:
result.append(elem)
return result
@pytest.mark.parametrize(
("stmt", "py_expr"),
[
(
lambda col: f.array_append(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
),
(
lambda col: f.array_push_back(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
),
(
lambda col: f.list_append(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
),
(
lambda col: f.list_push_back(col, literal(99.0)),
lambda data: [np.append(arr, 99.0) for arr in data],
),
(
lambda col: f.array_concat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
),
(
lambda col: f.array_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
),
(
lambda col: f.list_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
),
(
lambda col: f.list_concat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
),
(
lambda col: f.array_dims(col),
lambda data: [[len(r)] for r in data],
),
(
lambda col: f.array_distinct(col),
lambda data: [list(set(r)) for r in data],
),
(
lambda col: f.list_distinct(col),
lambda data: [list(set(r)) for r in data],
),
(
lambda col: f.list_dims(col),
lambda data: [[len(r)] for r in data],
),
(
lambda col: f.array_element(col, literal(1)),
lambda data: [r[0] for r in data],
),
(
lambda col: f.array_empty(col),
lambda data: [len(r) == 0 for r in data],
),
(
lambda col: f.empty(col),
lambda data: [len(r) == 0 for r in data],
),
(
lambda col: f.array_extract(col, literal(1)),
lambda data: [r[0] for r in data],
),
(
lambda col: f.list_element(col, literal(1)),
lambda data: [r[0] for r in data],
),
(
lambda col: f.list_extract(col, literal(1)),
lambda data: [r[0] for r in data],
),
(
lambda col: f.array_length(col),
lambda data: [len(r) for r in data],
),
(
lambda col: f.list_length(col),
lambda data: [len(r) for r in data],
),
(
lambda col: f.array_has(col, literal(1.0)),
lambda data: [1.0 in r for r in data],
),
(
lambda col: f.array_has_all(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
),
(
lambda col: f.array_has_any(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
),
(
lambda col: f.array_position(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
),
(
lambda col: f.array_indexof(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
),
(
lambda col: f.list_position(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
),
(
lambda col: f.list_indexof(col, literal(1.0)),
lambda data: [py_indexof(r, 1.0) for r in data],
),
(
lambda col: f.array_positions(col, literal(1.0)),
lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data],
),
(
lambda col: f.list_positions(col, literal(1.0)),
lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data],
),
(
lambda col: f.array_ndims(col),
lambda data: [np.array(r).ndim for r in data],
),
(
lambda col: f.list_ndims(col),
lambda data: [np.array(r).ndim for r in data],
),
(
lambda col: f.array_prepend(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
),
(
lambda col: f.array_push_front(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
),
(
lambda col: f.list_prepend(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
),
(
lambda col: f.list_push_front(literal(99.0), col),
lambda data: [np.insert(arr, 0, 99.0) for arr in data],
),
(
lambda col: f.array_pop_back(col),
lambda data: [arr[:-1] for arr in data],
),
(
lambda col: f.array_pop_front(col),
lambda data: [arr[1:] for arr in data],
),
(
lambda col: f.array_remove(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
),
(
lambda col: f.list_remove(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
),
(
lambda col: f.array_remove_n(col, literal(3.0), literal(2)),
lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
),
(
lambda col: f.list_remove_n(col, literal(3.0), literal(2)),
lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
),
(
lambda col: f.array_remove_all(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0) for arr in data],
),
(
lambda col: f.list_remove_all(col, literal(3.0)),
lambda data: [py_arr_remove(arr, 3.0) for arr in data],
),
(
lambda col: f.array_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
),
(
lambda col: f.list_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
),
(
lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
),
(
lambda col: f.list_replace(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
),
(
lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
),
(
lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
),
(
lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
),
(
lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
),
(
lambda col: f.array_sort(col, descending=True, null_first=True),
lambda data: [np.sort(arr)[::-1] for arr in data],
),
(
lambda col: f.list_sort(col, descending=False, null_first=False),
lambda data: [np.sort(arr) for arr in data],
),
(
lambda col: f.array_slice(col, literal(2), literal(4)),
lambda data: [arr[1:4] for arr in data],
),
pytest.param(
lambda col: f.list_slice(col, literal(-1), literal(2)),
lambda data: [arr[-1:2] for arr in data],
),
(
lambda col: f.array_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
),
(
lambda col: f.list_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
),
(
lambda col: f.array_union(col, literal([12.0, 999.0])),
lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
),
(
lambda col: f.list_union(col, literal([12.0, 999.0])),
lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
),
(
lambda col: f.array_except(col, literal([3.0])),
lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
),
(
lambda col: f.list_except(col, literal([3.0])),
lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
),
(
lambda col: f.array_resize(col, literal(10), literal(0.0)),
lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
),
(
lambda col: f.list_resize(col, literal(10), literal(0.0)),
lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
),
(
lambda col: f.range(literal(1), literal(5), literal(2)),
lambda data: [np.arange(1, 5, 2)],
),
],
)
def test_array_functions(stmt, py_expr):
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])
col = column("arr")
query_result = df.select(stmt(col)).collect()[0].column(0)
for a, b in zip(query_result, py_expr(data)):
np.testing.assert_array_almost_equal(
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
)
def test_array_function_flatten():
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])
stmt = f.flatten(literal(data))
py_expr = [py_flatten(data)]
query_result = df.select(stmt).collect()[0].column(0)
for a, b in zip(query_result, py_expr):
np.testing.assert_array_almost_equal(
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
)
def test_array_function_cardinality():
data = [[1, 2, 3], [4, 4, 5, 6]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])
stmt = f.cardinality(column("arr"))
py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3]
# assert py_expr lengths
query_result = df.select(stmt).collect()[0].column(0)
for a, b in zip(query_result, py_expr):
np.testing.assert_array_equal(
np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
)
@pytest.mark.parametrize("make_func", [f.make_array, f.make_list])
def test_make_array_functions(make_func):
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", "World", "!"], type=pa.string()),
pa.array([4, 5, 6]),
pa.array(["hello ", " world ", " !"], type=pa.string()),
],
names=["a", "b", "c"],
)
df = ctx.create_dataframe([[batch]])
stmt = make_func(
column("a").cast(pa.string()),
column("b").cast(pa.string()),
column("c").cast(pa.string()),
)
py_expr = [
["Hello", "4", "hello "],
["World", "5", " world "],
["!", "6", " !"],
]
query_result = df.select(stmt).collect()[0].column(0)
for a, b in zip(query_result, py_expr):
np.testing.assert_array_equal(
np.array(a.as_py(), dtype=str), np.array(b, dtype=str)
)
@pytest.mark.parametrize(
("stmt", "py_expr"),
[
(
f.array_to_string(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
),
(
f.array_join(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
),
(
f.list_to_string(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
),
(
f.list_join(column("arr"), literal(",")),
lambda data: [",".join([str(int(v)) for v in r]) for r in data],
),
],
)
def test_array_function_obj_tests(stmt, py_expr):
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])
query_result = np.array(df.select(stmt).collect()[0].column(0))
for a, b in zip(query_result, py_expr(data)):
assert a == b
@pytest.mark.parametrize(
("function", "expected_result"),
[
(
f.ascii(column("a")),
pa.array([72, 87, 33], type=pa.int32()),
), # H = 72; W = 87; ! = 33
(
f.bit_length(column("a").cast(pa.string())),
pa.array([40, 40, 8], type=pa.int32()),
),
(
f.btrim(literal(" World ")),
pa.array(["World", "World", "World"], type=pa.string_view()),
),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(
f.concat_ws("-", column("a"), literal("test")),
pa.array(["Hello-test", "World-test", "!-test"]),
),
(
f.concat(column("a").cast(pa.string()), literal("?")),
pa.array(["Hello?", "World?", "!?"], type=pa.string_view()),
),
(
f.initcap(column("c")),
pa.array(["Hello ", " World ", " !"], type=pa.string_view()),
),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(
f.ltrim(column("c")),
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
),
(
f.md5(column("a")),
pa.array(
[
"8b1a9953c4611296a827abf8c47804d7",
"f5a7924e621e84c9280a9a27e1bcb7f6",
"9033e0e305f247c0c3c80d0c7848c8b3",
]
),
),
(f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(
f.repeat(column("a"), literal(2)),
pa.array(["HelloHello", "WorldWorld", "!!"]),
),
(
f.replace(column("a"), literal("l"), literal("?")),
pa.array(["He??o", "Wor?d", "!"]),
),
(f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
(f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
(
f.rpad(column("a"), literal(8)),
pa.array(["Hello ", "World ", "! "]),
),
(
f.rtrim(column("c")),
pa.array(["hello", " world", " !"], type=pa.string_view()),
),
(
f.split_part(column("a"), literal("l"), literal(1)),
pa.array(["He", "Wor", "!"]),
),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(
f.substr(column("a"), literal(3)),
pa.array(["llo", "rld", ""], type=pa.string_view()),
),
(
f.translate(column("a"), literal("or"), literal("ld")),
pa.array(["Helll", "Wldld", "!"]),
),
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(
f.overlay(column("a"), literal("--"), literal(2)),
pa.array(["H--lo", "W--ld", "--"]),
),
(
f.regexp_like(column("a"), literal("(ell|orl)")),
pa.array([True, True, False]),
),
(
f.regexp_match(column("a"), literal("(ell|orl)")),
pa.array([["ell"], ["orl"], None], type=pa.list_(pa.string_view())),
),
(
f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")),
pa.array(["H-o", "W-d", "!"]),
),
(
f.regexp_count(column("a"), literal("(ell|orl)"), literal(1)),
pa.array([1, 1, 0], type=pa.int64()),
),
],
)
def test_string_functions(df, function, expected_result):
df = df.select(function)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == expected_result
def test_hash_functions(df):
exprs = [
f.digest(column("a"), literal(m))
for m in (
"md5",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2s",
"blake3",
)
]
df = df.select(
*exprs,
f.sha224(column("a")),
f.sha256(column("a")),
f.sha384(column("a")),
f.sha512(column("a")),
)
result = df.collect()
assert len(result) == 1
result = result[0]
b = bytearray.fromhex
assert result.column(0) == pa.array(
[
b("8B1A9953C4611296A827ABF8C47804D7"),
b("F5A7924E621E84C9280A9A27E1BCB7F6"),
b("9033E0E305F247C0C3C80D0C7848C8B3"),
]
)
assert result.column(1) == pa.array(
[
b("4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC"),
b("12972632B6D3B6AA52BD6434552F08C1303D56B817119406466E9236"),
b("6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0"),
]
)
assert result.column(2) == pa.array(
[
b("185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969"),
b("78AE647DC5544D227130A0682A51E30BC7777FBB6D8A8F17007463A3ECD1D524"),
b("BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62"),
]
)
assert result.column(3) == pa.array(
[
b(
"3519FE5AD2C596EFE3E276A6F351B8FC"
"0B03DB861782490D45F7598EBD0AB5FD"
"5520ED102F38C4A5EC834E98668035FC"
),
b(
"ED7CED84875773603AF90402E42C65F3"
"B48A5E77F84ADC7A19E8F3E8D3101010"
"22F552AEC70E9E1087B225930C1D260A"
),
b(
"1D0EC8C84EE9521E21F06774DE232367"
"B64DE628474CB5B2E372B699A1F55AE3"
"35CC37193EF823E33324DFD9A70738A6"
),
]
)
assert result.column(4) == pa.array(
[
b(
"3615F80C9D293ED7402687F94B22D58E"
"529B8CC7916F8FAC7FDDF7FBD5AF4CF7"
"77D3D795A7A00A16BF7E7F3FB9561EE9"
"BAAE480DA9FE7A18769E71886B03F315"
),
b(
"8EA77393A42AB8FA92500FB077A9509C"
"C32BC95E72712EFA116EDAF2EDFAE34F"
"BB682EFDD6C5DD13C117E08BD4AAEF71"
"291D8AACE2F890273081D0677C16DF0F"
),
b(
"3831A6A6155E509DEE59A7F451EB3532"
"4D8F8F2DF6E3708894740F98FDEE2388"
"9F4DE5ADB0C5010DFB555CDA77C8AB5D"
"C902094C52DE3278F35A75EBC25F093A"
),
]
)
assert result.column(5) == pa.array(
[
b("F73A5FBF881F89B814871F46E26AD3FA37CB2921C5E8561618639015B3CCBB71"),
b("B792A0383FB9E7A189EC150686579532854E44B71AC394831DAED169BA85CCC5"),
b("27988A0E51812297C77A433F635233346AEE29A829DCF4F46E0F58F402C6CFCB"),
]
)
assert result.column(6) == pa.array(
[
b("FBC2B0516EE8744D293B980779178A3508850FDCFE965985782C39601B65794F"),
b("BF73D18575A736E4037D45F9E316085B86C19BE6363DE6AA789E13DEAACC1C4E"),
b("C8D11B9F7237E4034ADBCD2005735F9BC4C597C75AD89F4492BEC8F77D15F7EB"),
]
)
assert result.column(7) == result.column(1) # SHA-224
assert result.column(8) == result.column(2) # SHA-256
assert result.column(9) == result.column(3) # SHA-384
assert result.column(10) == result.column(4) # SHA-512
def test_temporal_functions(df):
df = df.select(
f.date_part(literal("month"), column("d")),
f.datepart(literal("year"), column("d")),
f.date_trunc(literal("month"), column("d")),
f.datetrunc(literal("day"), column("d")),
f.date_bin(
literal("15 minutes").cast(pa.string()),
column("d"),
literal("2001-01-01 00:02:30").cast(pa.string()),
),
f.from_unixtime(literal(1673383974)),
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_seconds(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")),
f.extract(literal("day"), column("d")),
f.to_timestamp(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_seconds(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_millis(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_micros(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
f.to_timestamp_nanos(literal("2023-09-07 05:06:14.523952")),
f.to_timestamp_nanos(
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
),
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([12, 6, 7], type=pa.int32())
assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32())
assert result.column(2) == pa.array(
[
datetime(2022, 12, 1, tzinfo=DEFAULT_TZ),
datetime(2027, 6, 1, tzinfo=DEFAULT_TZ),
datetime(2020, 7, 1, tzinfo=DEFAULT_TZ),
],
type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(3) == pa.array(
[
datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
],
type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(4) == pa.array(
[
datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ),
datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ),
datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ),
],
type=pa.timestamp("ns", tz=DEFAULT_TZ),
)
assert result.column(5) == pa.array(
[datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("s"),
)
assert result.column(6) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ns"),
)
assert result.column(7) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s")
)
assert result.column(8) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ms"),
)
assert result.column(9) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("us"),
)
assert result.column(10) == pa.array([31, 26, 2], type=pa.int32())
assert result.column(11) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ns"),
)
assert result.column(12) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("s"),
)
assert result.column(13) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ms"),
)
assert result.column(14) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("us"),
)
assert result.column(15) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ns"),
)
assert result.column(16) == pa.array(
[datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
type=pa.timestamp("ns"),
)
def test_arrow_cast(df):
df = df.select(
# we use `string_literal` to return utf8 instead of `literal` which returns
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
def test_case(df):
df = df.select(
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),
f.case(column("a"))
.when(literal("Hello"), literal("Hola"))
.when(literal("World"), literal("Mundo"))
.otherwise(literal("!!")),
f.case(column("a"))
.when(literal("Hello"), literal("Hola"))
.when(literal("World"), literal("Mundo"))
.end(),
)
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([10, 8, 8])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())
def test_when_with_no_base(df):
df.show()
df = df.select(
column("b"),
f.when(column("b") > literal(5), literal("too big"))
.when(column("b") < literal(5), literal("too small"))
.otherwise(literal("just right"))
.alias("goldilocks"),
f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"),
)
df.show()
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([4, 5, 6])
assert result.column(1) == pa.array(
["too small", "just right", "too big"], type=pa.string_view()
)
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())
def test_regr_funcs_sql(df):
# test case base on
# https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
ctx = SessionContext()
result = ctx.sql(
"select regr_slope(1,1), regr_intercept(1,1), "
"regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), "
"regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), "
"regr_sxy(1,1);"
).collect()
assert result[0].column(0) == pa.array([None], type=pa.float64())
assert result[0].column(1) == pa.array([None], type=pa.float64())
assert result[0].column(2) == pa.array([1], type=pa.uint64())
assert result[0].column(3) == pa.array([None], type=pa.float64())
assert result[0].column(4) == pa.array([1], type=pa.float64())
assert result[0].column(5) == pa.array([1], type=pa.float64())
assert result[0].column(6) == pa.array([0], type=pa.float64())
assert result[0].column(7) == pa.array([0], type=pa.float64())
assert result[0].column(8) == pa.array([0], type=pa.float64())
def test_regr_funcs_sql_2():
# test case based on `regr_*() basic tests
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
ctx = SessionContext()
# Perform the regression functions using SQL
result_sql = ctx.sql(
"select "
"regr_slope(column2, column1), "
"regr_intercept(column2, column1), "
"regr_count(column2, column1), "
"regr_r2(column2, column1), "
"regr_avgx(column2, column1), "
"regr_avgy(column2, column1), "
"regr_sxx(column2, column1), "
"regr_syy(column2, column1), "
"regr_sxy(column2, column1) "
"from (values (1,2), (2,4), (3,6))"
).collect()
# Assertions for SQL results
assert result_sql[0].column(0) == pa.array([2], type=pa.float64())
assert result_sql[0].column(1) == pa.array([0], type=pa.float64())
assert result_sql[0].column(2) == pa.array([3], type=pa.uint64())
assert result_sql[0].column(3) == pa.array([1], type=pa.float64())
assert result_sql[0].column(4) == pa.array([2], type=pa.float64())
assert result_sql[0].column(5) == pa.array([4], type=pa.float64())
assert result_sql[0].column(6) == pa.array([2], type=pa.float64())
assert result_sql[0].column(7) == pa.array([8], type=pa.float64())
assert result_sql[0].column(8) == pa.array([4], type=pa.float64())
@pytest.mark.parametrize(
("func", "expected"),
[
pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"),
pytest.param(
f.regr_slope(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[8],
id="regr_slope_filter",
),
pytest.param(
f.regr_intercept(column("c2"), column("c1")), [-4], id="regr_intercept"
),
pytest.param(
f.regr_intercept(
column("c2"), column("c1"), filter=column("c1") > literal(2)
),
[-16],
id="regr_intercept_filter",
),
pytest.param(f.regr_count(column("c2"), column("c1")), [4], id="regr_count"),
pytest.param(
f.regr_count(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[2],
id="regr_count_filter",
),
pytest.param(f.regr_r2(column("c2"), column("c1")), [0.92], id="regr_r2"),
pytest.param(
f.regr_r2(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[1.0],
id="regr_r2_filter",
),
pytest.param(f.regr_avgx(column("c2"), column("c1")), [2.5], id="regr_avgx"),
pytest.param(
f.regr_avgx(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[3.5],
id="regr_avgx_filter",
),
pytest.param(f.regr_avgy(column("c2"), column("c1")), [7.5], id="regr_avgy"),
pytest.param(
f.regr_avgy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[12.0],
id="regr_avgy_filter",
),
pytest.param(f.regr_sxx(column("c2"), column("c1")), [5.0], id="regr_sxx"),
pytest.param(
f.regr_sxx(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[0.5],
id="regr_sxx_filter",
),
pytest.param(f.regr_syy(column("c2"), column("c1")), [115.0], id="regr_syy"),
pytest.param(
f.regr_syy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[32.0],
id="regr_syy_filter",
),
pytest.param(f.regr_sxy(column("c2"), column("c1")), [23.0], id="regr_sxy"),
pytest.param(
f.regr_sxy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
[4.0],
id="regr_sxy_filter",
),
],
)
def test_regr_funcs_df(func, expected):
# test case based on `regr_*() basic tests
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
ctx = SessionContext()
# Create a DataFrame
data = {"c1": [1, 2, 3, 4, 5, None], "c2": [2, 4, 8, 16, None, 64]}
df = ctx.from_pydict(data, name="test_table")
# Perform the regression function using DataFrame API
df = df.aggregate([], [func.alias("result")])
df.show()
expected_dict = {
"result": expected,
}
assert df.collect()[0].to_pydict() == expected_dict
def test_binary_string_functions(df):
df = df.select(
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
f.decode(
f.encode(
column("a").cast(pa.string()), literal("base64").cast(pa.string())
),
literal("base64").cast(pa.string()),
),
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array(["SGVsbG8", "V29ybGQ", "IQ"])
assert pa.array(result.column(1)).cast(pa.string()) == pa.array(
["Hello", "World", "!"]
)
@pytest.mark.parametrize(
("python_datatype", "name", "expected"),
[
pytest.param(bool, "e", pa.bool_(), id="bool"),
pytest.param(int, "b", pa.int64(), id="int"),
pytest.param(float, "b", pa.float64(), id="float"),
pytest.param(str, "b", pa.string(), id="str"),
],
)
def test_cast(df, python_datatype, name: str, expected):
df = df.select(
column(name).cast(python_datatype).alias("actual"),
column(name).cast(expected).alias("expected"),
)
result = df.collect()
result = result[0]
assert result.column(0) == result.column(1)
@pytest.mark.parametrize(
("negated", "low", "high", "expected"),
[
pytest.param(False, 3, 5, {"filtered": [4, 5]}),
pytest.param(False, 4, 5, {"filtered": [4, 5]}),
pytest.param(True, 3, 5, {"filtered": [6]}),
pytest.param(True, 4, 6, []),
],
)
def test_between(df, negated, low, high, expected):
df = df.filter(column("b").between(low, high, negated=negated)).select(
column("b").alias("filtered")
)
actual = df.collect()
if expected:
actual = actual[0].to_pydict()
assert actual == expected
else:
assert len(actual) == 0 # the rows are empty
def test_between_default(df):
df = df.filter(column("b").between(3, 5)).select(column("b").alias("filtered"))
expected = {"filtered": [4, 5]}
actual = df.collect()[0].to_pydict()
assert actual == expected
def test_alias_with_metadata(df):
df = df.select(f.alias(f.col("a"), "b", {"key": "value"}))
assert df.schema().field("b").metadata == {b"key": b"value"}
def test_coalesce(df):
# Create a DataFrame with null values
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", None, "!"]), # string column with null
pa.array([4, None, 6]), # integer column with null
pa.array(["hello ", None, " !"]), # string column with null
pa.array(
[
datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
None,
datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
]
), # datetime with null
pa.array([False, None, True]), # boolean column with null
],
names=["a", "b", "c", "d", "e"],
)
df_with_nulls = ctx.create_dataframe([[batch]])
# Test coalesce with different data types
result_df = df_with_nulls.select(
f.coalesce(column("a"), literal("default")).alias("a_coalesced"),
f.coalesce(column("b"), literal(0)).alias("b_coalesced"),
f.coalesce(column("c"), literal("default")).alias("c_coalesced"),
f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias(
"d_coalesced"
),
f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"),
)
result = result_df.collect()[0]
# Verify results
assert result.column(0) == pa.array(
["Hello", "default", "!"], type=pa.string_view()
)
assert result.column(1) == pa.array([4, 0, 6], type=pa.int64())
assert result.column(2) == pa.array(
["hello ", "default", " !"], type=pa.string_view()
)
assert result.column(3).to_pylist() == [
datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
datetime(2000, 1, 1, tzinfo=DEFAULT_TZ),
datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
]
assert result.column(4) == pa.array([False, False, True], type=pa.bool_())
# Test multiple arguments
result_df = df_with_nulls.select(
f.coalesce(column("a"), literal(None), literal("fallback")).alias(
"multi_coalesce"
)
)
result = result_df.collect()[0]
assert result.column(0) == pa.array(
["Hello", "fallback", "!"], type=pa.string_view()
)