| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import math |
| from datetime import datetime, timezone |
| |
| import numpy as np |
| import pyarrow as pa |
| import pytest |
| from datafusion import SessionContext, column, literal, string_literal |
| from datafusion import functions as f |
| |
| np.seterr(invalid="ignore") |
| |
| DEFAULT_TZ = timezone.utc |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array(["Hello", "World", "!"], type=pa.string_view()), |
| pa.array([4, 5, 6]), |
| pa.array(["hello ", " world ", " !"], type=pa.string_view()), |
| pa.array( |
| [ |
| datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), |
| datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), |
| datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), |
| ] |
| ), |
| pa.array([False, True, True]), |
| ], |
| names=["a", "b", "c", "d", "e"], |
| ) |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| def test_named_struct(df): |
| df = df.with_column( |
| "d", |
| f.named_struct( |
| [ |
| ("a", column("a")), |
| ("b", column("b")), |
| ("c", column("c")), |
| ] |
| ), |
| ) |
| |
| expected = """DataFrame() |
| +-------+---+---------+------------------------------+-------+ |
| | a | b | c | d | e | |
| +-------+---+---------+------------------------------+-------+ |
| | Hello | 4 | hello | {a: Hello, b: 4, c: hello } | false | |
| | World | 5 | world | {a: World, b: 5, c: world } | true | |
| | ! | 6 | ! | {a: !, b: 6, c: !} | true | |
| +-------+---+---------+------------------------------+-------+ |
| """.strip() |
| assert str(df) == expected |
| |
| |
| def test_literal(df): |
| df = df.select( |
| literal(1), |
| literal("1"), |
| literal("OK"), |
| literal(3.14), |
| literal(value=True), |
| literal(b"hello world"), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array([1] * 3) |
| assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view()) |
| assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view()) |
| assert result.column(3) == pa.array([3.14] * 3) |
| assert result.column(4) == pa.array([True] * 3) |
| assert result.column(5) == pa.array([b"hello world"] * 3) |
| |
| |
| def test_lit_arith(df): |
| """Test literals with arithmetic operations""" |
| df = df.select( |
| literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!")) |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| |
| assert result.column(0) == pa.array([5, 6, 7]) |
| assert result.column(1) == pa.array( |
| ["Hello!", "World!", "!!"], type=pa.string_view() |
| ) |
| |
| |
| def test_math_functions(): |
| ctx = SessionContext() |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([0.1, -0.7, 0.55]), pa.array([float("nan"), 0, 2.0])], |
| names=["value", "na_value"], |
| ) |
| df = ctx.create_dataframe([[batch]]) |
| |
| values = np.array([0.1, -0.7, 0.55]) |
| na_values = np.array([np.nan, 0, 2.0]) |
| col_v = column("value") |
| col_nav = column("na_value") |
| df = df.select( |
| f.abs(col_v), |
| f.sin(col_v), |
| f.cos(col_v), |
| f.tan(col_v), |
| f.asin(col_v), |
| f.acos(col_v), |
| f.exp(col_v), |
| f.ln(col_v + literal(pa.scalar(1))), |
| f.log2(col_v + literal(pa.scalar(1))), |
| f.log10(col_v + literal(pa.scalar(1))), |
| f.random(), |
| f.atan(col_v), |
| f.atan2(col_v, literal(pa.scalar(1.1))), |
| f.ceil(col_v), |
| f.floor(col_v), |
| f.power(col_v, literal(pa.scalar(3))), |
| f.pow(col_v, literal(pa.scalar(4))), |
| f.round(col_v), |
| f.round(col_v, literal(pa.scalar(3))), |
| f.sqrt(col_v), |
| f.signum(col_v), |
| f.trunc(col_v), |
| f.asinh(col_v), |
| f.acosh(col_v), |
| f.atanh(col_v), |
| f.cbrt(col_v), |
| f.cosh(col_v), |
| f.degrees(col_v), |
| f.gcd(literal(9), literal(3)), |
| f.lcm(literal(6), literal(4)), |
| f.nanvl(col_nav, literal(5)), |
| f.pi(), |
| f.radians(col_v), |
| f.sinh(col_v), |
| f.tanh(col_v), |
| f.factorial(literal(6)), |
| f.isnan(col_nav), |
| f.iszero(col_nav), |
| f.log(literal(3), col_v + literal(pa.scalar(1))), |
| ) |
| batches = df.collect() |
| assert len(batches) == 1 |
| result = batches[0] |
| |
| np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) |
| np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) |
| np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) |
| np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) |
| np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) |
| np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) |
| np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) |
| np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0)) |
| np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0)) |
| np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0)) |
| np.testing.assert_array_less(result.column(10), np.ones_like(values)) |
| np.testing.assert_array_almost_equal(result.column(11), np.arctan(values)) |
| np.testing.assert_array_almost_equal(result.column(12), np.arctan2(values, 1.1)) |
| np.testing.assert_array_almost_equal(result.column(13), np.ceil(values)) |
| np.testing.assert_array_almost_equal(result.column(14), np.floor(values)) |
| np.testing.assert_array_almost_equal(result.column(15), np.power(values, 3)) |
| np.testing.assert_array_almost_equal(result.column(16), np.power(values, 4)) |
| np.testing.assert_array_almost_equal(result.column(17), np.round(values)) |
| np.testing.assert_array_almost_equal(result.column(18), np.round(values, 3)) |
| np.testing.assert_array_almost_equal(result.column(19), np.sqrt(values)) |
| np.testing.assert_array_almost_equal(result.column(20), np.sign(values)) |
| np.testing.assert_array_almost_equal(result.column(21), np.trunc(values)) |
| np.testing.assert_array_almost_equal(result.column(22), np.arcsinh(values)) |
| np.testing.assert_array_almost_equal(result.column(23), np.arccosh(values)) |
| np.testing.assert_array_almost_equal(result.column(24), np.arctanh(values)) |
| np.testing.assert_array_almost_equal(result.column(25), np.cbrt(values)) |
| np.testing.assert_array_almost_equal(result.column(26), np.cosh(values)) |
| np.testing.assert_array_almost_equal(result.column(27), np.degrees(values)) |
| np.testing.assert_array_almost_equal(result.column(28), np.gcd(9, 3)) |
| np.testing.assert_array_almost_equal(result.column(29), np.lcm(6, 4)) |
| np.testing.assert_array_almost_equal( |
| result.column(30), np.where(np.isnan(na_values), 5, na_values) |
| ) |
| np.testing.assert_array_almost_equal(result.column(31), np.pi) |
| np.testing.assert_array_almost_equal(result.column(32), np.radians(values)) |
| np.testing.assert_array_almost_equal(result.column(33), np.sinh(values)) |
| np.testing.assert_array_almost_equal(result.column(34), np.tanh(values)) |
| np.testing.assert_array_almost_equal(result.column(35), math.factorial(6)) |
| np.testing.assert_array_almost_equal(result.column(36), np.isnan(na_values)) |
| np.testing.assert_array_almost_equal(result.column(37), na_values == 0) |
| np.testing.assert_array_almost_equal( |
| result.column(38), np.emath.logn(3, values + 1.0) |
| ) |
| |
| |
| def py_indexof(arr, v): |
| try: |
| return arr.index(v) + 1 |
| except ValueError: |
| return np.nan |
| |
| |
| def py_arr_remove(arr, v, n=None): |
| new_arr = arr[:] |
| found = 0 |
| try: |
| while found != n: |
| new_arr.remove(v) |
| found += 1 |
| except ValueError: |
| pass |
| |
| return new_arr |
| |
| |
| def py_arr_replace(arr, from_, to, n=None): |
| new_arr = arr[:] |
| found = 0 |
| try: |
| while found != n: |
| idx = new_arr.index(from_) |
| new_arr[idx] = to |
| found += 1 |
| except ValueError: |
| pass |
| |
| return new_arr |
| |
| |
| def py_arr_resize(arr, size, value): |
| arr = np.asarray(arr) |
| return np.pad( |
| arr, |
| [(0, size - arr.shape[0])], |
| "constant", |
| constant_values=value, |
| ) |
| |
| |
| def py_flatten(arr): |
| result = [] |
| for elem in arr: |
| if isinstance(elem, list): |
| result.extend(py_flatten(elem)) |
| else: |
| result.append(elem) |
| return result |
| |
| |
| @pytest.mark.parametrize( |
| ("stmt", "py_expr"), |
| [ |
| ( |
| lambda col: f.array_append(col, literal(99.0)), |
| lambda data: [np.append(arr, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_push_back(col, literal(99.0)), |
| lambda data: [np.append(arr, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_append(col, literal(99.0)), |
| lambda data: [np.append(arr, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_push_back(col, literal(99.0)), |
| lambda data: [np.append(arr, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_concat(col, col), |
| lambda data: [np.concatenate([arr, arr]) for arr in data], |
| ), |
| ( |
| lambda col: f.array_cat(col, col), |
| lambda data: [np.concatenate([arr, arr]) for arr in data], |
| ), |
| ( |
| lambda col: f.list_cat(col, col), |
| lambda data: [np.concatenate([arr, arr]) for arr in data], |
| ), |
| ( |
| lambda col: f.list_concat(col, col), |
| lambda data: [np.concatenate([arr, arr]) for arr in data], |
| ), |
| ( |
| lambda col: f.array_dims(col), |
| lambda data: [[len(r)] for r in data], |
| ), |
| ( |
| lambda col: f.array_distinct(col), |
| lambda data: [list(set(r)) for r in data], |
| ), |
| ( |
| lambda col: f.list_distinct(col), |
| lambda data: [list(set(r)) for r in data], |
| ), |
| ( |
| lambda col: f.list_dims(col), |
| lambda data: [[len(r)] for r in data], |
| ), |
| ( |
| lambda col: f.array_element(col, literal(1)), |
| lambda data: [r[0] for r in data], |
| ), |
| ( |
| lambda col: f.array_empty(col), |
| lambda data: [len(r) == 0 for r in data], |
| ), |
| ( |
| lambda col: f.empty(col), |
| lambda data: [len(r) == 0 for r in data], |
| ), |
| ( |
| lambda col: f.array_extract(col, literal(1)), |
| lambda data: [r[0] for r in data], |
| ), |
| ( |
| lambda col: f.list_element(col, literal(1)), |
| lambda data: [r[0] for r in data], |
| ), |
| ( |
| lambda col: f.list_extract(col, literal(1)), |
| lambda data: [r[0] for r in data], |
| ), |
| ( |
| lambda col: f.array_length(col), |
| lambda data: [len(r) for r in data], |
| ), |
| ( |
| lambda col: f.list_length(col), |
| lambda data: [len(r) for r in data], |
| ), |
| ( |
| lambda col: f.array_has(col, literal(1.0)), |
| lambda data: [1.0 in r for r in data], |
| ), |
| ( |
| lambda col: f.array_has_all( |
| col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) |
| ), |
| lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], |
| ), |
| ( |
| lambda col: f.array_has_any( |
| col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) |
| ), |
| lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], |
| ), |
| ( |
| lambda col: f.array_position(col, literal(1.0)), |
| lambda data: [py_indexof(r, 1.0) for r in data], |
| ), |
| ( |
| lambda col: f.array_indexof(col, literal(1.0)), |
| lambda data: [py_indexof(r, 1.0) for r in data], |
| ), |
| ( |
| lambda col: f.list_position(col, literal(1.0)), |
| lambda data: [py_indexof(r, 1.0) for r in data], |
| ), |
| ( |
| lambda col: f.list_indexof(col, literal(1.0)), |
| lambda data: [py_indexof(r, 1.0) for r in data], |
| ), |
| ( |
| lambda col: f.array_positions(col, literal(1.0)), |
| lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], |
| ), |
| ( |
| lambda col: f.list_positions(col, literal(1.0)), |
| lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], |
| ), |
| ( |
| lambda col: f.array_ndims(col), |
| lambda data: [np.array(r).ndim for r in data], |
| ), |
| ( |
| lambda col: f.list_ndims(col), |
| lambda data: [np.array(r).ndim for r in data], |
| ), |
| ( |
| lambda col: f.array_prepend(literal(99.0), col), |
| lambda data: [np.insert(arr, 0, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_push_front(literal(99.0), col), |
| lambda data: [np.insert(arr, 0, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_prepend(literal(99.0), col), |
| lambda data: [np.insert(arr, 0, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_push_front(literal(99.0), col), |
| lambda data: [np.insert(arr, 0, 99.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_pop_back(col), |
| lambda data: [arr[:-1] for arr in data], |
| ), |
| ( |
| lambda col: f.array_pop_front(col), |
| lambda data: [arr[1:] for arr in data], |
| ), |
| ( |
| lambda col: f.array_remove(col, literal(3.0)), |
| lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], |
| ), |
| ( |
| lambda col: f.list_remove(col, literal(3.0)), |
| lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], |
| ), |
| ( |
| lambda col: f.array_remove_n(col, literal(3.0), literal(2)), |
| lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], |
| ), |
| ( |
| lambda col: f.list_remove_n(col, literal(3.0), literal(2)), |
| lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], |
| ), |
| ( |
| lambda col: f.array_remove_all(col, literal(3.0)), |
| lambda data: [py_arr_remove(arr, 3.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_remove_all(col, literal(3.0)), |
| lambda data: [py_arr_remove(arr, 3.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_repeat(col, literal(2)), |
| lambda data: [[arr] * 2 for arr in data], |
| ), |
| ( |
| lambda col: f.list_repeat(col, literal(2)), |
| lambda data: [[arr] * 2 for arr in data], |
| ), |
| ( |
| lambda col: f.array_replace(col, literal(3.0), literal(4.0)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], |
| ), |
| ( |
| lambda col: f.list_replace(col, literal(3.0), literal(4.0)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], |
| ), |
| ( |
| lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], |
| ), |
| ( |
| lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], |
| ), |
| ( |
| lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)), |
| lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], |
| ), |
| ( |
| lambda col: f.array_sort(col, descending=True, null_first=True), |
| lambda data: [np.sort(arr)[::-1] for arr in data], |
| ), |
| ( |
| lambda col: f.list_sort(col, descending=False, null_first=False), |
| lambda data: [np.sort(arr) for arr in data], |
| ), |
| ( |
| lambda col: f.array_slice(col, literal(2), literal(4)), |
| lambda data: [arr[1:4] for arr in data], |
| ), |
| pytest.param( |
| lambda col: f.list_slice(col, literal(-1), literal(2)), |
| lambda data: [arr[-1:2] for arr in data], |
| ), |
| ( |
| lambda col: f.array_intersect(col, literal([3.0, 4.0])), |
| lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.list_intersect(col, literal([3.0, 4.0])), |
| lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.array_union(col, literal([12.0, 999.0])), |
| lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.list_union(col, literal([12.0, 999.0])), |
| lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.array_except(col, literal([3.0])), |
| lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.list_except(col, literal([3.0])), |
| lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], |
| ), |
| ( |
| lambda col: f.array_resize(col, literal(10), literal(0.0)), |
| lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], |
| ), |
| ( |
| lambda col: f.list_resize(col, literal(10), literal(0.0)), |
| lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], |
| ), |
| ( |
| lambda col: f.range(literal(1), literal(5), literal(2)), |
| lambda data: [np.arange(1, 5, 2)], |
| ), |
| ], |
| ) |
| def test_array_functions(stmt, py_expr): |
| data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) |
| df = ctx.create_dataframe([[batch]]) |
| |
| col = column("arr") |
| query_result = df.select(stmt(col)).collect()[0].column(0) |
| for a, b in zip(query_result, py_expr(data)): |
| np.testing.assert_array_almost_equal( |
| np.array(a.as_py(), dtype=float), np.array(b, dtype=float) |
| ) |
| |
| |
| def test_array_function_flatten(): |
| data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) |
| df = ctx.create_dataframe([[batch]]) |
| |
| stmt = f.flatten(literal(data)) |
| py_expr = [py_flatten(data)] |
| query_result = df.select(stmt).collect()[0].column(0) |
| for a, b in zip(query_result, py_expr): |
| np.testing.assert_array_almost_equal( |
| np.array(a.as_py(), dtype=float), np.array(b, dtype=float) |
| ) |
| |
| |
| def test_array_function_cardinality(): |
| data = [[1, 2, 3], [4, 4, 5, 6]] |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) |
| df = ctx.create_dataframe([[batch]]) |
| |
| stmt = f.cardinality(column("arr")) |
| py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3] |
| # assert py_expr lengths |
| |
| query_result = df.select(stmt).collect()[0].column(0) |
| |
| for a, b in zip(query_result, py_expr): |
| np.testing.assert_array_equal( |
| np.array([a.as_py()], dtype=int), np.array([b], dtype=int) |
| ) |
| |
| |
| @pytest.mark.parametrize("make_func", [f.make_array, f.make_list]) |
| def test_make_array_functions(make_func): |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array(["Hello", "World", "!"], type=pa.string()), |
| pa.array([4, 5, 6]), |
| pa.array(["hello ", " world ", " !"], type=pa.string()), |
| ], |
| names=["a", "b", "c"], |
| ) |
| df = ctx.create_dataframe([[batch]]) |
| |
| stmt = make_func( |
| column("a").cast(pa.string()), |
| column("b").cast(pa.string()), |
| column("c").cast(pa.string()), |
| ) |
| py_expr = [ |
| ["Hello", "4", "hello "], |
| ["World", "5", " world "], |
| ["!", "6", " !"], |
| ] |
| |
| query_result = df.select(stmt).collect()[0].column(0) |
| for a, b in zip(query_result, py_expr): |
| np.testing.assert_array_equal( |
| np.array(a.as_py(), dtype=str), np.array(b, dtype=str) |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| ("stmt", "py_expr"), |
| [ |
| ( |
| f.array_to_string(column("arr"), literal(",")), |
| lambda data: [",".join([str(int(v)) for v in r]) for r in data], |
| ), |
| ( |
| f.array_join(column("arr"), literal(",")), |
| lambda data: [",".join([str(int(v)) for v in r]) for r in data], |
| ), |
| ( |
| f.list_to_string(column("arr"), literal(",")), |
| lambda data: [",".join([str(int(v)) for v in r]) for r in data], |
| ), |
| ( |
| f.list_join(column("arr"), literal(",")), |
| lambda data: [",".join([str(int(v)) for v in r]) for r in data], |
| ), |
| ], |
| ) |
| def test_array_function_obj_tests(stmt, py_expr): |
| data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) |
| df = ctx.create_dataframe([[batch]]) |
| query_result = np.array(df.select(stmt).collect()[0].column(0)) |
| for a, b in zip(query_result, py_expr(data)): |
| assert a == b |
| |
| |
| @pytest.mark.parametrize( |
| ("function", "expected_result"), |
| [ |
| ( |
| f.ascii(column("a")), |
| pa.array([72, 87, 33], type=pa.int32()), |
| ), # H = 72; W = 87; ! = 33 |
| ( |
| f.bit_length(column("a").cast(pa.string())), |
| pa.array([40, 40, 8], type=pa.int32()), |
| ), |
| ( |
| f.btrim(literal(" World ")), |
| pa.array(["World", "World", "World"], type=pa.string_view()), |
| ), |
| (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), |
| (f.chr(literal(68)), pa.array(["D", "D", "D"])), |
| ( |
| f.concat_ws("-", column("a"), literal("test")), |
| pa.array(["Hello-test", "World-test", "!-test"]), |
| ), |
| ( |
| f.concat(column("a").cast(pa.string()), literal("?")), |
| pa.array(["Hello?", "World?", "!?"], type=pa.string_view()), |
| ), |
| ( |
| f.initcap(column("c")), |
| pa.array(["Hello ", " World ", " !"], type=pa.string_view()), |
| ), |
| (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])), |
| (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())), |
| (f.lower(column("a")), pa.array(["hello", "world", "!"])), |
| (f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])), |
| ( |
| f.ltrim(column("c")), |
| pa.array(["hello ", "world ", "!"], type=pa.string_view()), |
| ), |
| ( |
| f.md5(column("a")), |
| pa.array( |
| [ |
| "8b1a9953c4611296a827abf8c47804d7", |
| "f5a7924e621e84c9280a9a27e1bcb7f6", |
| "9033e0e305f247c0c3c80d0c7848c8b3", |
| ] |
| ), |
| ), |
| (f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), |
| ( |
| f.repeat(column("a"), literal(2)), |
| pa.array(["HelloHello", "WorldWorld", "!!"]), |
| ), |
| ( |
| f.replace(column("a"), literal("l"), literal("?")), |
| pa.array(["He??o", "Wor?d", "!"]), |
| ), |
| (f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])), |
| (f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])), |
| ( |
| f.rpad(column("a"), literal(8)), |
| pa.array(["Hello ", "World ", "! "]), |
| ), |
| ( |
| f.rtrim(column("c")), |
| pa.array(["hello", " world", " !"], type=pa.string_view()), |
| ), |
| ( |
| f.split_part(column("a"), literal("l"), literal(1)), |
| pa.array(["He", "Wor", "!"]), |
| ), |
| (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), |
| (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), |
| ( |
| f.substr(column("a"), literal(3)), |
| pa.array(["llo", "rld", ""], type=pa.string_view()), |
| ), |
| ( |
| f.translate(column("a"), literal("or"), literal("ld")), |
| pa.array(["Helll", "Wldld", "!"]), |
| ), |
| (f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())), |
| (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])), |
| (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])), |
| ( |
| f.overlay(column("a"), literal("--"), literal(2)), |
| pa.array(["H--lo", "W--ld", "--"]), |
| ), |
| ( |
| f.regexp_like(column("a"), literal("(ell|orl)")), |
| pa.array([True, True, False]), |
| ), |
| ( |
| f.regexp_match(column("a"), literal("(ell|orl)")), |
| pa.array([["ell"], ["orl"], None], type=pa.list_(pa.string_view())), |
| ), |
| ( |
| f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), |
| pa.array(["H-o", "W-d", "!"]), |
| ), |
| ( |
| f.regexp_count(column("a"), literal("(ell|orl)"), literal(1)), |
| pa.array([1, 1, 0], type=pa.int64()), |
| ), |
| ], |
| ) |
| def test_string_functions(df, function, expected_result): |
| df = df.select(function) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == expected_result |
| |
| |
| def test_hash_functions(df): |
| exprs = [ |
| f.digest(column("a"), literal(m)) |
| for m in ( |
| "md5", |
| "sha224", |
| "sha256", |
| "sha384", |
| "sha512", |
| "blake2s", |
| "blake3", |
| ) |
| ] |
| df = df.select( |
| *exprs, |
| f.sha224(column("a")), |
| f.sha256(column("a")), |
| f.sha384(column("a")), |
| f.sha512(column("a")), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| b = bytearray.fromhex |
| assert result.column(0) == pa.array( |
| [ |
| b("8B1A9953C4611296A827ABF8C47804D7"), |
| b("F5A7924E621E84C9280A9A27E1BCB7F6"), |
| b("9033E0E305F247C0C3C80D0C7848C8B3"), |
| ] |
| ) |
| assert result.column(1) == pa.array( |
| [ |
| b("4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC"), |
| b("12972632B6D3B6AA52BD6434552F08C1303D56B817119406466E9236"), |
| b("6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0"), |
| ] |
| ) |
| assert result.column(2) == pa.array( |
| [ |
| b("185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969"), |
| b("78AE647DC5544D227130A0682A51E30BC7777FBB6D8A8F17007463A3ECD1D524"), |
| b("BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62"), |
| ] |
| ) |
| assert result.column(3) == pa.array( |
| [ |
| b( |
| "3519FE5AD2C596EFE3E276A6F351B8FC" |
| "0B03DB861782490D45F7598EBD0AB5FD" |
| "5520ED102F38C4A5EC834E98668035FC" |
| ), |
| b( |
| "ED7CED84875773603AF90402E42C65F3" |
| "B48A5E77F84ADC7A19E8F3E8D3101010" |
| "22F552AEC70E9E1087B225930C1D260A" |
| ), |
| b( |
| "1D0EC8C84EE9521E21F06774DE232367" |
| "B64DE628474CB5B2E372B699A1F55AE3" |
| "35CC37193EF823E33324DFD9A70738A6" |
| ), |
| ] |
| ) |
| assert result.column(4) == pa.array( |
| [ |
| b( |
| "3615F80C9D293ED7402687F94B22D58E" |
| "529B8CC7916F8FAC7FDDF7FBD5AF4CF7" |
| "77D3D795A7A00A16BF7E7F3FB9561EE9" |
| "BAAE480DA9FE7A18769E71886B03F315" |
| ), |
| b( |
| "8EA77393A42AB8FA92500FB077A9509C" |
| "C32BC95E72712EFA116EDAF2EDFAE34F" |
| "BB682EFDD6C5DD13C117E08BD4AAEF71" |
| "291D8AACE2F890273081D0677C16DF0F" |
| ), |
| b( |
| "3831A6A6155E509DEE59A7F451EB3532" |
| "4D8F8F2DF6E3708894740F98FDEE2388" |
| "9F4DE5ADB0C5010DFB555CDA77C8AB5D" |
| "C902094C52DE3278F35A75EBC25F093A" |
| ), |
| ] |
| ) |
| assert result.column(5) == pa.array( |
| [ |
| b("F73A5FBF881F89B814871F46E26AD3FA37CB2921C5E8561618639015B3CCBB71"), |
| b("B792A0383FB9E7A189EC150686579532854E44B71AC394831DAED169BA85CCC5"), |
| b("27988A0E51812297C77A433F635233346AEE29A829DCF4F46E0F58F402C6CFCB"), |
| ] |
| ) |
| assert result.column(6) == pa.array( |
| [ |
| b("FBC2B0516EE8744D293B980779178A3508850FDCFE965985782C39601B65794F"), |
| b("BF73D18575A736E4037D45F9E316085B86C19BE6363DE6AA789E13DEAACC1C4E"), |
| b("C8D11B9F7237E4034ADBCD2005735F9BC4C597C75AD89F4492BEC8F77D15F7EB"), |
| ] |
| ) |
| assert result.column(7) == result.column(1) # SHA-224 |
| assert result.column(8) == result.column(2) # SHA-256 |
| assert result.column(9) == result.column(3) # SHA-384 |
| assert result.column(10) == result.column(4) # SHA-512 |
| |
| |
| def test_temporal_functions(df): |
| df = df.select( |
| f.date_part(literal("month"), column("d")), |
| f.datepart(literal("year"), column("d")), |
| f.date_trunc(literal("month"), column("d")), |
| f.datetrunc(literal("day"), column("d")), |
| f.date_bin( |
| literal("15 minutes").cast(pa.string()), |
| column("d"), |
| literal("2001-01-01 00:02:30").cast(pa.string()), |
| ), |
| f.from_unixtime(literal(1673383974)), |
| f.to_timestamp(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_seconds(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")), |
| f.extract(literal("day"), column("d")), |
| f.to_timestamp( |
| literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f") |
| ), |
| f.to_timestamp_seconds( |
| literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f") |
| ), |
| f.to_timestamp_millis( |
| literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f") |
| ), |
| f.to_timestamp_micros( |
| literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f") |
| ), |
| f.to_timestamp_nanos(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_nanos( |
| literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f") |
| ), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array([12, 6, 7], type=pa.int32()) |
| assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32()) |
| assert result.column(2) == pa.array( |
| [ |
| datetime(2022, 12, 1, tzinfo=DEFAULT_TZ), |
| datetime(2027, 6, 1, tzinfo=DEFAULT_TZ), |
| datetime(2020, 7, 1, tzinfo=DEFAULT_TZ), |
| ], |
| type=pa.timestamp("ns", tz=DEFAULT_TZ), |
| ) |
| assert result.column(3) == pa.array( |
| [ |
| datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), |
| datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), |
| datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), |
| ], |
| type=pa.timestamp("ns", tz=DEFAULT_TZ), |
| ) |
| assert result.column(4) == pa.array( |
| [ |
| datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ), |
| datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ), |
| datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ), |
| ], |
| type=pa.timestamp("ns", tz=DEFAULT_TZ), |
| ) |
| assert result.column(5) == pa.array( |
| [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("s"), |
| ) |
| assert result.column(6) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ns"), |
| ) |
| assert result.column(7) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") |
| ) |
| assert result.column(8) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ms"), |
| ) |
| assert result.column(9) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("us"), |
| ) |
| assert result.column(10) == pa.array([31, 26, 2], type=pa.int32()) |
| assert result.column(11) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ns"), |
| ) |
| assert result.column(12) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("s"), |
| ) |
| assert result.column(13) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ms"), |
| ) |
| assert result.column(14) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("us"), |
| ) |
| assert result.column(15) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ns"), |
| ) |
| assert result.column(16) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, |
| type=pa.timestamp("ns"), |
| ) |
| |
| |
| def test_arrow_cast(df): |
| df = df.select( |
| # we use `string_literal` to return utf8 instead of `literal` which returns |
| # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view |
| # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 |
| f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), |
| f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| |
| assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) |
| assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) |
| |
| |
| def test_case(df): |
| df = df.select( |
| f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), |
| f.case(column("a")) |
| .when(literal("Hello"), literal("Hola")) |
| .when(literal("World"), literal("Mundo")) |
| .otherwise(literal("!!")), |
| f.case(column("a")) |
| .when(literal("Hello"), literal("Hola")) |
| .when(literal("World"), literal("Mundo")) |
| .end(), |
| ) |
| |
| result = df.collect() |
| result = result[0] |
| assert result.column(0) == pa.array([10, 8, 8]) |
| assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view()) |
| assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view()) |
| |
| |
| def test_when_with_no_base(df): |
| df.show() |
| df = df.select( |
| column("b"), |
| f.when(column("b") > literal(5), literal("too big")) |
| .when(column("b") < literal(5), literal("too small")) |
| .otherwise(literal("just right")) |
| .alias("goldilocks"), |
| f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"), |
| ) |
| df.show() |
| |
| result = df.collect() |
| result = result[0] |
| assert result.column(0) == pa.array([4, 5, 6]) |
| assert result.column(1) == pa.array( |
| ["too small", "just right", "too big"], type=pa.string_view() |
| ) |
| assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view()) |
| |
| |
| def test_regr_funcs_sql(df): |
| # test case base on |
| # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330 |
| ctx = SessionContext() |
| result = ctx.sql( |
| "select regr_slope(1,1), regr_intercept(1,1), " |
| "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), " |
| "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), " |
| "regr_sxy(1,1);" |
| ).collect() |
| |
| assert result[0].column(0) == pa.array([None], type=pa.float64()) |
| assert result[0].column(1) == pa.array([None], type=pa.float64()) |
| assert result[0].column(2) == pa.array([1], type=pa.uint64()) |
| assert result[0].column(3) == pa.array([None], type=pa.float64()) |
| assert result[0].column(4) == pa.array([1], type=pa.float64()) |
| assert result[0].column(5) == pa.array([1], type=pa.float64()) |
| assert result[0].column(6) == pa.array([0], type=pa.float64()) |
| assert result[0].column(7) == pa.array([0], type=pa.float64()) |
| assert result[0].column(8) == pa.array([0], type=pa.float64()) |
| |
| |
| def test_regr_funcs_sql_2(): |
| # test case based on `regr_*() basic tests |
| # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1 |
| ctx = SessionContext() |
| |
| # Perform the regression functions using SQL |
| result_sql = ctx.sql( |
| "select " |
| "regr_slope(column2, column1), " |
| "regr_intercept(column2, column1), " |
| "regr_count(column2, column1), " |
| "regr_r2(column2, column1), " |
| "regr_avgx(column2, column1), " |
| "regr_avgy(column2, column1), " |
| "regr_sxx(column2, column1), " |
| "regr_syy(column2, column1), " |
| "regr_sxy(column2, column1) " |
| "from (values (1,2), (2,4), (3,6))" |
| ).collect() |
| |
| # Assertions for SQL results |
| assert result_sql[0].column(0) == pa.array([2], type=pa.float64()) |
| assert result_sql[0].column(1) == pa.array([0], type=pa.float64()) |
| assert result_sql[0].column(2) == pa.array([3], type=pa.uint64()) |
| assert result_sql[0].column(3) == pa.array([1], type=pa.float64()) |
| assert result_sql[0].column(4) == pa.array([2], type=pa.float64()) |
| assert result_sql[0].column(5) == pa.array([4], type=pa.float64()) |
| assert result_sql[0].column(6) == pa.array([2], type=pa.float64()) |
| assert result_sql[0].column(7) == pa.array([8], type=pa.float64()) |
| assert result_sql[0].column(8) == pa.array([4], type=pa.float64()) |
| |
| |
| @pytest.mark.parametrize( |
| ("func", "expected"), |
| [ |
| pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"), |
| pytest.param( |
| f.regr_slope(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [8], |
| id="regr_slope_filter", |
| ), |
| pytest.param( |
| f.regr_intercept(column("c2"), column("c1")), [-4], id="regr_intercept" |
| ), |
| pytest.param( |
| f.regr_intercept( |
| column("c2"), column("c1"), filter=column("c1") > literal(2) |
| ), |
| [-16], |
| id="regr_intercept_filter", |
| ), |
| pytest.param(f.regr_count(column("c2"), column("c1")), [4], id="regr_count"), |
| pytest.param( |
| f.regr_count(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [2], |
| id="regr_count_filter", |
| ), |
| pytest.param(f.regr_r2(column("c2"), column("c1")), [0.92], id="regr_r2"), |
| pytest.param( |
| f.regr_r2(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [1.0], |
| id="regr_r2_filter", |
| ), |
| pytest.param(f.regr_avgx(column("c2"), column("c1")), [2.5], id="regr_avgx"), |
| pytest.param( |
| f.regr_avgx(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [3.5], |
| id="regr_avgx_filter", |
| ), |
| pytest.param(f.regr_avgy(column("c2"), column("c1")), [7.5], id="regr_avgy"), |
| pytest.param( |
| f.regr_avgy(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [12.0], |
| id="regr_avgy_filter", |
| ), |
| pytest.param(f.regr_sxx(column("c2"), column("c1")), [5.0], id="regr_sxx"), |
| pytest.param( |
| f.regr_sxx(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [0.5], |
| id="regr_sxx_filter", |
| ), |
| pytest.param(f.regr_syy(column("c2"), column("c1")), [115.0], id="regr_syy"), |
| pytest.param( |
| f.regr_syy(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [32.0], |
| id="regr_syy_filter", |
| ), |
| pytest.param(f.regr_sxy(column("c2"), column("c1")), [23.0], id="regr_sxy"), |
| pytest.param( |
| f.regr_sxy(column("c2"), column("c1"), filter=column("c1") > literal(2)), |
| [4.0], |
| id="regr_sxy_filter", |
| ), |
| ], |
| ) |
| def test_regr_funcs_df(func, expected): |
| # test case based on `regr_*() basic tests |
| # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1 |
| |
| ctx = SessionContext() |
| |
| # Create a DataFrame |
| data = {"c1": [1, 2, 3, 4, 5, None], "c2": [2, 4, 8, 16, None, 64]} |
| df = ctx.from_pydict(data, name="test_table") |
| |
| # Perform the regression function using DataFrame API |
| df = df.aggregate([], [func.alias("result")]) |
| df.show() |
| |
| expected_dict = { |
| "result": expected, |
| } |
| |
| assert df.collect()[0].to_pydict() == expected_dict |
| |
| |
| def test_binary_string_functions(df): |
| df = df.select( |
| f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())), |
| f.decode( |
| f.encode( |
| column("a").cast(pa.string()), literal("base64").cast(pa.string()) |
| ), |
| literal("base64").cast(pa.string()), |
| ), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array(["SGVsbG8", "V29ybGQ", "IQ"]) |
| assert pa.array(result.column(1)).cast(pa.string()) == pa.array( |
| ["Hello", "World", "!"] |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| ("python_datatype", "name", "expected"), |
| [ |
| pytest.param(bool, "e", pa.bool_(), id="bool"), |
| pytest.param(int, "b", pa.int64(), id="int"), |
| pytest.param(float, "b", pa.float64(), id="float"), |
| pytest.param(str, "b", pa.string(), id="str"), |
| ], |
| ) |
| def test_cast(df, python_datatype, name: str, expected): |
| df = df.select( |
| column(name).cast(python_datatype).alias("actual"), |
| column(name).cast(expected).alias("expected"), |
| ) |
| result = df.collect() |
| result = result[0] |
| assert result.column(0) == result.column(1) |
| |
| |
| @pytest.mark.parametrize( |
| ("negated", "low", "high", "expected"), |
| [ |
| pytest.param(False, 3, 5, {"filtered": [4, 5]}), |
| pytest.param(False, 4, 5, {"filtered": [4, 5]}), |
| pytest.param(True, 3, 5, {"filtered": [6]}), |
| pytest.param(True, 4, 6, []), |
| ], |
| ) |
| def test_between(df, negated, low, high, expected): |
| df = df.filter(column("b").between(low, high, negated=negated)).select( |
| column("b").alias("filtered") |
| ) |
| |
| actual = df.collect() |
| |
| if expected: |
| actual = actual[0].to_pydict() |
| assert actual == expected |
| else: |
| assert len(actual) == 0 # the rows are empty |
| |
| |
| def test_between_default(df): |
| df = df.filter(column("b").between(3, 5)).select(column("b").alias("filtered")) |
| expected = {"filtered": [4, 5]} |
| |
| actual = df.collect()[0].to_pydict() |
| assert actual == expected |
| |
| |
| def test_alias_with_metadata(df): |
| df = df.select(f.alias(f.col("a"), "b", {"key": "value"})) |
| assert df.schema().field("b").metadata == {b"key": b"value"} |
| |
| |
| def test_coalesce(df): |
| # Create a DataFrame with null values |
| ctx = SessionContext() |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array(["Hello", None, "!"]), # string column with null |
| pa.array([4, None, 6]), # integer column with null |
| pa.array(["hello ", None, " !"]), # string column with null |
| pa.array( |
| [ |
| datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), |
| None, |
| datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), |
| ] |
| ), # datetime with null |
| pa.array([False, None, True]), # boolean column with null |
| ], |
| names=["a", "b", "c", "d", "e"], |
| ) |
| df_with_nulls = ctx.create_dataframe([[batch]]) |
| |
| # Test coalesce with different data types |
| result_df = df_with_nulls.select( |
| f.coalesce(column("a"), literal("default")).alias("a_coalesced"), |
| f.coalesce(column("b"), literal(0)).alias("b_coalesced"), |
| f.coalesce(column("c"), literal("default")).alias("c_coalesced"), |
| f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias( |
| "d_coalesced" |
| ), |
| f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"), |
| ) |
| |
| result = result_df.collect()[0] |
| |
| # Verify results |
| assert result.column(0) == pa.array( |
| ["Hello", "default", "!"], type=pa.string_view() |
| ) |
| assert result.column(1) == pa.array([4, 0, 6], type=pa.int64()) |
| assert result.column(2) == pa.array( |
| ["hello ", "default", " !"], type=pa.string_view() |
| ) |
| assert result.column(3).to_pylist() == [ |
| datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), |
| datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), |
| datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), |
| ] |
| assert result.column(4) == pa.array([False, False, True], type=pa.bool_()) |
| |
| # Test multiple arguments |
| result_df = df_with_nulls.select( |
| f.coalesce(column("a"), literal(None), literal("fallback")).alias( |
| "multi_coalesce" |
| ) |
| ) |
| result = result_df.collect()[0] |
| assert result.column(0) == pa.array( |
| ["Hello", "fallback", "!"], type=pa.string_view() |
| ) |