blob: 1282749a7d81e0aad45d19198c85f9b3417f8493 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import asyncio
import uuid
from pyflink.table import DataTypes
from pyflink.table.udf import AsyncScalarFunction, udf, FunctionContext
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, run_with_config
def generate_random_table_name():
return "Table{0}".format(str(uuid.uuid1()).replace("-", "_"))
class AsyncScalarFunctionTests(PyFlinkStreamTableTestCase):
"""
Integration tests for Python Async Scalar Function.
"""
def test_basic_async_scalar_function(self):
class AsyncFunctionWithLifecycle(AsyncScalarFunction):
def open(self, function_context: FunctionContext):
self.prefix = "opened_"
async def eval(self, value):
await asyncio.sleep(0.001)
return self.prefix + value
def close(self):
pass
async_func = udf(
AsyncFunctionWithLifecycle(),
input_types=[DataTypes.STRING()],
result_type=DataTypes.STRING()
)
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test1",), ("test2",)], ['a'])
t.select(t.a, async_func(t.a).alias('b')).execute_insert(sink_table).wait()
actual = source_sink_utils.results()
self.assert_equals(actual, [
"+I[test1, opened_test1]",
"+I[test2, opened_test2]"
])
def test_raise_exception_in_async_eval(self):
"""Test async scalar function that raises exception during evaluation."""
class ExceptionAsyncFunction(AsyncScalarFunction):
async def eval(self, value: str) -> str:
raise ValueError("Test exception in async eval")
async_func = udf(
ExceptionAsyncFunction(),
input_types=[DataTypes.STRING()],
result_type=DataTypes.STRING()
)
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test1",)], ['a'])
with self.assertRaises(Exception) as context:
t.select(t.a, async_func(t.a).alias('b')).execute_insert(sink_table).wait()
# Verify exception message is propagated
self.assertIn("Test exception in async eval", str(context.exception))
def test_async_function_with_retry_logic(self):
"""Test async scalar function with custom retry logic."""
class RetryAsyncFunction(AsyncScalarFunction):
def open(self, function_context: FunctionContext):
self.attempt_count = {}
async def eval(self, value: str) -> str:
await asyncio.sleep(0.01)
# Track attempts for this value
if value not in self.attempt_count:
self.attempt_count[value] = 0
self.attempt_count[value] += 1
# Fail first attempts, succeed on 2nd
if self.attempt_count[value] < 2:
raise ValueError(f"Retry attempt {self.attempt_count[value]}")
return f"success_{value}"
async_func = udf(
RetryAsyncFunction(),
input_types=[DataTypes.STRING()],
result_type=DataTypes.STRING()
)
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test",)], ['a'])
t.select(t.a, async_func(t.a).alias('b')).execute_insert(sink_table).wait()
actual = source_sink_utils.results()
# If retry logic works, we should get success result
self.assert_equals(actual, [
"+I[test, success_test]"
])
def test_chaining_async_scalar_functions(self):
"""Test chaining multiple async scalar functions."""
@udf(result_type=DataTypes.STRING())
async def async_add_prefix(value: str) -> str:
await asyncio.sleep(0.001)
return f"prefix_{value}"
@udf(result_type=DataTypes.STRING())
async def async_add_suffix(value: str) -> str:
await asyncio.sleep(0.001)
return f"{value}_suffix"
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test",)], ['a'])
# Chain async functions
t.select(t.a, async_add_suffix(async_add_prefix(t.a)).alias('b')) \
.execute_insert(sink_table).wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[test, prefix_test_suffix]"])
def test_chaining_async_and_sync_functions(self):
"""Test chaining async scalar functions with regular Python UDFs."""
from pyflink.table.udf import udf
# Define a regular synchronous UDF
@udf(result_type=DataTypes.STRING())
def sync_upper(value: str) -> str:
return value.upper()
# Define async UDFs
@udf(result_type=DataTypes.STRING())
async def async_add_prefix(value: str) -> str:
await asyncio.sleep(0.001)
return f"prefix_{value}"
@udf(result_type=DataTypes.STRING())
async def async_add_suffix(value: str) -> str:
await asyncio.sleep(0.001)
return f"{value}_suffix"
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING, c STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test", "test2")], ['a', 'b'])
# Test various chaining patterns:
# 1. sync -> async
# 2. async -> sync -> async
t.select(
t.a,
async_add_prefix(sync_upper(t.a)).alias('b'), # sync -> async
async_add_suffix(sync_upper(async_add_prefix(t.a))).alias('c') # async -> sync -> async
).execute_insert(sink_table).wait()
actual = source_sink_utils.results()
self.assert_equals(actual, [
"+I[test, prefix_TEST, PREFIX_TEST_suffix]"
])
def test_async_udf_with_pandas_raises_error(self):
"""Test that using pandas func_type with async function raises an error."""
from pyflink.table.udf import udf
# Test 1: async def function with pandas should raise error
with self.assertRaises(ValueError) as context:
@udf(result_type=DataTypes.STRING(), func_type='pandas')
async def async_func(value: str) -> str:
await asyncio.sleep(0.001)
return f"async_{value}"
self.assertIn("Async scalar functions do not support pandas func_type",
str(context.exception))
# Test 2: AsyncScalarFunction with pandas should raise error
class MyAsyncFunc(AsyncScalarFunction):
async def eval(self, value: str) -> str:
await asyncio.sleep(0.001)
return f"class_{value}"
with self.assertRaises(ValueError) as context:
udf(MyAsyncFunc(), result_type=DataTypes.STRING(), func_type='pandas')
self.assertIn("Async scalar functions do not support pandas func_type",
str(context.exception))
@run_with_config({"table.exec.async-scalar.timeout": "5s"})
def test_processing_timeout(self):
"""Test async scalar function processing timeout."""
class TimeoutAsyncFunction(AsyncScalarFunction):
async def eval(self, value: str) -> str:
# Sleep longer than timeout
await asyncio.sleep(10)
return f"completed_{value}"
# Note: For Table API async scalar functions, timeout behavior is handled
# by the framework configuration. This test verifies the function can be
# registered and used, but actual timeout handling is framework-level.
async_func = udf(
TimeoutAsyncFunction(),
input_types=[DataTypes.STRING()],
result_type=DataTypes.STRING()
)
sink_table = generate_random_table_name()
self.t_env.execute_sql(f"""
CREATE TABLE {sink_table}(a STRING, b STRING)
WITH ('connector'='test-sink')
""")
t = self.t_env.from_elements([("test",)], ['a'])
# With short processing time limit, this should either timeout or complete
# based on framework timeout configuration
try:
t.select(t.a, async_func(t.a).alias('b')).execute_insert(sink_table).wait()
# If execution completes, verify results exist
actual = source_sink_utils.results()
# Results could be timeout or actual completion depending on timing
self.assert_equals(actual, [
"+I[test, completed_test]"
])
except Exception as e:
# Timeout exceptions are acceptable for this test
self.assertIn("Async function call has timed out for input: ['test']",
str(e))
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)