blob: 3d80afebfacf16cfdefa395b648a5974456084a9 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
import os
import unittest
import pytz
from pyflink.table import DataTypes, expressions as expr
from pyflink.table.udf import ScalarFunction, udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, \
PyFlinkBlinkStreamTableTestCase, PyFlinkBlinkBatchTableTestCase, \
PyFlinkBatchTableTestCase
class UserDefinedFunctionTests(object):
def test_scalar_function(self):
# test metric disabled.
self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'false')
# test lambda function
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
# test Python ScalarFunction
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
# test callable function
add_one_callable = udf(CallablePlus(), result_type=DataTypes.BIGINT())
def partial_func(col, param):
return col + param
# test partial function
import functools
add_one_partial = udf(functools.partial(partial_func, param=1),
result_type=DataTypes.BIGINT())
# check memory limit is set
@udf(result_type=DataTypes.BIGINT())
def check_memory_limit():
assert os.environ['_PYTHON_WORKER_MEMORY_LIMIT'] is not None
return 1
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(),
DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
t.where(add_one(t.b) <= 3).select(
add_one(t.a), subtract_one(t.b), add(t.a, t.c), add_one_callable(t.a),
add_one_partial(t.a), check_memory_limit(), t.a) \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2,1,4,2,2,1,1", "4,0,12,4,4,1,3"])
def test_chaining_scalar_function(self):
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.INT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c'])
t.select(add(add_one(t.a), subtract_one(t.b)), t.c, expr.lit(1)) \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["3,1,1", "7,2,1", "4,3,1"])
def test_udf_in_join_condition(self):
t1 = self.t_env.from_elements([(2, "Hi")], ['a', 'b'])
t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
f = udf(lambda i: i, result_type=DataTypes.BIGINT())
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
[DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()])
self.t_env.register_table_sink("Results", table_sink)
t1.join(t2).where(f(t1.a) == t2.c).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2,Hi,2,Flink"])
def test_udf_in_join_condition_2(self):
t1 = self.t_env.from_elements([(1, "Hi"), (2, "Hi")], ['a', 'b'])
t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
f = udf(lambda i: i, result_type=DataTypes.BIGINT())
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
[DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()])
self.t_env.register_table_sink("Results", table_sink)
t1.join(t2).where(f(t1.a) == f(t2.c)).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2,Hi,2,Flink"])
def test_udf_with_constant_params(self):
def udf_with_constant_params(p, null_param, tinyint_param, smallint_param, int_param,
bigint_param, decimal_param, float_param, double_param,
boolean_param, str_param,
date_param, time_param, timestamp_param):
from decimal import Decimal
import datetime
assert null_param is None, 'null_param is wrong value %s' % null_param
assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \
% type(tinyint_param)
p += tinyint_param
assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \
% type(smallint_param)
p += smallint_param
assert isinstance(int_param, int), 'int_param of wrong type %s !' \
% type(int_param)
p += int_param
assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \
% type(bigint_param)
p += bigint_param
assert decimal_param == Decimal('1.05'), \
'decimal_param is wrong value %s ' % decimal_param
p += int(decimal_param)
assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-06), \
'float_param is wrong value %s ' % float_param
p += int(float_param)
assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-07), \
'double_param is wrong value %s ' % double_param
p += int(double_param)
assert boolean_param is True, 'boolean_param is wrong value %s' % boolean_param
assert str_param == 'flink', 'str_param is wrong value %s' % str_param
assert date_param == datetime.date(year=2014, month=9, day=13), \
'date_param is wrong value %s' % date_param
assert time_param == datetime.time(hour=12, minute=0, second=0), \
'time_param is wrong value %s' % time_param
assert timestamp_param == datetime.datetime(1999, 9, 10, 5, 20, 10), \
'timestamp_param is wrong value %s' % timestamp_param
return p
self.t_env.create_temporary_system_function("udf_with_constant_params",
udf(udf_with_constant_params,
result_type=DataTypes.BIGINT()))
self.t_env.create_temporary_system_function(
"udf_with_all_constant_params", udf(lambda i, j: i + j,
result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
self.t_env.register_table("test_table", t)
self.t_env.sql_query("select udf_with_all_constant_params("
"cast (1 as BIGINT),"
"cast (2 as BIGINT)), "
"udf_with_constant_params(a, "
"cast (null as BIGINT),"
"cast (1 as TINYINT),"
"cast (1 as SMALLINT),"
"cast (1 as INT),"
"cast (1 as BIGINT),"
"cast (1.05 as DECIMAL),"
"cast (1.23 as FLOAT),"
"cast (1.98932 as DOUBLE),"
"true,"
"'flink',"
"cast ('2014-09-13' as DATE),"
"cast ('12:00:00' as TIME),"
"cast ('1999-9-10 05:20:10' as TIMESTAMP))"
" from test_table").insert_into("Results")
self.t_env.execute("test")
actual = source_sink_utils.results()
self.assert_equals(actual, ["3,8", "3,9", "3,10"])
def test_overwrite_builtin_function(self):
self.t_env.create_temporary_system_function(
"plus", udf(lambda i, j: i + j - 1,
result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
t.select("plus(a, b)").execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2", "6", "3"])
def test_open(self):
self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'true')
subtract = udf(Subtract(), result_type=DataTypes.BIGINT())
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2), (2, 5), (3, 4)], ['a', 'b'])
t.select(t.a, subtract(t.b)).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["1,1", "2,4", "3,3"])
def test_udf_without_arguments(self):
one = udf(lambda: 1, result_type=DataTypes.BIGINT(), deterministic=True)
two = udf(lambda: 2, result_type=DataTypes.BIGINT(), deterministic=False)
table_sink = source_sink_utils.TestAppendSink(['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
t.select(one(), two()).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["1,2", "1,2", "1,2"])
def test_all_data_types_expression(self):
@udf(result_type=DataTypes.BOOLEAN())
def boolean_func(bool_param):
assert isinstance(bool_param, bool), 'bool_param of wrong type %s !' \
% type(bool_param)
return bool_param
@udf(result_type=DataTypes.TINYINT())
def tinyint_func(tinyint_param):
assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \
% type(tinyint_param)
return tinyint_param
@udf(result_type=DataTypes.SMALLINT())
def smallint_func(smallint_param):
assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \
% type(smallint_param)
assert smallint_param == 32767, 'smallint_param of wrong value %s' % smallint_param
return smallint_param
@udf(result_type=DataTypes.INT())
def int_func(int_param):
assert isinstance(int_param, int), 'int_param of wrong type %s !' \
% type(int_param)
assert int_param == -2147483648, 'int_param of wrong value %s' % int_param
return int_param
@udf(result_type=DataTypes.BIGINT())
def bigint_func(bigint_param):
assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \
% type(bigint_param)
return bigint_param
@udf(result_type=DataTypes.BIGINT())
def bigint_func_none(bigint_param):
assert bigint_param is None, 'bigint_param %s should be None!' % bigint_param
return bigint_param
@udf(result_type=DataTypes.FLOAT())
def float_func(float_param):
assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-6), \
'float_param is wrong value %s !' % float_param
return float_param
@udf(result_type=DataTypes.DOUBLE())
def double_func(double_param):
assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-7), \
'double_param is wrong value %s !' % double_param
return double_param
@udf(result_type=DataTypes.BYTES())
def bytes_func(bytes_param):
assert bytes_param == b'flink', \
'bytes_param is wrong value %s !' % bytes_param
return bytes_param
@udf(result_type=DataTypes.STRING())
def str_func(str_param):
assert str_param == 'pyflink', \
'str_param is wrong value %s !' % str_param
return str_param
@udf(result_type=DataTypes.DATE())
def date_func(date_param):
from datetime import date
assert date_param == date(year=2014, month=9, day=13), \
'date_param is wrong value %s !' % date_param
return date_param
@udf(result_type=DataTypes.TIME())
def time_func(time_param):
from datetime import time
assert time_param == time(hour=12, minute=0, second=0, microsecond=123000), \
'time_param is wrong value %s !' % time_param
return time_param
@udf(result_type=DataTypes.TIMESTAMP(3))
def timestamp_func(timestamp_param):
from datetime import datetime
assert timestamp_param == datetime(2018, 3, 11, 3, 0, 0, 123000), \
'timestamp_param is wrong value %s !' % timestamp_param
return timestamp_param
@udf(result_type=DataTypes.ARRAY(DataTypes.BIGINT()))
def array_func(array_param):
assert array_param == [[1, 2, 3]], \
'array_param is wrong value %s !' % array_param
return array_param[0]
@udf(result_type=DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()))
def map_func(map_param):
assert map_param == {1: 'flink', 2: 'pyflink'}, \
'map_param is wrong value %s !' % map_param
return map_param
@udf(result_type=DataTypes.DECIMAL(38, 18))
def decimal_func(decimal_param):
from decimal import Decimal
assert decimal_param == Decimal('1000000000000000000.050000000000000000'), \
'decimal_param is wrong value %s !' % decimal_param
return decimal_param
@udf(result_type=DataTypes.DECIMAL(38, 18))
def decimal_cut_func(decimal_param):
from decimal import Decimal
assert decimal_param == Decimal('1000000000000000000.059999999999999999'), \
'decimal_param is wrong value %s !' % decimal_param
return decimal_param
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.TINYINT(),
DataTypes.BOOLEAN(), DataTypes.SMALLINT(), DataTypes.INT(),
DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.BYTES(),
DataTypes.STRING(), DataTypes.DATE(), DataTypes.TIME(),
DataTypes.TIMESTAMP(3), DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()),
DataTypes.DECIMAL(38, 18), DataTypes.DECIMAL(38, 18)])
self.t_env.register_table_sink("Results", table_sink)
import datetime
import decimal
t = self.t_env.from_elements(
[(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
datetime.time(hour=12, minute=0, second=0, microsecond=123000),
datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [[1, 2, 3]],
{1: 'flink', 2: 'pyflink'}, decimal.Decimal('1000000000000000000.05'),
decimal.Decimal('1000000000000000000.05999999999999999899999999999'))],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT()),
DataTypes.FIELD("c", DataTypes.TINYINT()),
DataTypes.FIELD("d", DataTypes.BOOLEAN()),
DataTypes.FIELD("e", DataTypes.SMALLINT()),
DataTypes.FIELD("f", DataTypes.INT()),
DataTypes.FIELD("g", DataTypes.FLOAT()),
DataTypes.FIELD("h", DataTypes.DOUBLE()),
DataTypes.FIELD("i", DataTypes.BYTES()),
DataTypes.FIELD("j", DataTypes.STRING()),
DataTypes.FIELD("k", DataTypes.DATE()),
DataTypes.FIELD("l", DataTypes.TIME()),
DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)),
DataTypes.FIELD("n", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))),
DataTypes.FIELD("o", DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())),
DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18))]))
t.select(
bigint_func(t.a),
bigint_func_none(t.b),
tinyint_func(t.c),
boolean_func(t.d),
smallint_func(t.e),
int_func(t.f),
float_func(t.g),
double_func(t.h),
bytes_func(t.i),
str_func(t.j),
date_func(t.k),
time_func(t.l),
timestamp_func(t.m),
array_func(t.n),
map_func(t.o),
decimal_func(t.p),
decimal_cut_func(t.q)) \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
# Currently the sink result precision of DataTypes.TIME(precision) only supports 0.
self.assert_equals(actual,
["1,null,1,true,32767,-2147483648,1.23,1.98932,"
"[102, 108, 105, 110, 107],pyflink,2014-09-13,"
"12:00:00,2018-03-11 03:00:00.123,[1, 2, 3],"
"{1=flink, 2=pyflink},1000000000000000000.050000000000000000,"
"1000000000000000000.059999999999999999"])
def test_all_data_types(self):
def boolean_func(bool_param):
assert isinstance(bool_param, bool), 'bool_param of wrong type %s !' \
% type(bool_param)
return bool_param
def tinyint_func(tinyint_param):
assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \
% type(tinyint_param)
return tinyint_param
def smallint_func(smallint_param):
assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \
% type(smallint_param)
assert smallint_param == 32767, 'smallint_param of wrong value %s' % smallint_param
return smallint_param
def int_func(int_param):
assert isinstance(int_param, int), 'int_param of wrong type %s !' \
% type(int_param)
assert int_param == -2147483648, 'int_param of wrong value %s' % int_param
return int_param
def bigint_func(bigint_param):
assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \
% type(bigint_param)
return bigint_param
def bigint_func_none(bigint_param):
assert bigint_param is None, 'bigint_param %s should be None!' % bigint_param
return bigint_param
def float_func(float_param):
assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-6), \
'float_param is wrong value %s !' % float_param
return float_param
def double_func(double_param):
assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-7), \
'double_param is wrong value %s !' % double_param
return double_param
def bytes_func(bytes_param):
assert bytes_param == b'flink', \
'bytes_param is wrong value %s !' % bytes_param
return bytes_param
def str_func(str_param):
assert str_param == 'pyflink', \
'str_param is wrong value %s !' % str_param
return str_param
def date_func(date_param):
from datetime import date
assert date_param == date(year=2014, month=9, day=13), \
'date_param is wrong value %s !' % date_param
return date_param
def time_func(time_param):
from datetime import time
assert time_param == time(hour=12, minute=0, second=0, microsecond=123000), \
'time_param is wrong value %s !' % time_param
return time_param
def timestamp_func(timestamp_param):
from datetime import datetime
assert timestamp_param == datetime(2018, 3, 11, 3, 0, 0, 123000), \
'timestamp_param is wrong value %s !' % timestamp_param
return timestamp_param
def array_func(array_param):
assert array_param == [[1, 2, 3]], \
'array_param is wrong value %s !' % array_param
return array_param[0]
def map_func(map_param):
assert map_param == {1: 'flink', 2: 'pyflink'}, \
'map_param is wrong value %s !' % map_param
return map_param
def decimal_func(decimal_param):
from decimal import Decimal
assert decimal_param == Decimal('1000000000000000000.050000000000000000'), \
'decimal_param is wrong value %s !' % decimal_param
return decimal_param
def decimal_cut_func(decimal_param):
from decimal import Decimal
assert decimal_param == Decimal('1000000000000000000.059999999999999999'), \
'decimal_param is wrong value %s !' % decimal_param
return decimal_param
self.t_env.create_temporary_system_function(
"boolean_func", udf(boolean_func, result_type=DataTypes.BOOLEAN()))
self.t_env.create_temporary_system_function(
"tinyint_func", udf(tinyint_func, result_type=DataTypes.TINYINT()))
self.t_env.create_temporary_system_function(
"smallint_func", udf(smallint_func, result_type=DataTypes.SMALLINT()))
self.t_env.create_temporary_system_function(
"int_func", udf(int_func, result_type=DataTypes.INT()))
self.t_env.create_temporary_system_function(
"bigint_func", udf(bigint_func, result_type=DataTypes.BIGINT()))
self.t_env.create_temporary_system_function(
"bigint_func_none", udf(bigint_func_none, result_type=DataTypes.BIGINT()))
self.t_env.create_temporary_system_function(
"float_func", udf(float_func, result_type=DataTypes.FLOAT()))
self.t_env.create_temporary_system_function(
"double_func", udf(double_func, result_type=DataTypes.DOUBLE()))
self.t_env.create_temporary_system_function(
"bytes_func", udf(bytes_func, result_type=DataTypes.BYTES()))
self.t_env.create_temporary_system_function(
"str_func", udf(str_func, result_type=DataTypes.STRING()))
self.t_env.create_temporary_system_function(
"date_func", udf(date_func, result_type=DataTypes.DATE()))
self.t_env.create_temporary_system_function(
"time_func", udf(time_func, result_type=DataTypes.TIME()))
self.t_env.create_temporary_system_function(
"timestamp_func", udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3)))
self.t_env.create_temporary_system_function(
"array_func", udf(array_func, result_type=DataTypes.ARRAY(DataTypes.BIGINT())))
self.t_env.create_temporary_system_function(
"map_func", udf(map_func,
result_type=DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())))
self.t_env.register_function(
"decimal_func", udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18)))
self.t_env.register_function(
"decimal_cut_func", udf(decimal_cut_func, result_type=DataTypes.DECIMAL(38, 18)))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.TINYINT(),
DataTypes.BOOLEAN(), DataTypes.SMALLINT(), DataTypes.INT(),
DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.BYTES(),
DataTypes.STRING(), DataTypes.DATE(), DataTypes.TIME(),
DataTypes.TIMESTAMP(3), DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()),
DataTypes.DECIMAL(38, 18), DataTypes.DECIMAL(38, 18)])
self.t_env.register_table_sink("Results", table_sink)
import datetime
import decimal
t = self.t_env.from_elements(
[(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
datetime.time(hour=12, minute=0, second=0, microsecond=123000),
datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [[1, 2, 3]],
{1: 'flink', 2: 'pyflink'}, decimal.Decimal('1000000000000000000.05'),
decimal.Decimal('1000000000000000000.05999999999999999899999999999'))],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT()),
DataTypes.FIELD("c", DataTypes.TINYINT()),
DataTypes.FIELD("d", DataTypes.BOOLEAN()),
DataTypes.FIELD("e", DataTypes.SMALLINT()),
DataTypes.FIELD("f", DataTypes.INT()),
DataTypes.FIELD("g", DataTypes.FLOAT()),
DataTypes.FIELD("h", DataTypes.DOUBLE()),
DataTypes.FIELD("i", DataTypes.BYTES()),
DataTypes.FIELD("j", DataTypes.STRING()),
DataTypes.FIELD("k", DataTypes.DATE()),
DataTypes.FIELD("l", DataTypes.TIME()),
DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)),
DataTypes.FIELD("n", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))),
DataTypes.FIELD("o", DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())),
DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18))]))
t.select("bigint_func(a), bigint_func_none(b),"
"tinyint_func(c), boolean_func(d),"
"smallint_func(e),int_func(f),"
"float_func(g),double_func(h),"
"bytes_func(i),str_func(j),"
"date_func(k),time_func(l),"
"timestamp_func(m),array_func(n),"
"map_func(o),decimal_func(p),"
"decimal_cut_func(q)") \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
# Currently the sink result precision of DataTypes.TIME(precision) only supports 0.
self.assert_equals(actual,
["1,null,1,true,32767,-2147483648,1.23,1.98932,"
"[102, 108, 105, 110, 107],pyflink,2014-09-13,"
"12:00:00,2018-03-11 03:00:00.123,[1, 2, 3],"
"{1=flink, 2=pyflink},1000000000000000000.050000000000000000,"
"1000000000000000000.059999999999999999"])
def test_create_and_drop_function(self):
t_env = self.t_env
t_env.create_temporary_system_function(
"add_one_func", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
t_env.create_temporary_function(
"subtract_one_func", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
self.assert_equals(t_env.list_user_defined_functions(),
['add_one_func', 'subtract_one_func'])
t_env.drop_temporary_system_function("add_one_func")
t_env.drop_temporary_function("subtract_one_func")
self.assert_equals(t_env.list_user_defined_functions(), [])
# decide whether two floats are equal
def float_equal(a, b, rel_tol=1e-09, abs_tol=0.0):
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
class PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkStreamTableTestCase):
pass
class PyFlinkBatchUserDefinedFunctionTests(PyFlinkBatchTableTestCase):
def test_chaining_scalar_function(self):
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c'])
t = t.select(add(add_one(t.a), subtract_one(t.b)), t.c, expr.lit(1))
result = self.collect(t)
self.assertEqual(result, ["3,1,1", "7,2,1", "4,3,1"])
class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBlinkStreamTableTestCase):
def test_deterministic(self):
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
self.assertTrue(add_one._deterministic)
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), deterministic=False)
self.assertFalse(add_one._deterministic)
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
self.assertTrue(subtract_one._deterministic)
with self.assertRaises(ValueError, msg="Inconsistent deterministic: False and True"):
udf(SubtractOne(), result_type=DataTypes.BIGINT(), deterministic=False)
self.assertTrue(add._deterministic)
@udf(result_type=DataTypes.BIGINT(), deterministic=False)
def non_deterministic_udf(i):
return i
self.assertFalse(non_deterministic_udf._deterministic)
def test_name(self):
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
self.assertEqual("<lambda>", add_one._name)
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), name="add_one")
self.assertEqual("add_one", add_one._name)
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
self.assertEqual("SubtractOne", subtract_one._name)
subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT(), name="subtract_one")
self.assertEqual("subtract_one", subtract_one._name)
self.assertEqual("add", add._name)
@udf(result_type=DataTypes.BIGINT(), name="named")
def named_udf(i):
return i
self.assertEqual("named", named_udf._name)
def test_abc(self):
class UdfWithoutEval(ScalarFunction):
def open(self, function_context):
pass
with self.assertRaises(
TypeError,
msg="Can't instantiate abstract class UdfWithoutEval with abstract methods eval"):
UdfWithoutEval()
def test_invalid_udf(self):
class Plus(object):
def eval(self, col):
return col + 1
with self.assertRaises(
TypeError,
msg="Invalid function: not a function or callable (__call__ is not defined)"):
# test non-callable function
self.t_env.create_temporary_system_function(
"non-callable-udf", udf(Plus(), DataTypes.BIGINT(), DataTypes.BIGINT()))
def test_data_types_only_supported_in_blink_planner(self):
timezone = self.t_env.get_config().get_local_timezone()
local_datetime = pytz.timezone(timezone).localize(
datetime.datetime(1970, 1, 1, 0, 0, 0, 123000))
@udf(result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3))
def local_zoned_timestamp_func(local_zoned_timestamp_param):
assert local_zoned_timestamp_param == local_datetime, \
'local_zoned_timestamp_param is wrong value %s !' % local_zoned_timestamp_param
return local_zoned_timestamp_param
table_sink = source_sink_utils.TestAppendSink(
['a'], [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements(
[(local_datetime,)],
DataTypes.ROW([DataTypes.FIELD("a", DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3))]))
t.select(local_zoned_timestamp_func(local_zoned_timestamp_func(t.a))) \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["1970-01-01T00:00:00.123Z"])
class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBlinkBatchTableTestCase):
pass
# test specify the input_types
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
def add(i, j):
return i + j
class SubtractOne(ScalarFunction):
def eval(self, i):
return i - 1
class Subtract(ScalarFunction, unittest.TestCase):
def open(self, function_context):
self.subtracted_value = 1
mg = function_context.get_metric_group()
self.counter = mg.add_group("key", "value").counter("my_counter")
self.counter_sum = 0
def eval(self, i):
# counter
self.counter.inc(i)
self.counter_sum += i
return i - self.subtracted_value
class CallablePlus(object):
def __call__(self, col):
return col + 1
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)