blob: 3dea72673e53cc72623b9ec0778b7e54d2d193b9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pyarrow as pa
import pytest
from sedonadb import udf
def some_udf(arg0, arg1):
arg0, arg1 = (
pa.array(arg0.to_array()).to_pylist(),
pa.array(arg1.to_array()).to_pylist(),
)
return pa.array(
(f"{item0} / {item1}".encode() for item0, item1 in zip(arg0, arg1)),
pa.binary(),
)
def test_udf_matchers(con):
udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
assert udf_impl._name == "some_udf"
con.register_udf(udf_impl)
pd.testing.assert_frame_equal(
con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)
def test_udf_types(con):
udf_impl = udf.arrow_udf(pa.binary(), [pa.string(), pa.int64()])(some_udf)
assert udf_impl._name == "some_udf"
con.register_udf(udf_impl)
pd.testing.assert_frame_equal(
con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)
def test_udf_any_input(con):
udf_impl = udf.arrow_udf(pa.binary())(some_udf)
assert udf_impl._name == "some_udf"
con.register_udf(udf_impl)
pd.testing.assert_frame_equal(
con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)
def test_udf_return_type_fn(con):
udf_impl = udf.arrow_udf(lambda arg_types, arg_scalars: arg_types[0])(some_udf)
assert udf_impl._name == "some_udf"
con.register_udf(udf_impl)
pd.testing.assert_frame_equal(
con.sql("SELECT some_udf('abcd'::BYTEA, 123) as col").to_pandas(),
pd.DataFrame({"col": [b"b'abcd' / 123"]}),
)
def test_udf_array_input(con):
udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
assert udf_impl._name == "some_udf"
con.register_udf(udf_impl)
pd.testing.assert_frame_equal(
con.sql(
"SELECT some_udf(x, 123) as col FROM (VALUES ('a'), ('b'), ('c')) as t(x)"
).to_pandas(),
pd.DataFrame({"col": [b"a / 123", b"b / 123", b"c / 123"]}),
)
def test_udf_name():
udf_impl = udf.arrow_udf(pa.binary(), name="foofy")(some_udf)
assert udf_impl._name == "foofy"
def test_shapely_udf(con):
import shapely
import geoarrow.pyarrow as ga
import numpy as np
@udf.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC])
def shapely_udf(geom, distance):
geom_wkb = pa.array(geom.storage.to_array())
distance = pa.array(distance.to_array())
geom = shapely.from_wkb(geom_wkb)
result_shapely = shapely.buffer(geom, distance)
return pa.array(shapely.to_wkb(result_shapely))
con.register_udf(shapely_udf)
pd.testing.assert_frame_equal(
con.sql("SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as col").to_pandas(),
pd.DataFrame({"col": [12.485780609032208]}),
)
# Ensure we can propagate a crs
pd.testing.assert_frame_equal(
con.sql(
"SELECT ST_SRID(shapely_udf(ST_SetSRID(ST_Point(0, 0), 3857), 2.0)) as col"
).to_pandas(),
pd.DataFrame({"col": [3857]}, dtype=np.uint32),
)
def test_py_sedona_value(con):
@udf.arrow_udf(pa.int64())
def fn_arg_only(arg):
assert repr(arg) == "PySedonaValue Scalar Int64[1]"
assert arg.is_scalar() is True
assert repr(arg.type) == "SedonaType int64<Int64>"
return pa.array(range(len(pa.array(arg))))
con.register_udf(fn_arg_only)
con.sql("SELECT fn_arg_only(123)").to_arrow_table()
def test_udf_kwargs(con):
@udf.arrow_udf(pa.int64())
def fn_return_type(arg, *, return_type=None):
assert repr(return_type) == "SedonaType int64<Int64>"
return pa.array(range(len(pa.array(arg))))
con.register_udf(fn_return_type)
con.sql("SELECT fn_return_type('123')").to_arrow_table()
@udf.arrow_udf(pa.int64())
def fn_num_rows(arg, *, num_rows=None):
assert num_rows == 1
return pa.array(range(len(pa.array(arg))))
con.register_udf(fn_num_rows)
con.sql("SELECT fn_num_rows('123')").to_arrow_table()
@udf.arrow_udf(pa.int64())
def fn_num_rows_and_return_type(arg, *, num_rows=None, return_type=None):
assert repr(return_type) == "SedonaType int64<Int64>"
assert num_rows == 1
return pa.array(range(len(pa.array(arg))))
con.register_udf(fn_num_rows_and_return_type)
con.sql("SELECT fn_num_rows_and_return_type('123')").to_arrow_table()
def test_udf_bad_return_object(con):
@udf.arrow_udf(pa.binary())
def questionable_udf(arg):
return None
con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
match="Expected result of user-defined function to return an object implementing __arrow_c_array__",
):
con.sql("SELECT questionable_udf(123) as col").to_pandas()
def test_udf_bad_return_type(con):
@udf.arrow_udf(pa.binary())
def questionable_udf(arg):
return pa.array(["abc"], pa.string())
con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
match=(
"Expected result of user-defined function to "
"return array of type Binary or its storage "
"but got Utf8"
),
):
con.sql("SELECT questionable_udf(123) as col").to_pandas()
def test_udf_bad_return_length(con):
@udf.arrow_udf(pa.binary())
def questionable_udf(arg):
return pa.array([b"abc", b"def"], pa.binary())
con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
match="Expected result of user-defined function to return array of length 1 but got 2",
):
con.sql("SELECT questionable_udf(123) as col").to_pandas()
def test_udf_datafusion_to_sedonadb(con):
udf_impl = udf.arrow_udf(
pa.binary(), [udf.STRING, udf.NUMERIC], name="some_external_udf"
)(some_udf)
class UdfWrapper:
def __init__(self, obj):
self.obj = obj
def __datafusion_scalar_udf__(self):
return self.obj.__datafusion_scalar_udf__()
con.register_udf(UdfWrapper(udf_impl))
pd.testing.assert_frame_equal(
con.sql("SELECT some_external_udf('abcd', 123) as col").to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)
def test_udf_sedonadb_registry_function_to_datafusion(con):
datafusion = pytest.importorskip("datafusion")
udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
# Register with our session
con.register_udf(udf_impl)
# Create a datafusion session, fetch our udf and register with the other session
datafusion_ctx = datafusion.SessionContext()
datafusion_ctx.register_udf(
datafusion.ScalarUDF.from_pycapsule(con._impl.scalar_udf("some_udf"))
)
# Can't quite use to_pandas() because there is a schema/batch nullability mismatch
batches = datafusion_ctx.sql("SELECT some_udf('abcd', 123) as col").collect()
assert len(batches) == 1
pd.testing.assert_frame_equal(
batches[0].to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)
def test_udf_sedonadb_to_datafusion():
datafusion = pytest.importorskip("datafusion")
udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
# Create a datafusion session, register udf_impl directly
datafusion_ctx = datafusion.SessionContext()
datafusion_ctx.register_udf(datafusion.ScalarUDF.from_pycapsule(udf_impl))
# Can't quite use to_pandas() because there is a schema/batch nullability mismatch
batches = datafusion_ctx.sql("SELECT some_udf('abcd', 123) as col").collect()
assert len(batches) == 1
pd.testing.assert_frame_equal(
batches[0].to_pandas(),
pd.DataFrame({"col": [b"abcd / 123"]}),
)