blob: bf3ca4fb5c56cb33d0d86c270093d908828c9a1c [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import unittest
from pyflink.common import Row
from pyflink.table import DataTypes
from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf
from pyflink.table.expressions import col
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, \
PyFlinkBatchTableTestCase
class UserDefinedTableFunctionTests(object):
def test_table_function(self):
self.t_env.execute_sql("""
CREATE TABLE Results_test_table_function(
a BIGINT,
b BIGINT,
c BIGINT
) WITH ('connector'='test-sink')""")
multi_emit = udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
multi_num = udf(MultiNum(), result_type=DataTypes.BIGINT())
t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
t = t.join_lateral(multi_emit((t.a + t.a) / 2, multi_num(t.b)).alias('x', 'y'))
t = t.left_outer_join_lateral(condition_multi_emit(t.x, t.y).alias('m')) \
.select(t.x, t.y, col("m"))
t = t.left_outer_join_lateral(identity(t.m).alias('n')) \
.select(t.x, t.y, col("n"))
t.execute_insert("Results_test_table_function").wait()
actual = source_sink_utils.results()
self.assert_equals(actual,
["+I[1, 0, null]", "+I[1, 1, null]", "+I[2, 0, null]", "+I[2, 1, null]",
"+I[3, 0, 0]", "+I[3, 0, 1]", "+I[3, 0, 2]", "+I[3, 1, 1]",
"+I[3, 1, 2]", "+I[3, 2, 2]", "+I[3, 3, null]"])
def test_table_function_with_sql_query(self):
self.t_env.execute_sql("""
CREATE TABLE Results_test_table_function_with_sql_query(
a BIGINT,
b BIGINT,
c BIGINT
) WITH ('connector'='test-sink')""")
self.t_env.create_temporary_system_function(
"multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))
t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
self.t_env.create_temporary_view("MyTable", t)
t = self.t_env.sql_query(
"SELECT a, x, y FROM MyTable LEFT JOIN LATERAL TABLE(multi_emit(a, b)) as T(x, y)"
" ON TRUE")
t.execute_insert("Results_test_table_function_with_sql_query").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1, 1, 0]", "+I[2, 2, 0]", "+I[3, 3, 0]", "+I[3, 3, 1]"])
def test_table_function_in_sql(self):
sql = f"""
CREATE TEMPORARY FUNCTION pyfunc AS
'{UserDefinedTableFunctionTests.__module__}.identity'
LANGUAGE PYTHON
"""
self.t_env.execute_sql(sql)
self.assert_equals(
list(self.t_env.execute_sql(
"SELECT v FROM (VALUES (1)) AS T(id), LATERAL TABLE(pyfunc(id)) AS P(v)"
).collect()),
[Row(1)])
class PyFlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
PyFlinkStreamTableTestCase):
def test_execute_from_json_plan(self):
# create source file path
tmp_dir = self.tempdir
data = ['1,1', '3,2', '2,1']
source_path = tmp_dir + '/test_execute_from_json_plan_input.csv'
sink_path = tmp_dir + '/test_execute_from_json_plan_out'
with open(source_path, 'w') as fd:
for ele in data:
fd.write(ele + '\n')
source_table = """
CREATE TABLE source_table (
a BIGINT,
b BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % source_path
self.t_env.execute_sql(source_table)
self.t_env.execute_sql("""
CREATE TABLE sink_table (
a BIGINT,
b BIGINT,
c BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % sink_path)
self.t_env.create_temporary_system_function(
"multi_emit2", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))
json_plan = self.t_env._j_tenv.compilePlanSql("INSERT INTO sink_table "
"SELECT a, x, y FROM source_table "
"LEFT JOIN LATERAL TABLE(multi_emit2(a, b))"
" as T(x, y)"
" ON TRUE")
from py4j.java_gateway import get_method
get_method(json_plan.execute(), "await")()
import glob
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
lines.sort()
self.assertEqual(lines, ['1,1,0', '2,2,0', '3,3,0', '3,3,1'])
class PyFlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests,
PyFlinkBatchTableTestCase):
pass
class PyFlinkEmbeddedThreadTests(UserDefinedTableFunctionTests, PyFlinkStreamTableTestCase):
def setUp(self):
super(PyFlinkEmbeddedThreadTests, self).setUp()
self.t_env.get_config().set("python.execution-mode", "thread")
class MultiEmit(TableFunction, unittest.TestCase):
def open(self, function_context):
self.counter_sum = 0
def eval(self, x, y):
self.counter_sum += y
for i in range(y):
yield x, i
@udtf(result_types=['bigint'])
def identity(x):
if x is not None:
from pyflink.common import Row
return Row(x)
# test specify the input_types
@udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
result_types=DataTypes.BIGINT())
def condition_multi_emit(x, y):
if x == 3:
return range(y, x)
class MultiNum(ScalarFunction):
def eval(self, x):
return x * 2
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)