blob: 5aaf00664093716d63e4305a64afabe9368b86e2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pyarrow as pa
import pytest
from datafusion import SessionContext, column, lit, udwf
from datafusion import functions as f
from datafusion.expr import WindowFrame
from datafusion.user_defined import WindowEvaluator
class ExponentialSmoothDefault(WindowEvaluator):
def __init__(self, alpha: float = 0.9) -> None:
self.alpha = alpha
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
results = []
curr_value = 0.0
values = values[0]
for idx in range(num_rows):
if idx == 0:
curr_value = values[idx].as_py()
else:
curr_value = values[idx].as_py() * self.alpha + curr_value * (
1.0 - self.alpha
)
results.append(curr_value)
return pa.array(results)
class ExponentialSmoothBounded(WindowEvaluator):
def __init__(self, alpha: float = 0.9) -> None:
self.alpha = alpha
def supports_bounded_execution(self) -> bool:
return True
def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
# Override the default range of current row since uses_window_frame is False
# So for the purpose of this test we just smooth from the previous row to
# current.
if idx == 0:
return (0, 0)
return (idx - 1, idx)
def evaluate(
self, values: list[pa.Array], eval_range: tuple[int, int]
) -> pa.Scalar:
(start, stop) = eval_range
curr_value = 0.0
values = values[0]
for idx in range(start, stop + 1):
if idx == start:
curr_value = values[idx].as_py()
else:
curr_value = values[idx].as_py() * self.alpha + curr_value * (
1.0 - self.alpha
)
return pa.scalar(curr_value).cast(pa.float64())
class ExponentialSmoothRank(WindowEvaluator):
def __init__(self, alpha: float = 0.9) -> None:
self.alpha = alpha
def include_rank(self) -> bool:
return True
def evaluate_all_with_rank(
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
) -> pa.Array:
results = []
for idx in range(num_rows):
if idx == 0:
prior_value = 1.0
matching_row = [
i
for i in range(len(ranks_in_partition))
if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx
][0] + 1
curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha)
results.append(curr_value)
prior_value = matching_row
return pa.array(results)
class ExponentialSmoothFrame(WindowEvaluator):
def __init__(self, alpha: float = 0.9) -> None:
self.alpha = alpha
def uses_window_frame(self) -> bool:
return True
def evaluate(
self, values: list[pa.Array], eval_range: tuple[int, int]
) -> pa.Scalar:
(start, stop) = eval_range
curr_value = 0.0
if len(values) > 1:
order_by = values[1] # noqa: F841
values = values[0]
else:
values = values[0]
for idx in range(start, stop):
if idx == start:
curr_value = values[idx].as_py()
else:
curr_value = values[idx].as_py() * self.alpha + curr_value * (
1.0 - self.alpha
)
return pa.scalar(curr_value).cast(pa.float64())
class SmoothTwoColumn(WindowEvaluator):
"""This class demonstrates using two columns.
If the second column is above a threshold, then smooth over the first column from
the previous and next rows.
"""
def __init__(self, alpha: float = 0.9) -> None:
self.alpha = alpha
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
results = []
values_a = values[0]
values_b = values[1]
for idx in range(num_rows):
if values_b[idx].as_py() > 7:
if idx == 0:
results.append(values_a[1].cast(pa.float64()))
elif idx == num_rows - 1:
results.append(values_a[num_rows - 2].cast(pa.float64()))
else:
results.append(
pa.scalar(
values_a[idx - 1].as_py() * self.alpha
+ values_a[idx + 1].as_py() * (1.0 - self.alpha)
)
)
else:
results.append(values_a[idx].cast(pa.float64()))
return pa.array(results)
class SimpleWindowCount(WindowEvaluator):
"""A simple window evaluator that counts rows."""
def __init__(self, base: int = 0) -> None:
self.base = base
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
return pa.array([self.base + i for i in range(num_rows)])
class NotSubclassOfWindowEvaluator:
pass
@pytest.fixture
def ctx():
return SessionContext()
@pytest.fixture
def complex_window_df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array([0, 1, 2, 3, 4, 5, 6]),
pa.array([7, 4, 3, 8, 9, 1, 6]),
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
],
names=["a", "b", "c"],
)
return ctx.create_dataframe([[batch]])
@pytest.fixture
def count_window_df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="test_table")
def test_udwf_errors(complex_window_df):
with pytest.raises(TypeError):
udwf(
NotSubclassOfWindowEvaluator,
pa.float64(),
pa.float64(),
volatility="immutable",
)
def test_udwf_errors_with_message():
"""Test error cases for UDWF creation."""
with pytest.raises(
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
):
udwf(
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
)
def test_udwf_basic_usage(count_window_df):
"""Test basic UDWF usage with a simple counting window function."""
simple_count = udwf(
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
)
df = count_window_df.select(
simple_count(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])
def test_udwf_with_args(count_window_df):
"""Test UDWF with constructor arguments."""
count_base10 = udwf(
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
)
df = count_window_df.select(
count_base10(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([10, 11, 12])
def test_udwf_decorator_basic(count_window_df):
"""Test UDWF used as a decorator."""
@udwf([pa.int64()], pa.int64(), "immutable")
def window_count() -> WindowEvaluator:
return SimpleWindowCount()
df = count_window_df.select(
window_count(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])
def test_udwf_decorator_with_args(count_window_df):
"""Test UDWF decorator with constructor arguments."""
@udwf([pa.int64()], pa.int64(), "immutable")
def window_count_base10() -> WindowEvaluator:
return SimpleWindowCount(10)
df = count_window_df.select(
window_count_base10(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([10, 11, 12])
def test_register_udwf(ctx, count_window_df):
"""Test registering and using UDWF in SQL context."""
window_count = udwf(
SimpleWindowCount,
[pa.int64()],
pa.int64(),
volatility="immutable",
name="window_count",
)
ctx.register_udwf(window_count)
result = ctx.sql(
"""
SELECT window_count(a)
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING) FROM test_table
"""
).collect()[0]
assert result.column(0) == pa.array([0, 1, 2])
smooth_default = udwf(
ExponentialSmoothDefault,
pa.float64(),
pa.float64(),
volatility="immutable",
)
smooth_w_arguments = udwf(
lambda: ExponentialSmoothDefault(0.8),
pa.float64(),
pa.float64(),
volatility="immutable",
)
smooth_bounded = udwf(
ExponentialSmoothBounded,
pa.float64(),
pa.float64(),
volatility="immutable",
)
smooth_rank = udwf(
ExponentialSmoothRank,
pa.utf8(),
pa.float64(),
volatility="immutable",
)
smooth_frame = udwf(
ExponentialSmoothFrame,
pa.float64(),
pa.float64(),
volatility="immutable",
)
smooth_two_col = udwf(
SmoothTwoColumn,
[pa.int64(), pa.int64()],
pa.float64(),
volatility="immutable",
)
data_test_udwf_functions = [
(
"default_udwf_no_arguments",
smooth_default(column("a")),
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
),
(
"default_udwf_w_arguments",
smooth_w_arguments(column("a")),
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
),
(
"default_udwf_partitioned",
smooth_default(column("a")).partition_by(column("c")).build(),
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
),
(
"default_udwf_ordered",
smooth_default(column("a")).order_by(column("b")).build(),
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
),
(
"bounded_udwf",
smooth_bounded(column("a")),
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
),
(
"bounded_udwf_ignores_frame",
smooth_bounded(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build(),
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
),
(
"rank_udwf",
smooth_rank(column("c")).order_by(column("c")).build(),
[1, 1, 1, 1, 1.9, 2, 2],
),
(
"frame_unbounded_udwf",
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(),
[5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889],
),
(
"frame_bounded_udwf",
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(),
[0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
),
(
"frame_bounded_udwf",
smooth_frame(column("a"))
.window_frame(WindowFrame("rows", None, 0))
.order_by(column("b"))
.build(),
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
),
(
"two_column_udwf",
smooth_two_col(column("a"), column("b")),
[0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0],
),
]
@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
def test_udwf_functions(complex_window_df, name, expr, expected):
df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name))
# execute and collect the first (and only) batch
result = df.sort(column("a")).select(column(name)).collect()[0]
assert result.column(0) == pa.array(expected)
@pytest.mark.parametrize(
"udwf_func",
[
udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()),
udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
],
)
def test_udwf_overloads(udwf_func, count_window_df):
df = count_window_df.select(
udwf_func(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])
def test_udwf_named_function(ctx, count_window_df):
"""Test UDWF with explicit name parameter."""
window_count = udwf(
SimpleWindowCount,
pa.int64(),
pa.int64(),
volatility="immutable",
name="my_custom_counter",
)
ctx.register_udwf(window_count)
result = ctx.sql(
"""
SELECT my_custom_counter(a)
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING) FROM test_table"""
).collect()[0]
assert result.column(0) == pa.array([0, 1, 2])