| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import unittest |
| |
| from pyflink.common import Row |
| from pyflink.table import DataTypes |
| from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf |
| from pyflink.table.expressions import col |
| from pyflink.testing import source_sink_utils |
| from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, \ |
| PyFlinkBatchTableTestCase |
| |
| |
| class UserDefinedTableFunctionTests(object): |
| |
| def test_table_function(self): |
| self.t_env.execute_sql(""" |
| CREATE TABLE Results_test_table_function( |
| a BIGINT, |
| b BIGINT, |
| c BIGINT |
| ) WITH ('connector'='test-sink')""") |
| |
| multi_emit = udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) |
| multi_num = udf(MultiNum(), result_type=DataTypes.BIGINT()) |
| |
| t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) |
| t = t.join_lateral(multi_emit((t.a + t.a) / 2, multi_num(t.b)).alias('x', 'y')) |
| t = t.left_outer_join_lateral(condition_multi_emit(t.x, t.y).alias('m')) \ |
| .select(t.x, t.y, col("m")) |
| t = t.left_outer_join_lateral(identity(t.m).alias('n')) \ |
| .select(t.x, t.y, col("n")) |
| t.execute_insert("Results_test_table_function").wait() |
| actual = source_sink_utils.results() |
| self.assert_equals(actual, |
| ["+I[1, 0, null]", "+I[1, 1, null]", "+I[2, 0, null]", "+I[2, 1, null]", |
| "+I[3, 0, 0]", "+I[3, 0, 1]", "+I[3, 0, 2]", "+I[3, 1, 1]", |
| "+I[3, 1, 2]", "+I[3, 2, 2]", "+I[3, 3, null]"]) |
| |
| def test_table_function_with_sql_query(self): |
| self.t_env.execute_sql(""" |
| CREATE TABLE Results_test_table_function_with_sql_query( |
| a BIGINT, |
| b BIGINT, |
| c BIGINT |
| ) WITH ('connector'='test-sink')""") |
| |
| self.t_env.create_temporary_system_function( |
| "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])) |
| |
| t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) |
| self.t_env.create_temporary_view("MyTable", t) |
| t = self.t_env.sql_query( |
| "SELECT a, x, y FROM MyTable LEFT JOIN LATERAL TABLE(multi_emit(a, b)) as T(x, y)" |
| " ON TRUE") |
| t.execute_insert("Results_test_table_function_with_sql_query").wait() |
| actual = source_sink_utils.results() |
| self.assert_equals(actual, ["+I[1, 1, 0]", "+I[2, 2, 0]", "+I[3, 3, 0]", "+I[3, 3, 1]"]) |
| |
| def test_table_function_in_sql(self): |
| sql = f""" |
| CREATE TEMPORARY FUNCTION pyfunc AS |
| '{UserDefinedTableFunctionTests.__module__}.identity' |
| LANGUAGE PYTHON |
| """ |
| self.t_env.execute_sql(sql) |
| self.assert_equals( |
| list(self.t_env.execute_sql( |
| "SELECT v FROM (VALUES (1)) AS T(id), LATERAL TABLE(pyfunc(id)) AS P(v)" |
| ).collect()), |
| [Row(1)]) |
| |
| |
| class PyFlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests, |
| PyFlinkStreamTableTestCase): |
| |
| def test_execute_from_json_plan(self): |
| # create source file path |
| tmp_dir = self.tempdir |
| data = ['1,1', '3,2', '2,1'] |
| source_path = tmp_dir + '/test_execute_from_json_plan_input.csv' |
| sink_path = tmp_dir + '/test_execute_from_json_plan_out' |
| with open(source_path, 'w') as fd: |
| for ele in data: |
| fd.write(ele + '\n') |
| |
| source_table = """ |
| CREATE TABLE source_table ( |
| a BIGINT, |
| b BIGINT |
| ) WITH ( |
| 'connector' = 'filesystem', |
| 'path' = '%s', |
| 'format' = 'csv' |
| ) |
| """ % source_path |
| self.t_env.execute_sql(source_table) |
| |
| self.t_env.execute_sql(""" |
| CREATE TABLE sink_table ( |
| a BIGINT, |
| b BIGINT, |
| c BIGINT |
| ) WITH ( |
| 'connector' = 'filesystem', |
| 'path' = '%s', |
| 'format' = 'csv' |
| ) |
| """ % sink_path) |
| |
| self.t_env.create_temporary_system_function( |
| "multi_emit2", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])) |
| |
| json_plan = self.t_env._j_tenv.compilePlanSql("INSERT INTO sink_table " |
| "SELECT a, x, y FROM source_table " |
| "LEFT JOIN LATERAL TABLE(multi_emit2(a, b))" |
| " as T(x, y)" |
| " ON TRUE") |
| from py4j.java_gateway import get_method |
| get_method(json_plan.execute(), "await")() |
| |
| import glob |
| lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')] |
| lines.sort() |
| self.assertEqual(lines, ['1,1,0', '2,2,0', '3,3,0', '3,3,1']) |
| |
| |
| class PyFlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests, |
| PyFlinkBatchTableTestCase): |
| pass |
| |
| |
| class PyFlinkEmbeddedThreadTests(UserDefinedTableFunctionTests, PyFlinkStreamTableTestCase): |
| def setUp(self): |
| super(PyFlinkEmbeddedThreadTests, self).setUp() |
| self.t_env.get_config().set("python.execution-mode", "thread") |
| |
| |
| class MultiEmit(TableFunction, unittest.TestCase): |
| |
| def open(self, function_context): |
| self.counter_sum = 0 |
| |
| def eval(self, x, y): |
| self.counter_sum += y |
| for i in range(y): |
| yield x, i |
| |
| |
| @udtf(result_types=['bigint']) |
| def identity(x): |
| if x is not None: |
| from pyflink.common import Row |
| return Row(x) |
| |
| |
| # test specify the input_types |
| @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], |
| result_types=DataTypes.BIGINT()) |
| def condition_multi_emit(x, y): |
| if x == 3: |
| return range(y, x) |
| |
| |
| class MultiNum(ScalarFunction): |
| def eval(self, x): |
| return x * 2 |
| |
| |
| if __name__ == '__main__': |
| import unittest |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |