blob: c90668efe21d79fa8649c14398be06e542cfaf91 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pyarrow as pa
import pytest
from datafusion import SessionContext, column, udf
from datafusion import functions as f
@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, None])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="test_table")
def test_udf(df):
# is_null is a pa function over arrays
is_null = udf(
lambda x: x.is_null(),
[pa.int64()],
pa.bool_(),
volatility="immutable",
)
df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)
assert result == pa.array([False, False, True])
def test_udf_decorator(df):
@udf([pa.int64()], pa.bool_(), "immutable")
def is_null(x: pa.Array) -> pa.Array:
return x.is_null()
df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)
assert result == pa.array([False, False, True])
def test_register_udf(ctx, df) -> None:
is_null = udf(
lambda x: x.is_null(),
[pa.float64()],
pa.bool_(),
volatility="immutable",
name="is_null",
)
ctx.register_udf(is_null)
df_result = ctx.sql("select is_null(b) from test_table")
result = df_result.collect()[0].column(0)
assert result == pa.array([False, False, True])
class OverThresholdUDF:
def __init__(self, threshold: int = 0) -> None:
self.threshold = threshold
def __call__(self, values: pa.Array) -> pa.Array:
return pa.array(v.as_py() >= self.threshold for v in values)
def test_udf_with_parameters_function(df) -> None:
udf_no_param = udf(
OverThresholdUDF(),
pa.int64(),
pa.bool_(),
volatility="immutable",
)
df1 = df.select(udf_no_param(column("a")))
result = df1.collect()[0].column(0)
assert result == pa.array([True, True, True])
udf_with_param = udf(
OverThresholdUDF(2),
pa.int64(),
pa.bool_(),
volatility="immutable",
)
df2 = df.select(udf_with_param(column("a")))
result = df2.collect()[0].column(0)
assert result == pa.array([False, True, True])
def test_udf_with_parameters_decorator(df) -> None:
@udf([pa.int64()], pa.bool_(), "immutable")
def udf_no_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF()(values)
df1 = df.select(udf_no_param(column("a")))
result = df1.collect()[0].column(0)
assert result == pa.array([True, True, True])
@udf([pa.int64()], pa.bool_(), "immutable")
def udf_with_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF(2)(values)
df2 = df.select(udf_with_param(column("a")))
result = df2.collect()[0].column(0)
assert result == pa.array([False, True, True])
def test_udf_with_metadata(ctx) -> None:
from uuid import UUID
@udf([pa.string()], pa.uuid(), "stable")
def uuid_from_string(uuid_string):
return pa.array((UUID(s).bytes for s in uuid_string.to_pylist()), pa.uuid())
@udf([pa.uuid()], pa.int64(), "stable")
def uuid_version(uuid):
return pa.array(s.version for s in uuid.to_pylist())
batch = pa.record_batch({"idx": pa.array(range(5))})
results = (
ctx.create_dataframe([[batch]])
.with_column("uuid_string", f.uuid())
.with_column("uuid", uuid_from_string(column("uuid_string")))
.select(uuid_version(column("uuid").alias("uuid_version")))
.collect()
)
assert results[0][0].to_pylist() == [4, 4, 4, 4, 4]
def test_udf_with_nullability(ctx: SessionContext) -> None:
import pyarrow.compute as pc
field_nullable_i64 = pa.field("with_nulls", type=pa.int64(), nullable=True)
field_non_nullable_i64 = pa.field("no_nulls", type=pa.int64(), nullable=False)
@udf([field_nullable_i64], field_nullable_i64, "stable")
def nullable_abs(input_col):
return pc.abs(input_col)
@udf([field_non_nullable_i64], field_non_nullable_i64, "stable")
def non_nullable_abs(input_col):
return pc.abs(input_col)
batch = pa.record_batch(
{
"with_nulls": pa.array([-2, None, 0, 1, 2]),
"no_nulls": pa.array([-2, -1, 0, 1, 2]),
},
schema=pa.schema(
[
field_nullable_i64,
field_non_nullable_i64,
]
),
)
ctx.register_record_batches("t", [[batch]])
df = ctx.table("t")
# Input matches expected, nullable
df_result = df.select(nullable_abs(column("with_nulls")))
returned_field = df_result.schema().field(0)
assert returned_field.nullable
results = df_result.collect()
assert results[0][0].to_pylist() == [2, None, 0, 1, 2]
# Input coercible to expected, nullable
df_result = df.select(nullable_abs(column("no_nulls")))
returned_field = df_result.schema().field(0)
assert returned_field.nullable
results = df_result.collect()
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
# Input matches expected, no nulls
df_result = df.select(non_nullable_abs(column("no_nulls")))
returned_field = df_result.schema().field(0)
assert not returned_field.nullable
results = df_result.collect()
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
# Invalid - requires non-nullable input but that is not possible
df_result = df.select(non_nullable_abs(column("with_nulls")))
returned_field = df_result.schema().field(0)
assert not returned_field.nullable
with pytest.raises(Exception) as e_info:
_results = df_result.collect()
assert "InvalidArgumentError" in str(e_info)