blob: 415b23d4690dd075c10b185f9eca63d1fe4ed78d [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
import decimal
import pytz
from pyflink.common import Row
from pyflink.table import DataTypes
from pyflink.table.tests.test_udf import SubtractOne
from pyflink.table.udf import udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkBatchTableTestCase, \
PyFlinkStreamTableTestCase, PyFlinkTestCase
class PandasUDFTests(PyFlinkTestCase):
def test_non_exist_func_type(self):
with self.assertRaisesRegex(ValueError,
'The func_type must be one of \'general, pandas\''):
udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), func_type="non-exist")
class PandasUDFITTests(object):
def test_basic_functionality(self):
# pandas UDF
add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), func_type="pandas")
# general Python UDF
subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
t.where(add_one(t.b) <= 3) \
.select(t.a, t.b + 1, add(t.a + 1, subtract_one(t.c)) + 2, add(add_one(t.a), 1)) \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1, 3, 6, 3]", "+I[3, 2, 14, 5]"])
def test_all_data_types(self):
import pandas as pd
import numpy as np
@udf(result_type=DataTypes.TINYINT(), func_type="pandas")
def tinyint_func(tinyint_param):
assert isinstance(tinyint_param, pd.Series)
assert isinstance(tinyint_param[0], np.int8), \
'tinyint_param of wrong type %s !' % type(tinyint_param[0])
return tinyint_param
@udf(result_type=DataTypes.SMALLINT(), func_type="pandas")
def smallint_func(smallint_param):
assert isinstance(smallint_param, pd.Series)
assert isinstance(smallint_param[0], np.int16), \
'smallint_param of wrong type %s !' % type(smallint_param[0])
assert smallint_param[0] == 32767, 'smallint_param of wrong value %s' % smallint_param
return smallint_param
@udf(result_type=DataTypes.INT(), func_type="pandas")
def int_func(int_param):
assert isinstance(int_param, pd.Series)
assert isinstance(int_param[0], np.int32), \
'int_param of wrong type %s !' % type(int_param[0])
assert int_param[0] == -2147483648, 'int_param of wrong value %s' % int_param
return int_param
@udf(result_type=DataTypes.BIGINT(), func_type="pandas")
def bigint_func(bigint_param):
assert isinstance(bigint_param, pd.Series)
assert isinstance(bigint_param[0], np.int64), \
'bigint_param of wrong type %s !' % type(bigint_param[0])
return bigint_param
@udf(result_type=DataTypes.BOOLEAN(), func_type="pandas")
def boolean_func(boolean_param):
assert isinstance(boolean_param, pd.Series)
assert isinstance(boolean_param[0], np.bool_), \
'boolean_param of wrong type %s !' % type(boolean_param[0])
return boolean_param
@udf(result_type=DataTypes.FLOAT(), func_type="pandas")
def float_func(float_param):
assert isinstance(float_param, pd.Series)
assert isinstance(float_param[0], np.float32), \
'float_param of wrong type %s !' % type(float_param[0])
return float_param
@udf(result_type=DataTypes.DOUBLE(), func_type="pandas")
def double_func(double_param):
assert isinstance(double_param, pd.Series)
assert isinstance(double_param[0], np.float64), \
'double_param of wrong type %s !' % type(double_param[0])
return double_param
@udf(result_type=DataTypes.STRING(), func_type="pandas")
def varchar_func(varchar_param):
assert isinstance(varchar_param, pd.Series)
assert isinstance(varchar_param[0], str), \
'varchar_param of wrong type %s !' % type(varchar_param[0])
return varchar_param
@udf(result_type=DataTypes.BYTES(), func_type="pandas")
def varbinary_func(varbinary_param):
assert isinstance(varbinary_param, pd.Series)
assert isinstance(varbinary_param[0], bytes), \
'varbinary_param of wrong type %s !' % type(varbinary_param[0])
return varbinary_param
@udf(result_type=DataTypes.DECIMAL(38, 18), func_type="pandas")
def decimal_func(decimal_param):
assert isinstance(decimal_param, pd.Series)
assert isinstance(decimal_param[0], decimal.Decimal), \
'decimal_param of wrong type %s !' % type(decimal_param[0])
return decimal_param
@udf(result_type=DataTypes.DATE(), func_type="pandas")
def date_func(date_param):
assert isinstance(date_param, pd.Series)
assert isinstance(date_param[0], datetime.date), \
'date_param of wrong type %s !' % type(date_param[0])
return date_param
@udf(result_type=DataTypes.TIME(), func_type="pandas")
def time_func(time_param):
assert isinstance(time_param, pd.Series)
assert isinstance(time_param[0], datetime.time), \
'time_param of wrong type %s !' % type(time_param[0])
return time_param
timestamp_value = datetime.datetime(1970, 1, 2, 0, 0, 0, 123000)
@udf(result_type=DataTypes.TIMESTAMP(3), func_type="pandas")
def timestamp_func(timestamp_param):
assert isinstance(timestamp_param, pd.Series)
assert isinstance(timestamp_param[0], datetime.datetime), \
'timestamp_param of wrong type %s !' % type(timestamp_param[0])
assert timestamp_param[0] == timestamp_value, \
'timestamp_param is wrong value %s, should be %s!' % (timestamp_param[0],
timestamp_value)
return timestamp_param
def array_func(array_param):
assert isinstance(array_param, pd.Series)
assert isinstance(array_param[0], np.ndarray), \
'array_param of wrong type %s !' % type(array_param[0])
return array_param
array_str_func = udf(array_func,
result_type=DataTypes.ARRAY(DataTypes.STRING()),
func_type="pandas")
array_timestamp_func = udf(array_func,
result_type=DataTypes.ARRAY(DataTypes.TIMESTAMP(3)),
func_type="pandas")
array_int_func = udf(array_func,
result_type=DataTypes.ARRAY(DataTypes.INT()),
func_type="pandas")
@udf(result_type=DataTypes.ARRAY(DataTypes.STRING()), func_type="pandas")
def nested_array_func(nested_array_param):
assert isinstance(nested_array_param, pd.Series)
assert isinstance(nested_array_param[0], np.ndarray), \
'nested_array_param of wrong type %s !' % type(nested_array_param[0])
return pd.Series(nested_array_param[0])
row_type = DataTypes.ROW(
[DataTypes.FIELD("f1", DataTypes.INT()),
DataTypes.FIELD("f2", DataTypes.STRING()),
DataTypes.FIELD("f3", DataTypes.TIMESTAMP(3)),
DataTypes.FIELD("f4", DataTypes.ARRAY(DataTypes.INT()))])
@udf(result_type=row_type, func_type="pandas")
def row_func(row_param):
assert isinstance(row_param, pd.DataFrame)
assert isinstance(row_param.f1, pd.Series)
assert isinstance(row_param.f1[0], np.int32), \
'row_param.f1 of wrong type %s !' % type(row_param.f1[0])
assert isinstance(row_param.f2, pd.Series)
assert isinstance(row_param.f2[0], str), \
'row_param.f2 of wrong type %s !' % type(row_param.f2[0])
assert isinstance(row_param.f3, pd.Series)
assert isinstance(row_param.f3[0], datetime.datetime), \
'row_param.f3 of wrong type %s !' % type(row_param.f3[0])
assert isinstance(row_param.f4, pd.Series)
assert isinstance(row_param.f4[0], np.ndarray), \
'row_param.f4 of wrong type %s !' % type(row_param.f4[0])
return row_param
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
'r', 's', 't', 'u'],
[DataTypes.TINYINT(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.BIGINT(),
DataTypes.BOOLEAN(), DataTypes.BOOLEAN(), DataTypes.FLOAT(), DataTypes.DOUBLE(),
DataTypes.STRING(), DataTypes.STRING(), DataTypes.BYTES(), DataTypes.DECIMAL(38, 18),
DataTypes.DECIMAL(38, 18), DataTypes.DATE(), DataTypes.TIME(), DataTypes.TIMESTAMP(3),
DataTypes.ARRAY(DataTypes.STRING()), DataTypes.ARRAY(DataTypes.TIMESTAMP(3)),
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.STRING()), row_type])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements(
[(1, 32767, -2147483648, 1, True, False, 1.0, 1.0, 'hello', '中文',
bytearray(b'flink'), decimal.Decimal('1000000000000000000.05'),
decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1),
timestamp_value, ['hello', '中文', None], [timestamp_value], [1, 2],
[['hello', '中文', None]], Row(1, 'hello', timestamp_value, [1, 2]))],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT()),
DataTypes.FIELD("d", DataTypes.BIGINT()),
DataTypes.FIELD("e", DataTypes.BOOLEAN()),
DataTypes.FIELD("f", DataTypes.BOOLEAN()),
DataTypes.FIELD("g", DataTypes.FLOAT()),
DataTypes.FIELD("h", DataTypes.DOUBLE()),
DataTypes.FIELD("i", DataTypes.STRING()),
DataTypes.FIELD("j", DataTypes.STRING()),
DataTypes.FIELD("k", DataTypes.BYTES()),
DataTypes.FIELD("l", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("m", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("n", DataTypes.DATE()),
DataTypes.FIELD("o", DataTypes.TIME()),
DataTypes.FIELD("p", DataTypes.TIMESTAMP(3)),
DataTypes.FIELD("q", DataTypes.ARRAY(DataTypes.STRING())),
DataTypes.FIELD("r", DataTypes.ARRAY(DataTypes.TIMESTAMP(3))),
DataTypes.FIELD("s", DataTypes.ARRAY(DataTypes.INT())),
DataTypes.FIELD("t", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))),
DataTypes.FIELD("u", row_type)]))
t.select(
tinyint_func(t.a),
smallint_func(t.b),
int_func(t.c),
bigint_func(t.d),
boolean_func(t.e),
boolean_func(t.f),
float_func(t.g),
double_func(t.h),
varchar_func(t.i),
varchar_func(t.j),
varbinary_func(t.k),
decimal_func(t.l),
decimal_func(t.m),
date_func(t.n),
time_func(t.o),
timestamp_func(t.p),
array_str_func(t.q),
array_timestamp_func(t.r),
array_int_func(t.s),
nested_array_func(t.t),
row_func(t.u)) \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(
actual,
["+I[1, 32767, -2147483648, 1, true, false, 1.0, 1.0, hello, 中文, "
"[102, 108, 105, 110, 107], 1000000000000000000.050000000000000000, "
"1000000000000000000.059999999999999999, 2014-09-13, 01:00:01, "
"1970-01-02 00:00:00.123, [hello, 中文, null], [1970-01-02 00:00:00.123], "
"[1, 2], [hello, 中文, null], +I[1, hello, 1970-01-02 00:00:00.123, [1, 2]]]"])
def test_invalid_pandas_udf(self):
@udf(result_type=DataTypes.INT(), func_type="pandas")
def length_mismatch(i):
return i[1:]
@udf(result_type=DataTypes.INT(), func_type="pandas")
def result_type_not_series(i):
return i.iloc[0]
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
msg = "The result length '0' of Pandas UDF 'length_mismatch' is not equal " \
"to the input length '1'"
from py4j.protocol import Py4JJavaError
with self.assertRaisesRegex(Py4JJavaError, expected_regex=msg):
t.select(length_mismatch(t.a)).to_pandas()
msg = "The result type of Pandas UDF 'result_type_not_series' must be pandas.Series or " \
"pandas.DataFrame, got <class 'numpy.int64'>"
from py4j.protocol import Py4JJavaError
with self.assertRaisesRegex(Py4JJavaError, expected_regex=msg):
t.select(result_type_not_series(t.a)).to_pandas()
def test_data_types(self):
import pandas as pd
timezone = self.t_env.get_config().get_local_timezone()
local_datetime = pytz.timezone(timezone).localize(
datetime.datetime(1970, 1, 2, 0, 0, 0, 123000))
@udf(result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3), func_type="pandas")
def local_zoned_timestamp_func(local_zoned_timestamp_param):
assert isinstance(local_zoned_timestamp_param, pd.Series)
assert isinstance(local_zoned_timestamp_param[0], datetime.datetime), \
'local_zoned_timestamp_param of wrong type %s !' % type(
local_zoned_timestamp_param[0])
assert local_zoned_timestamp_param[0] == local_datetime, \
'local_zoned_timestamp_param is wrong value %s, %s!' % \
(local_zoned_timestamp_param[0], local_datetime)
return local_zoned_timestamp_param
table_sink = source_sink_utils.TestAppendSink(
['a'], [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements(
[(local_datetime,)],
DataTypes.ROW([DataTypes.FIELD("a", DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3))]))
t.select(local_zoned_timestamp_func(local_zoned_timestamp_func(t.a))) \
.execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1970-01-02T00:00:00.123Z]"])
class BatchPandasUDFITTests(PandasUDFITTests,
PyFlinkBatchTableTestCase):
pass
class StreamPandasUDFITTests(PandasUDFITTests,
PyFlinkStreamTableTestCase):
pass
@udf(result_type=DataTypes.BIGINT(), func_type='pandas')
def add(i, j):
return i + j
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)