|  | # Licensed to the Apache Software Foundation (ASF) under one | 
|  | # or more contributor license agreements.  See the NOTICE file | 
|  | # distributed with this work for additional information | 
|  | # regarding copyright ownership.  The ASF licenses this file | 
|  | # to you under the Apache License, Version 2.0 (the | 
|  | # "License"); you may not use this file except in compliance | 
|  | # with the License.  You may obtain a copy of the License at | 
|  | # | 
|  | #   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | # Unless required by applicable law or agreed to in writing, | 
|  | # software distributed under the License is distributed on an | 
|  | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | # KIND, either express or implied.  See the License for the | 
|  | # specific language governing permissions and limitations | 
|  | # under the License. | 
|  |  | 
|  |  | 
|  | import pytest | 
|  |  | 
|  | import pyarrow as pa | 
|  | from pyarrow import compute as pc | 
|  |  | 
|  | # UDFs are all tested with a dataset scan | 
|  | pytestmark = pytest.mark.dataset | 
|  |  | 
|  | try: | 
|  | import pyarrow.dataset as ds | 
|  | except ImportError: | 
|  | ds = None | 
|  |  | 
|  |  | 
|  | def mock_scalar_udf_context(batch_length=10): | 
|  | from pyarrow._compute import _get_scalar_udf_context | 
|  | return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) | 
|  |  | 
|  |  | 
|  | class MyError(RuntimeError): | 
|  | pass | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def binary_func_fixture(): | 
|  | """ | 
|  | Register a binary scalar function. | 
|  | """ | 
|  | def binary_function(ctx, m, x): | 
|  | return pc.call_function("multiply", [m, x], | 
|  | memory_pool=ctx.memory_pool) | 
|  | func_name = "y=mx" | 
|  | binary_doc = {"summary": "y=mx", | 
|  | "description": "find y from y = mx"} | 
|  | pc.register_scalar_function(binary_function, | 
|  | func_name, | 
|  | binary_doc, | 
|  | {"m": pa.int64(), | 
|  | "x": pa.int64(), | 
|  | }, | 
|  | pa.int64()) | 
|  | return binary_function, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def ternary_func_fixture(): | 
|  | """ | 
|  | Register a ternary scalar function. | 
|  | """ | 
|  | def ternary_function(ctx, m, x, c): | 
|  | mx = pc.call_function("multiply", [m, x], | 
|  | memory_pool=ctx.memory_pool) | 
|  | return pc.call_function("add", [mx, c], | 
|  | memory_pool=ctx.memory_pool) | 
|  | ternary_doc = {"summary": "y=mx+c", | 
|  | "description": "find y from y = mx + c"} | 
|  | func_name = "y=mx+c" | 
|  | pc.register_scalar_function(ternary_function, | 
|  | func_name, | 
|  | ternary_doc, | 
|  | { | 
|  | "array1": pa.int64(), | 
|  | "array2": pa.int64(), | 
|  | "array3": pa.int64(), | 
|  | }, | 
|  | pa.int64()) | 
|  | return ternary_function, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def varargs_func_fixture(): | 
|  | """ | 
|  | Register a varargs scalar function with at least two arguments. | 
|  | """ | 
|  | def varargs_function(ctx, first, *values): | 
|  | acc = first | 
|  | for val in values: | 
|  | acc = pc.call_function("add", [acc, val], | 
|  | memory_pool=ctx.memory_pool) | 
|  | return acc | 
|  | func_name = "z=ax+by+c" | 
|  | varargs_doc = {"summary": "z=ax+by+c", | 
|  | "description": "find z from z = ax + by + c" | 
|  | } | 
|  | pc.register_scalar_function(varargs_function, | 
|  | func_name, | 
|  | varargs_doc, | 
|  | { | 
|  | "array1": pa.int64(), | 
|  | "array2": pa.int64(), | 
|  | }, | 
|  | pa.int64()) | 
|  | return varargs_function, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def nullary_func_fixture(): | 
|  | """ | 
|  | Register a nullary scalar function. | 
|  | """ | 
|  | def nullary_func(context): | 
|  | return pa.array([42] * context.batch_length, type=pa.int64(), | 
|  | memory_pool=context.memory_pool) | 
|  |  | 
|  | func_doc = { | 
|  | "summary": "random function", | 
|  | "description": "generates a random value" | 
|  | } | 
|  | func_name = "test_nullary_func" | 
|  | pc.register_scalar_function(nullary_func, | 
|  | func_name, | 
|  | func_doc, | 
|  | {}, | 
|  | pa.int64()) | 
|  |  | 
|  | return nullary_func, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def wrong_output_type_func_fixture(): | 
|  | """ | 
|  | Register a scalar function which returns something that is neither | 
|  | a Arrow scalar or array. | 
|  | """ | 
|  | def wrong_output_type(ctx): | 
|  | return 42 | 
|  |  | 
|  | func_name = "test_wrong_output_type" | 
|  | in_types = {} | 
|  | out_type = pa.int64() | 
|  | doc = { | 
|  | "summary": "return wrong output type", | 
|  | "description": "" | 
|  | } | 
|  | pc.register_scalar_function(wrong_output_type, func_name, doc, | 
|  | in_types, out_type) | 
|  | return wrong_output_type, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def wrong_output_datatype_func_fixture(): | 
|  | """ | 
|  | Register a scalar function whose actual output DataType doesn't | 
|  | match the declared output DataType. | 
|  | """ | 
|  | def wrong_output_datatype(ctx, array): | 
|  | return pc.call_function("add", [array, 1]) | 
|  | func_name = "test_wrong_output_datatype" | 
|  | in_types = {"array": pa.int64()} | 
|  | # The actual output DataType will be int64. | 
|  | out_type = pa.int16() | 
|  | doc = { | 
|  | "summary": "return wrong output datatype", | 
|  | "description": "" | 
|  | } | 
|  | pc.register_scalar_function(wrong_output_datatype, func_name, doc, | 
|  | in_types, out_type) | 
|  | return wrong_output_datatype, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def wrong_signature_func_fixture(): | 
|  | """ | 
|  | Register a scalar function with the wrong signature. | 
|  | """ | 
|  | # Missing the context argument | 
|  | def wrong_signature(): | 
|  | return pa.scalar(1, type=pa.int64()) | 
|  |  | 
|  | func_name = "test_wrong_signature" | 
|  | in_types = {} | 
|  | out_type = pa.int64() | 
|  | doc = { | 
|  | "summary": "UDF with wrong signature", | 
|  | "description": "" | 
|  | } | 
|  | pc.register_scalar_function(wrong_signature, func_name, doc, | 
|  | in_types, out_type) | 
|  | return wrong_signature, func_name | 
|  |  | 
|  |  | 
|  | @pytest.fixture(scope="session") | 
|  | def raising_func_fixture(): | 
|  | """ | 
|  | Register a scalar function which raises a custom exception. | 
|  | """ | 
|  | def raising_func(ctx): | 
|  | raise MyError("error raised by scalar UDF") | 
|  | func_name = "test_raise" | 
|  | doc = { | 
|  | "summary": "raising function", | 
|  | "description": "" | 
|  | } | 
|  | pc.register_scalar_function(raising_func, func_name, doc, | 
|  | {}, pa.int64()) | 
|  | return raising_func, func_name | 
|  |  | 
|  |  | 
|  | def check_scalar_function(func_fixture, | 
|  | inputs, *, | 
|  | run_in_dataset=True, | 
|  | batch_length=None): | 
|  | function, name = func_fixture | 
|  | if batch_length is None: | 
|  | all_scalar = True | 
|  | for arg in inputs: | 
|  | if isinstance(arg, pa.Array): | 
|  | all_scalar = False | 
|  | batch_length = len(arg) | 
|  | if all_scalar: | 
|  | batch_length = 1 | 
|  |  | 
|  | expected_output = function(mock_scalar_udf_context(batch_length), *inputs) | 
|  | func = pc.get_function(name) | 
|  | assert func.name == name | 
|  |  | 
|  | result = pc.call_function(name, inputs, length=batch_length) | 
|  | assert result == expected_output | 
|  | # At the moment there is an issue when handling nullary functions. | 
|  | # See: ARROW-15286 and ARROW-16290. | 
|  | if run_in_dataset: | 
|  | field_names = [f'field{index}' for index, in_arr in inputs] | 
|  | table = pa.Table.from_arrays(inputs, field_names) | 
|  | dataset = ds.dataset(table) | 
|  | func_args = [ds.field(field_name) for field_name in field_names] | 
|  | result_table = dataset.to_table( | 
|  | columns={'result': ds.field('')._call(name, func_args)}) | 
|  | assert result_table.column(0).chunks[0] == expected_output | 
|  |  | 
|  |  | 
|  | def test_udf_array_unary(unary_func_fixture): | 
|  | check_scalar_function(unary_func_fixture, | 
|  | [ | 
|  | pa.array([10, 20], pa.int64()) | 
|  | ] | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_udf_array_binary(binary_func_fixture): | 
|  | check_scalar_function(binary_func_fixture, | 
|  | [ | 
|  | pa.array([10, 20], pa.int64()), | 
|  | pa.array([2, 4], pa.int64()) | 
|  | ] | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_udf_array_ternary(ternary_func_fixture): | 
|  | check_scalar_function(ternary_func_fixture, | 
|  | [ | 
|  | pa.array([10, 20], pa.int64()), | 
|  | pa.array([2, 4], pa.int64()), | 
|  | pa.array([5, 10], pa.int64()) | 
|  | ] | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_udf_array_varargs(varargs_func_fixture): | 
|  | check_scalar_function(varargs_func_fixture, | 
|  | [ | 
|  | pa.array([2, 3], pa.int64()), | 
|  | pa.array([10, 20], pa.int64()), | 
|  | pa.array([3, 7], pa.int64()), | 
|  | pa.array([20, 30], pa.int64()), | 
|  | pa.array([5, 10], pa.int64()) | 
|  | ] | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_registration_errors(): | 
|  | # validate function name | 
|  | doc = { | 
|  | "summary": "test udf input", | 
|  | "description": "parameters are validated" | 
|  | } | 
|  | in_types = {"scalar": pa.int64()} | 
|  | out_type = pa.int64() | 
|  |  | 
|  | def test_reg_function(context): | 
|  | return pa.array([10]) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | pc.register_scalar_function(test_reg_function, | 
|  | None, doc, in_types, | 
|  | out_type) | 
|  |  | 
|  | # validate function | 
|  | with pytest.raises(TypeError, match="func must be a callable"): | 
|  | pc.register_scalar_function(None, "test_none_function", doc, in_types, | 
|  | out_type) | 
|  |  | 
|  | # validate output type | 
|  | expected_expr = "DataType expected, got <class 'NoneType'>" | 
|  | with pytest.raises(TypeError, match=expected_expr): | 
|  | pc.register_scalar_function(test_reg_function, | 
|  | "test_output_function", doc, in_types, | 
|  | None) | 
|  |  | 
|  | # validate input type | 
|  | expected_expr = "in_types must be a dictionary of DataType" | 
|  | with pytest.raises(TypeError, match=expected_expr): | 
|  | pc.register_scalar_function(test_reg_function, | 
|  | "test_input_function", doc, None, | 
|  | out_type) | 
|  |  | 
|  | # register an already registered function | 
|  | # first registration | 
|  | pc.register_scalar_function(test_reg_function, | 
|  | "test_reg_function", doc, {}, | 
|  | out_type) | 
|  | # second registration | 
|  | expected_expr = "Already have a function registered with name:" \ | 
|  | + " test_reg_function" | 
|  | with pytest.raises(KeyError, match=expected_expr): | 
|  | pc.register_scalar_function(test_reg_function, | 
|  | "test_reg_function", doc, {}, | 
|  | out_type) | 
|  |  | 
|  |  | 
|  | def test_varargs_function_validation(varargs_func_fixture): | 
|  | _, func_name = varargs_func_fixture | 
|  |  | 
|  | error_msg = r"VarArgs function 'z=ax\+by\+c' needs at least 2 arguments" | 
|  |  | 
|  | with pytest.raises(ValueError, match=error_msg): | 
|  | pc.call_function(func_name, [42]) | 
|  |  | 
|  |  | 
|  | def test_function_doc_validation(): | 
|  | # validate arity | 
|  | in_types = {"scalar": pa.int64()} | 
|  | out_type = pa.int64() | 
|  |  | 
|  | # doc with no summary | 
|  | func_doc = { | 
|  | "description": "desc" | 
|  | } | 
|  |  | 
|  | def add_const(ctx, scalar): | 
|  | return pc.call_function("add", [scalar, 1]) | 
|  |  | 
|  | with pytest.raises(ValueError, | 
|  | match="Function doc must contain a summary"): | 
|  | pc.register_scalar_function(add_const, "test_no_summary", | 
|  | func_doc, in_types, | 
|  | out_type) | 
|  |  | 
|  | # doc with no decription | 
|  | func_doc = { | 
|  | "summary": "test summary" | 
|  | } | 
|  |  | 
|  | with pytest.raises(ValueError, | 
|  | match="Function doc must contain a description"): | 
|  | pc.register_scalar_function(add_const, "test_no_desc", | 
|  | func_doc, in_types, | 
|  | out_type) | 
|  |  | 
|  |  | 
|  | def test_nullary_function(nullary_func_fixture): | 
|  | # XXX the Python compute layer API doesn't let us override batch_length, | 
|  | # so only test with the default value of 1. | 
|  | check_scalar_function(nullary_func_fixture, [], run_in_dataset=False, | 
|  | batch_length=1) | 
|  |  | 
|  |  | 
|  | def test_wrong_output_type(wrong_output_type_func_fixture): | 
|  | _, func_name = wrong_output_type_func_fixture | 
|  |  | 
|  | with pytest.raises(TypeError, | 
|  | match="Unexpected output type: int"): | 
|  | pc.call_function(func_name, [], length=1) | 
|  |  | 
|  |  | 
|  | def test_wrong_output_datatype(wrong_output_datatype_func_fixture): | 
|  | _, func_name = wrong_output_datatype_func_fixture | 
|  |  | 
|  | expected_expr = ("Expected output datatype int16, " | 
|  | "but function returned datatype int64") | 
|  |  | 
|  | with pytest.raises(TypeError, match=expected_expr): | 
|  | pc.call_function(func_name, [pa.array([20, 30])]) | 
|  |  | 
|  |  | 
|  | def test_wrong_signature(wrong_signature_func_fixture): | 
|  | _, func_name = wrong_signature_func_fixture | 
|  |  | 
|  | expected_expr = (r"wrong_signature\(\) takes 0 positional arguments " | 
|  | "but 1 was given") | 
|  |  | 
|  | with pytest.raises(TypeError, match=expected_expr): | 
|  | pc.call_function(func_name, [], length=1) | 
|  |  | 
|  |  | 
|  | def test_wrong_datatype_declaration(): | 
|  | def identity(ctx, val): | 
|  | return val | 
|  |  | 
|  | func_name = "test_wrong_datatype_declaration" | 
|  | in_types = {"array": pa.int64()} | 
|  | out_type = {} | 
|  | doc = { | 
|  | "summary": "test output value", | 
|  | "description": "test output" | 
|  | } | 
|  | with pytest.raises(TypeError, | 
|  | match="DataType expected, got <class 'dict'>"): | 
|  | pc.register_scalar_function(identity, func_name, | 
|  | doc, in_types, out_type) | 
|  |  | 
|  |  | 
|  | def test_wrong_input_type_declaration(): | 
|  | def identity(ctx, val): | 
|  | return val | 
|  |  | 
|  | func_name = "test_wrong_input_type_declaration" | 
|  | in_types = {"array": None} | 
|  | out_type = pa.int64() | 
|  | doc = { | 
|  | "summary": "test invalid input type", | 
|  | "description": "invalid input function" | 
|  | } | 
|  | with pytest.raises(TypeError, | 
|  | match="DataType expected, got <class 'NoneType'>"): | 
|  | pc.register_scalar_function(identity, func_name, doc, | 
|  | in_types, out_type) | 
|  |  | 
|  |  | 
|  | def test_scalar_udf_context(unary_func_fixture): | 
|  | # Check the memory_pool argument is properly propagated | 
|  | proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) | 
|  | _, func_name = unary_func_fixture | 
|  |  | 
|  | res = pc.call_function(func_name, | 
|  | [pa.array([1] * 1000, type=pa.int64())], | 
|  | memory_pool=proxy_pool) | 
|  | assert res == pa.array([2] * 1000, type=pa.int64()) | 
|  | assert proxy_pool.bytes_allocated() == 1000 * 8 | 
|  | # Destroying Python array should destroy underlying C++ memory | 
|  | res = None | 
|  | assert proxy_pool.bytes_allocated() == 0 | 
|  |  | 
|  |  | 
|  | def test_raising_func(raising_func_fixture): | 
|  | _, func_name = raising_func_fixture | 
|  | with pytest.raises(MyError, match="error raised by scalar UDF"): | 
|  | pc.call_function(func_name, [], length=1) | 
|  |  | 
|  |  | 
|  | def test_scalar_input(unary_func_fixture): | 
|  | function, func_name = unary_func_fixture | 
|  | res = pc.call_function(func_name, [pa.scalar(10)]) | 
|  | assert res == pa.scalar(11) | 
|  |  | 
|  |  | 
|  | def test_input_lifetime(unary_func_fixture): | 
|  | function, func_name = unary_func_fixture | 
|  |  | 
|  | proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) | 
|  | assert proxy_pool.bytes_allocated() == 0 | 
|  |  | 
|  | v = pa.array([1] * 1000, type=pa.int64(), memory_pool=proxy_pool) | 
|  | assert proxy_pool.bytes_allocated() == 1000 * 8 | 
|  | pc.call_function(func_name, [v]) | 
|  | assert proxy_pool.bytes_allocated() == 1000 * 8 | 
|  | # Calling a UDF should not have kept `v` alive longer than required | 
|  | v = None | 
|  | assert proxy_pool.bytes_allocated() == 0 | 
|  |  | 
|  |  | 
|  | def _record_batch_from_iters(schema, *iters): | 
|  | arrays = [pa.array(list(v), type=schema[i].type) | 
|  | for i, v in enumerate(iters)] | 
|  | return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) | 
|  |  | 
|  |  | 
|  | def _record_batch_for_range(schema, n): | 
|  | return _record_batch_from_iters(schema, | 
|  | range(n, n + 10), | 
|  | range(n + 1, n + 11)) | 
|  |  | 
|  |  | 
|  | def make_udt_func(schema, batch_gen): | 
|  | def udf_func(ctx): | 
|  | class UDT: | 
|  | def __init__(self): | 
|  | self.caller = None | 
|  |  | 
|  | def __call__(self, ctx): | 
|  | try: | 
|  | if self.caller is None: | 
|  | self.caller, ctx = batch_gen(ctx).send, None | 
|  | batch = self.caller(ctx) | 
|  | except StopIteration: | 
|  | arrays = [pa.array([], type=field.type) | 
|  | for field in schema] | 
|  | batch = pa.RecordBatch.from_arrays( | 
|  | arrays=arrays, schema=schema) | 
|  | return batch.to_struct_array() | 
|  | return UDT() | 
|  | return udf_func | 
|  |  | 
|  |  | 
|  | def datasource1_direct(): | 
|  | """A short dataset""" | 
|  | schema = datasource1_schema() | 
|  |  | 
|  | class Generator: | 
|  | def __init__(self): | 
|  | self.n = 3 | 
|  |  | 
|  | def __call__(self, ctx): | 
|  | if self.n == 0: | 
|  | batch = _record_batch_from_iters(schema, [], []) | 
|  | else: | 
|  | self.n -= 1 | 
|  | batch = _record_batch_for_range(schema, self.n) | 
|  | return batch.to_struct_array() | 
|  | return lambda ctx: Generator() | 
|  |  | 
|  |  | 
|  | def datasource1_generator(): | 
|  | schema = datasource1_schema() | 
|  |  | 
|  | def batch_gen(ctx): | 
|  | for n in range(3, 0, -1): | 
|  | # ctx = | 
|  | yield _record_batch_for_range(schema, n - 1) | 
|  | return make_udt_func(schema, batch_gen) | 
|  |  | 
|  |  | 
|  | def datasource1_exception(): | 
|  | schema = datasource1_schema() | 
|  |  | 
|  | def batch_gen(ctx): | 
|  | for n in range(3, 0, -1): | 
|  | # ctx = | 
|  | yield _record_batch_for_range(schema, n - 1) | 
|  | raise RuntimeError("datasource1_exception") | 
|  | return make_udt_func(schema, batch_gen) | 
|  |  | 
|  |  | 
|  | def datasource1_schema(): | 
|  | return pa.schema([('', pa.int32()), ('', pa.int32())]) | 
|  |  | 
|  |  | 
|  | def datasource1_args(func, func_name): | 
|  | func_doc = {"summary": f"{func_name} UDT", | 
|  | "description": "test {func_name} UDT"} | 
|  | in_types = {} | 
|  | out_type = pa.struct([("", pa.int32()), ("", pa.int32())]) | 
|  | return func, func_name, func_doc, in_types, out_type | 
|  |  | 
|  |  | 
|  | def _test_datasource1_udt(func_maker): | 
|  | schema = datasource1_schema() | 
|  | func = func_maker() | 
|  | func_name = func_maker.__name__ | 
|  | func_args = datasource1_args(func, func_name) | 
|  | pc.register_tabular_function(*func_args) | 
|  | n = 3 | 
|  | for item in pc.call_tabular_function(func_name): | 
|  | n -= 1 | 
|  | assert item == _record_batch_for_range(schema, n) | 
|  |  | 
|  |  | 
|  | def test_udt_datasource1_direct(): | 
|  | _test_datasource1_udt(datasource1_direct) | 
|  |  | 
|  |  | 
|  | def test_udt_datasource1_generator(): | 
|  | _test_datasource1_udt(datasource1_generator) | 
|  |  | 
|  |  | 
|  | def test_udt_datasource1_exception(): | 
|  | with pytest.raises(RuntimeError, match='datasource1_exception'): | 
|  | _test_datasource1_udt(datasource1_exception) |