blob: dba46e61ab7aa247878297615bbec87b2e8f90d2 [file] [log] [blame]
import pyarrow as pa
import pyarrow.dataset as ds
from datafusion import Expr, SessionContext, Table, udtf
from datafusion.context import TableProviderExportable
def python_table_function_inner(
num_cols: int, num_rows: int, num_batches: int
) -> Table:
rows = list(range(num_rows))
cols = [pa.array(rows) for _ in range(num_cols)]
names = [f"row_{i}" for i in range(num_cols)]
batch = pa.RecordBatch.from_arrays(cols, names=names)
batches = [batch for _ in range(num_batches)]
dataset = ds.dataset(batches)
return Table(dataset)
def test_python_table_function_with_args_class() -> None:
"""Test Python TableFunction using a class with arguments."""
ctx = SessionContext()
class ParameterizedTableFunction:
"""Table function that takes parameters."""
def __call__(self, num_cols: Expr, num_rows: Expr, num_batches: Expr) -> Table:
return python_table_function_inner(
num_cols.to_variant().value_i64(),
num_rows.to_variant().value_i64(),
num_batches.to_variant().value_i64(),
)
# Register the function
table_func = ParameterizedTableFunction()
table_udtf = udtf(table_func, "param_func")
ctx.register_udtf(table_udtf)
# Call with different parameters
result = ctx.sql("SELECT * FROM param_func(5, 10, 2)").collect()
assert len(result) == 2
assert result[0].num_columns == 5
assert result[0].num_rows == 10
def test_python_table_function_decorator() -> None:
"""Test Python TableFunction using decorator syntax."""
ctx = SessionContext()
@udtf("decorated_func")
def my_decorated_func(num_cols: Expr, num_rows: Expr, num_batches: Expr) -> Table:
return python_table_function_inner(
num_cols.to_variant().value_i64(),
num_rows.to_variant().value_i64(),
num_batches.to_variant().value_i64(),
)
ctx.register_udtf(my_decorated_func)
result = ctx.sql("SELECT * FROM decorated_func(3, 7, 2)").collect()
assert len(result) == 2
assert result[0].num_columns == 3
assert result[0].num_rows == 7
def test_python_table_function_no_args() -> None:
"""Test Python TableFunction with no arguments."""
ctx = SessionContext()
@udtf("static_func")
def static_table_func() -> Table:
return python_table_function_inner(2, 3, 1)
ctx.register_udtf(static_table_func)
result = ctx.sql("SELECT * FROM static_func()").collect()
assert len(result) == 1
assert list(result[0].column(0).to_pylist()) == [0, 1, 2]
assert list(result[0].column(1).to_pylist()) == [0, 1, 2]
def test_python_table_function_single_arg() -> None:
"""Test Python TableFunction with a single argument."""
ctx = SessionContext()
@udtf("single_arg_func")
def single_arg_func(n: Expr) -> TableProviderExportable:
return python_table_function_inner(2, n.to_variant().value_i64(), 1)
ctx.register_udtf(single_arg_func)
result = ctx.sql("SELECT * FROM single_arg_func(15)").collect()
assert len(result) == 1
assert result[0].num_columns == 2
assert result[0].num_rows == 15
def test_python_table_function_with_string_args() -> None:
"""Test Python TableFunction with string arguments."""
ctx = SessionContext()
@udtf("string_arg_func")
def string_arg_func(prefix: Expr) -> TableProviderExportable:
prefix_str = prefix.to_variant().value_string()
# Create a table with the prefix in column names
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=[f"{prefix_str}_a", f"{prefix_str}_b"],
)
return Table(ds.dataset([batch]))
ctx.register_udtf(string_arg_func)
result = ctx.sql("SELECT * FROM string_arg_func('test')").collect()
assert len(result) == 1
assert result[0].schema.names == ["test_a", "test_b"]