blob: 03b99b41b12a2a2c0b15610a31ce6b5b6254adbf [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from pandas.util.testing import assert_frame_equal
from pyflink.common import Row
from pyflink.table import expressions as expr, ListView
from pyflink.table.types import DataTypes
from pyflink.table.udf import udf, udtf, udaf, AggregateFunction, TableAggregateFunction, udtaf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkBatchTableTestCase, \
PyFlinkStreamTableTestCase
class RowBasedOperationTests(object):
def test_map(self):
t = self.t_env.from_elements(
[(1, 2, 3), (2, 1, 3), (1, 5, 4), (1, 8, 6), (2, 3, 4)],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT())]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
func = udf(lambda x: Row(a=x + 1, b=x * x), result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())]))
func2 = udf(lambda x: Row(x.a + 1, x.b * 2), result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())]))
t.map(func(t.b)).alias("a", "b") \
.map(func(t.a)) \
.map(func2) \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(
actual, ["+I[5, 18]", "+I[4, 8]", "+I[8, 72]", "+I[11, 162]", "+I[6, 32]"])
def test_map_with_pandas_udf(self):
t = self.t_env.from_elements(
[(1, Row(2, 3)), (2, Row(1, 3)), (1, Row(5, 4)), (1, Row(8, 6)), (2, Row(3, 4))],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b",
DataTypes.ROW([DataTypes.FIELD("c", DataTypes.INT()),
DataTypes.FIELD("d", DataTypes.INT())]))]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
def func(x):
import pandas as pd
res = pd.concat([x.a, x.c + x.d], axis=1)
return res
def func2(x):
return x * 2
def func3(x):
assert isinstance(x, Row)
return x
pandas_udf = udf(func,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')
pandas_udf_2 = udf(func2,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')
general_udf = udf(func3,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]))
t.map(pandas_udf).map(pandas_udf_2).map(general_udf).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(
actual,
["+I[4, 8]", "+I[2, 10]", "+I[2, 28]", "+I[2, 18]", "+I[4, 14]"])
def test_flat_map(self):
t = self.t_env.from_elements(
[(1, "2,3"), (2, "1"), (1, "5,6,7")],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.STRING())]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f'],
[DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(),
DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()])
self.t_env.register_table_sink("Results", table_sink)
@udtf(result_types=[DataTypes.INT(), DataTypes.STRING()])
def split(x):
for s in x.b.split(","):
yield x.a, s
t.flat_map(split).alias("a", "b") \
.flat_map(split).alias("a", "b") \
.join_lateral(split.alias("c", "d")) \
.left_outer_join_lateral(split.alias("e", "f")) \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(
actual,
["+I[1, 2, 1, 2, 1, 2]", "+I[1, 3, 1, 3, 1, 3]", "+I[2, 1, 2, 1, 2, 1]",
"+I[1, 5, 1, 5, 1, 5]", "+I[1, 6, 1, 6, 1, 6]", "+I[1, 7, 1, 7, 1, 7]"])
class BatchRowBasedOperationITTests(RowBasedOperationTests, PyFlinkBatchTableTestCase):
def test_aggregate_with_pandas_udaf(self):
t = self.t_env.from_elements(
[(1, 2, 3), (2, 1, 3), (1, 5, 4), (1, 8, 6), (2, 3, 4)],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT())]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c'],
[DataTypes.TINYINT(), DataTypes.FLOAT(), DataTypes.INT()])
self.t_env.register_table_sink("Results", table_sink)
pandas_udaf = udaf(lambda pd: (pd.b.mean(), pd.a.max()),
result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.FLOAT()),
DataTypes.FIELD("b", DataTypes.INT())]),
func_type="pandas")
t.select(t.a, t.b) \
.group_by(t.a) \
.aggregate(pandas_udaf) \
.select("*") \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1, 5.0, 1]", "+I[2, 2.0, 2]"])
def test_aggregate_with_pandas_udaf_without_keys(self):
t = self.t_env.from_elements(
[(1, 2, 3), (2, 1, 3), (1, 5, 4), (1, 8, 6), (2, 3, 4)],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT())]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'],
[DataTypes.FLOAT(), DataTypes.INT()])
self.t_env.register_table_sink("Results", table_sink)
pandas_udaf = udaf(lambda pd: Row(pd.b.mean(), pd.b.max()),
result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.FLOAT()),
DataTypes.FIELD("b", DataTypes.INT())]),
func_type="pandas")
t.select(t.b) \
.aggregate(pandas_udaf.alias("a", "b")) \
.select("a, b") \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[3.8, 8]"])
def test_window_aggregate_with_pandas_udaf(self):
import datetime
from pyflink.table.window import Tumble
t = self.t_env.from_elements(
[
(1, 2, 3, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(3, 2, 4, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(2, 1, 2, datetime.datetime(2018, 3, 11, 3, 10, 0, 0)),
(1, 3, 1, datetime.datetime(2018, 3, 11, 3, 40, 0, 0)),
(1, 8, 5, datetime.datetime(2018, 3, 11, 4, 20, 0, 0)),
(2, 3, 6, datetime.datetime(2018, 3, 11, 3, 30, 0, 0))
],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT()),
DataTypes.FIELD("rowtime", DataTypes.TIMESTAMP(3))]))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c'],
[
DataTypes.TIMESTAMP(3),
DataTypes.FLOAT(),
DataTypes.INT()
])
self.t_env.register_table_sink("Results", table_sink)
pandas_udaf = udaf(lambda pd: (pd.b.mean(), pd.b.max()),
result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.FLOAT()),
DataTypes.FIELD("b", DataTypes.INT())]),
func_type="pandas")
tumble_window = Tumble.over(expr.lit(1).hours) \
.on(expr.col("rowtime")) \
.alias("w")
t.select(t.b, t.rowtime) \
.window(tumble_window) \
.group_by("w") \
.aggregate(pandas_udaf.alias("d", "e")) \
.select("w.rowtime, d, e") \
.execute_insert("Results") \
.wait()
actual = source_sink_utils.results()
self.assert_equals(actual,
["+I[2018-03-11 03:59:59.999, 2.2, 3]",
"+I[2018-03-11 04:59:59.999, 8.0, 8]"])
class StreamRowBasedOperationITTests(RowBasedOperationTests, PyFlinkStreamTableTestCase):
def test_aggregate(self):
import pandas as pd
t = self.t_env.from_elements(
[(1, 2, 3), (2, 1, 3), (1, 5, 4), (1, 8, 6), (2, 3, 4)],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
DataTypes.FIELD("c", DataTypes.INT())]))
function = CountAndSumAggregateFunction()
agg = udaf(function,
result_type=function.get_result_type(),
accumulator_type=function.get_accumulator_type(),
name=str(function.__class__.__name__))
result = t.group_by(t.a) \
.aggregate(agg.alias("c", "d")) \
.select("a, c, d") \
.to_pandas()
assert_frame_equal(result.sort_values('a').reset_index(drop=True),
pd.DataFrame([[1, 3, 15], [2, 2, 4]], columns=['a', 'c', 'd']))
def test_flat_aggregate(self):
import pandas as pd
mytop = udtaf(Top2())
t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
(3, 'Hi', 'hi'),
(5, 'Hi2', 'hi'),
(7, 'Hi', 'Hello'),
(2, 'Hi', 'Hello')], ['a', 'b', 'c'])
result = t.select(t.a, t.c) \
.group_by(t.c) \
.flat_aggregate(mytop.alias('a')) \
.select(t.a) \
.flat_aggregate(mytop.alias("b")) \
.select("b") \
.to_pandas()
assert_frame_equal(result, pd.DataFrame([[7], [5]], columns=['b']))
def test_flat_aggregate_list_view(self):
import pandas as pd
my_concat = udtaf(ListViewConcatTableAggregateFunction())
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "2")
t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
(3, 'Hi', 'hi'),
(3, 'Hi2', 'hi'),
(3, 'Hi', 'hi'),
(2, 'Hi', 'Hello'),
(1, 'Hi2', 'Hello'),
(3, 'Hi3', 'hi'),
(3, 'Hi2', 'Hello'),
(3, 'Hi3', 'hi'),
(2, 'Hi3', 'Hello')], ['a', 'b', 'c'])
result = t.group_by(t.c) \
.flat_aggregate(my_concat(t.b, ',').alias("b")) \
.select(t.b, t.c) \
.alias("a, c")
assert_frame_equal(result.to_pandas().sort_values('c').reset_index(drop=True),
pd.DataFrame([["Hi,Hi,Hi2,Hi2,Hi3", "Hello"],
["Hi,Hi,Hi2,Hi2,Hi3", "Hello"],
["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
["Hi,Hi2,Hi,Hi3,Hi3", "hi"]],
columns=['a', 'c']))
class CountAndSumAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
from pyflink.common import Row
return Row(accumulator[0], accumulator[1])
def create_accumulator(self):
from pyflink.common import Row
return Row(0, 0)
def accumulate(self, accumulator, row: Row):
accumulator[0] += 1
accumulator[1] += row.b
def retract(self, accumulator, row: Row):
accumulator[0] -= 1
accumulator[1] -= row.a
def merge(self, accumulator, accumulators):
for other_acc in accumulators:
accumulator[0] += other_acc[0]
accumulator[1] += other_acc[1]
def get_accumulator_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())])
class Top2(TableAggregateFunction):
def emit_value(self, accumulator):
accumulator.sort()
accumulator.reverse()
size = len(accumulator)
if size > 1:
yield accumulator[0]
if size > 2:
yield accumulator[1]
def create_accumulator(self):
return []
def accumulate(self, accumulator, row: Row):
accumulator.append(row.a)
def retract(self, accumulator, row: Row):
accumulator.remove(row.a)
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
def get_result_type(self):
return DataTypes.BIGINT()
class ListViewConcatTableAggregateFunction(TableAggregateFunction):
def emit_value(self, accumulator):
result = accumulator[1].join(accumulator[0])
yield Row(result)
yield Row(result)
def create_accumulator(self):
return Row(ListView(), '')
def accumulate(self, accumulator, *args):
accumulator[1] = args[1]
accumulator[0].add(args[0])
def retract(self, accumulator, *args):
raise NotImplementedError
def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.ROW([DataTypes.FIELD("a", DataTypes.STRING())])
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)