blob: 26a70aa930af8ba83609a106e5e195a4095defed [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import collections
import datetime
from decimal import Decimal
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pyflink.common import Row, RowKind
from pyflink.fn_execution.state_impl import RemovableConcatIterator
from pyflink.table import DataTypes
from pyflink.table.data_view import ListView, MapView
from pyflink.table.expressions import col
from pyflink.table.udf import AggregateFunction, udaf
from pyflink.testing.test_case_utils import PyFlinkBlinkStreamTableTestCase
class CountAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
return accumulator[0]
def create_accumulator(self):
return [0]
def accumulate(self, accumulator, *args):
accumulator[0] = accumulator[0] + 1
def retract(self, accumulator, *args):
accumulator[0] = accumulator[0] - 1
def merge(self, accumulator, accumulators):
for other_acc in accumulators:
accumulator[0] = accumulator[0] + other_acc[0]
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
def get_result_type(self):
return DataTypes.BIGINT()
class SumAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
return accumulator[0]
def create_accumulator(self):
return [0]
def accumulate(self, accumulator, *args):
accumulator[0] = accumulator[0] + args[0]
def retract(self, accumulator, *args):
accumulator[0] = accumulator[0] - args[0]
def merge(self, accumulator, accumulators):
for other_acc in accumulators:
accumulator[0] = accumulator[0] + other_acc[0]
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
def get_result_type(self):
return DataTypes.BIGINT()
class ConcatAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
str_list = [i for i in accumulator[0]]
str_list.sort()
return accumulator[1].join(str_list)
def create_accumulator(self):
return Row([], '')
def accumulate(self, accumulator, *args):
if args[0] is not None:
accumulator[1] = args[1]
accumulator[0].append(args[0])
def retract(self, accumulator, *args):
if args[0] is not None:
accumulator[0].remove(args[0])
def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.ARRAY(DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.STRING()
class ListViewConcatAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
return accumulator[1].join(accumulator[0])
def create_accumulator(self):
return Row(ListView(), '')
def accumulate(self, accumulator, *args):
accumulator[1] = args[1]
accumulator[0].add(args[0])
def retract(self, accumulator, *args):
raise NotImplementedError
def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.STRING()
class CountDistinctAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
return accumulator[1]
def create_accumulator(self):
return Row(MapView(), 0)
def accumulate(self, accumulator, *args):
input_str = args[0]
if accumulator[0].is_empty() or input_str not in accumulator[0] \
or accumulator[0][input_str] is None:
accumulator[0][input_str] = 1
accumulator[1] += 1
else:
accumulator[0][input_str] += 1
if input_str == "clear":
accumulator[0].clear()
accumulator[1] = 0
def retract(self, accumulator, *args):
input_str = args[0]
if accumulator[0].is_empty() or input_str not in accumulator[0]:
return
accumulator[0].put_all({input_str: accumulator[0][input_str] - 1})
if accumulator[0][input_str] <= 0:
accumulator[1] -= 1
accumulator[0][input_str] = None
def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.BIGINT()
class TestIterateAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
# test iterate keys
key_set = [i for i in accumulator[0]]
key_set.sort()
# test iterate values
value_set = [str(i) for i in accumulator[0].values()]
value_set.sort()
item_set = {}
# test iterate items
for key, value in accumulator[0].items():
item_set[key] = value
ordered_item_set = collections.OrderedDict()
for key in key_set:
ordered_item_set[key] = str(item_set[key])
try:
# test auto clear the cached iterators
next(iter(accumulator[0].items()))
except StopIteration:
pass
return Row(",".join(key_set),
','.join(value_set),
",".join([":".join(item) for item in ordered_item_set.items()]),
accumulator[1])
def create_accumulator(self):
return Row(MapView(), 0)
def accumulate(self, accumulator, *args):
input_str = args[0]
if input_str not in accumulator[0]:
accumulator[0][input_str] = 1
accumulator[1] += 1
else:
accumulator[0][input_str] += 1
def retract(self, accumulator, *args):
input_str = args[0]
if input_str not in accumulator[0]:
return
accumulator[0][input_str] -= 1
if accumulator[0][input_str] == 0:
# test removable iterator
key_iter = iter(accumulator[0].keys()) # type: RemovableConcatIterator
while True:
try:
key = next(key_iter)
if key == input_str:
key_iter.remove()
except StopIteration:
break
accumulator[1] -= 1
def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.MAP_VIEW(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.STRING()),
DataTypes.FIELD("f1", DataTypes.STRING()),
DataTypes.FIELD("f2", DataTypes.STRING()),
DataTypes.FIELD("f3", DataTypes.BIGINT())])
class StreamTableAggregateTests(PyFlinkBlinkStreamTableTestCase):
def test_double_aggregate(self):
self.t_env.register_function("my_count", CountAggregateFunction())
self.t_env.create_temporary_function("my_sum", SumAggregateFunction())
# trigger the finish bundle more frequently to ensure testing the communication
# between RemoteKeyedStateBackend and the StateGrpcService.
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "1")
t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
(3, 'Hi', 'hi'),
(3, 'Hi2', 'hi'),
(3, 'Hi', 'hi2'),
(2, 'Hi', 'Hello')], ['a', 'b', 'c'])
result = t.group_by(t.c).select("my_count(a) as a, my_sum(a) as b, c") \
.select("my_count(a) as a, my_sum(b) as b, sum0(b) as c, sum0(b.cast(double)) as d")
assert_frame_equal(result.to_pandas(),
pd.DataFrame([[3, 12, 12, 12.0]], columns=['a', 'b', 'c', 'd']))
def test_mixed_with_built_in_functions_with_retract(self):
self.env.set_parallelism(1)
self.t_env.create_temporary_system_function(
"concat",
ConcatAggregateFunction())
t = self.t_env.from_elements(
[(1, 'Hi_', 1),
(1, 'Hi', 2),
(2, 'Hi_', 3),
(2, 'Hi', 4),
(3, None, None),
(3, None, None),
(4, 'hello2_', 7),
(4, 'hello2', 8),
(5, 'hello_', 9),
(5, 'hello', 10)], ['a', 'b', 'c'])
self.t_env.create_temporary_view("source", t)
table_with_retract_message = self.t_env.sql_query(
"select a, LAST_VALUE(b) as b, LAST_VALUE(c) as c from source group by a")
self.t_env.create_temporary_view("retract_table", table_with_retract_message)
result_table = self.t_env.sql_query(
"select concat(b, ',') as a, "
"FIRST_VALUE(b) as b, "
"LAST_VALUE(b) as c, "
"COUNT(c) as d, "
"COUNT(1) as e, "
"LISTAGG(b) as f,"
"LISTAGG(b, '|') as g,"
"MAX(c) as h,"
"MAX(cast(c as float) + 1) as i,"
"MIN(c) as j,"
"MIN(cast(c as decimal) + 1) as k,"
"SUM(c) as l,"
"SUM(cast(c as float) + 1) as m,"
"AVG(c) as n,"
"AVG(cast(c as double) + 1) as o,"
"STDDEV_POP(cast(c as float)),"
"STDDEV_SAMP(cast(c as float)),"
"VAR_POP(cast(c as float)),"
"VAR_SAMP(cast(c as float))"
" from retract_table")
result = [i for i in result_table.execute().collect()]
expected = Row('Hi,Hi,hello,hello2', 'Hi', 'hello', 4, 5, 'Hi,Hi,hello2,hello',
'Hi|Hi|hello2|hello', 10, 11.0, 2, Decimal(3.0), 24, 28.0, 6, 7.0,
3.1622777, 3.6514838, 10.0, 13.333333)
expected.set_row_kind(RowKind.UPDATE_AFTER)
self.assertEqual(result[len(result) - 1], expected)
def test_mixed_with_built_in_functions_without_retract(self):
self.env.set_parallelism(1)
self.t_env.create_temporary_system_function(
"concat",
ConcatAggregateFunction())
t = self.t_env.from_elements(
[('Hi', 2),
('Hi', 4),
(None, None),
('hello2', 8),
('hello', 10)], ['b', 'c'])
self.t_env.create_temporary_view("source", t)
result_table = self.t_env.sql_query(
"select concat(b, ',') as a, "
"FIRST_VALUE(b) as b, "
"LAST_VALUE(b) as c, "
"COUNT(c) as d, "
"COUNT(1) as e, "
"LISTAGG(b) as f,"
"LISTAGG(b, '|') as g,"
"MAX(c) as h,"
"MAX(cast(c as float) + 1) as i,"
"MIN(c) as j,"
"MIN(cast(c as decimal) + 1) as k,"
"SUM(c) as l,"
"SUM(cast(c as float) + 1) as m "
"from source")
result = [i for i in result_table.execute().collect()]
expected = Row('Hi,Hi,hello,hello2', 'Hi', 'hello', 4, 5, 'Hi,Hi,hello2,hello',
'Hi|Hi|hello2|hello', 10, 11.0, 2, Decimal(3.0), 24, 28.0)
expected.set_row_kind(RowKind.UPDATE_AFTER)
self.assertEqual(result[len(result) - 1], expected)
def test_using_decorator(self):
my_count = udaf(CountAggregateFunction(),
accumulator_type=DataTypes.ARRAY(DataTypes.INT()),
result_type=DataTypes.INT())
t = self.t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c'])
result = t.group_by(t.c) \
.select(my_count(t.a).alias("a"), t.c.alias("b"))
plan = result.explain()
result_type = result.get_schema().get_field_data_type(0)
self.assertTrue(plan.find("PythonGroupAggregate(groupBy=[c], ") >= 0)
self.assertEqual(result_type, DataTypes.INT())
def test_list_view(self):
my_concat = udaf(ListViewConcatAggregateFunction())
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "2")
t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
(3, 'Hi', 'hi'),
(3, 'Hi2', 'hi'),
(3, 'Hi', 'hi'),
(2, 'Hi', 'Hello'),
(1, 'Hi2', 'Hello'),
(3, 'Hi3', 'hi'),
(3, 'Hi2', 'Hello'),
(3, 'Hi3', 'hi'),
(2, 'Hi3', 'Hello')], ['a', 'b', 'c'])
result = t.group_by(t.c).select(my_concat(t.b, ',').alias("a"), t.c)
assert_frame_equal(result.to_pandas(),
pd.DataFrame([["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
["Hi,Hi,Hi2,Hi2,Hi3", "Hello"]], columns=['a', 'c']))
def test_map_view(self):
my_count = udaf(CountDistinctAggregateFunction())
self.t_env.get_config().set_idle_state_retention(datetime.timedelta(days=1))
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "1")
self.t_env.get_config().get_configuration().set_string(
"python.map-state.read-cache-size", "1")
self.t_env.get_config().get_configuration().set_string(
"python.map-state.write-cache-size", "1")
t = self.t_env.from_elements(
[(1, 'Hi_', 'hi'),
(1, 'Hi', 'hi'),
(2, 'hello', 'hello'),
(3, 'Hi_', 'hi'),
(3, 'Hi', 'hi'),
(4, 'hello', 'hello'),
(5, 'Hi2_', 'hi'),
(5, 'Hi2', 'hi'),
(6, 'hello2', 'hello'),
(7, 'Hi', 'hi'),
(8, 'hello', 'hello'),
(9, 'Hi2', 'hi'),
(13, 'Hi3', 'hi')], ['a', 'b', 'c'])
self.t_env.create_temporary_view("source", t)
table_with_retract_message = self.t_env.sql_query(
"select LAST_VALUE(b) as b, LAST_VALUE(c) as c from source group by a")
result = table_with_retract_message.group_by(t.c).select(my_count(t.b).alias("a"), t.c)
assert_frame_equal(result.to_pandas(),
pd.DataFrame([[2, "hello"],
[3, "hi"]], columns=['a', 'c']))
def test_data_view_clear(self):
my_count = udaf(CountDistinctAggregateFunction())
self.t_env.get_config().set_idle_state_retention(datetime.timedelta(days=1))
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "1")
t = self.t_env.from_elements(
[(2, 'hello', 'hello'),
(4, 'clear', 'hello'),
(6, 'hello2', 'hello'),
(8, 'hello', 'hello')], ['a', 'b', 'c'])
result = t.group_by(t.c).select(my_count(t.b).alias("a"), t.c)
assert_frame_equal(result.to_pandas(),
pd.DataFrame([[2, "hello"]], columns=['a', 'c']))
def test_map_view_iterate(self):
test_iterate = udaf(TestIterateAggregateFunction())
self.t_env.get_config().set_idle_state_retention(datetime.timedelta(days=1))
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
self.t_env.get_config().get_configuration().set_string(
"python.state.cache-size", "2")
self.t_env.get_config().get_configuration().set_string(
"python.map-state.read-cache-size", "2")
self.t_env.get_config().get_configuration().set_string(
"python.map-state.write-cache-size", "2")
self.t_env.get_config().get_configuration().set_string(
"python.map-state.iterate-response-batch-size", "2")
t = self.t_env.from_elements(
[(1, 'Hi_', 'hi'),
(1, 'Hi', 'hi'),
(2, 'hello', 'hello'),
(3, 'Hi_', 'hi'),
(3, 'Hi', 'hi'),
(4, 'hello', 'hello'),
(5, 'Hi2_', 'hi'),
(5, 'Hi2', 'hi'),
(6, 'hello2', 'hello'),
(7, 'Hi', 'hi'),
(8, 'hello', 'hello'),
(9, 'Hi2', 'hi'),
(13, 'Hi3', 'hi')], ['a', 'b', 'c'])
self.t_env.create_temporary_view("source", t)
table_with_retract_message = self.t_env.sql_query(
"select LAST_VALUE(b) as b, LAST_VALUE(c) as c from source group by a")
result = table_with_retract_message.group_by(t.c) \
.select(test_iterate(t.b).alias("a"), t.c) \
.select(col("a").get(0).alias("a"),
col("a").get(1).alias("b"),
col("a").get(2).alias("c"),
col("a").get(3).alias("d"),
t.c.alias("e"))
assert_frame_equal(
result.to_pandas(),
pd.DataFrame([
["hello,hello2", "1,3", 'hello:3,hello2:1', 2, "hello"],
["Hi,Hi2,Hi3", "1,2,3", "Hi:3,Hi2:2,Hi3:1", 3, "hi"]],
columns=['a', 'b', 'c', 'd', 'e']))
def test_distinct_and_filter(self):
self.t_env.create_temporary_system_function(
"concat",
ConcatAggregateFunction())
t = self.t_env.from_elements(
[(1, 'Hi_', 'hi'),
(1, 'Hi', 'hi'),
(2, 'hello', 'hello'),
(3, 'Hi_', 'hi'),
(3, 'Hi', 'hi'),
(4, 'hello', 'hello'),
(5, 'Hi2_', 'hi'),
(5, 'Hi2', 'hi'),
(6, 'hello2', 'hello'),
(7, 'Hi', 'hi'),
(8, 'hello', 'hello'),
(9, 'Hi2', 'hi'),
(13, 'Hi3', 'hi')], ['a', 'b', 'c'])
self.t_env.create_temporary_view("source", t)
table_with_retract_message = self.t_env.sql_query(
"select LAST_VALUE(b) as b, LAST_VALUE(c) as c from source group by a")
self.t_env.create_temporary_view("retract_table", table_with_retract_message)
result = self.t_env.sql_query(
"select concat(distinct b, '.') as a, "
"concat(distinct b, ',') filter (where c = 'hi') as b, "
"concat(distinct b, ',') filter (where c = 'hello') as c, "
"c as d "
"from retract_table group by c")
assert_frame_equal(result.to_pandas().sort_values(by='a').reset_index(drop=True),
pd.DataFrame([["Hi.Hi2.Hi3", "Hi,Hi2,Hi3", "", "hi"],
["hello.hello2", "", "hello,hello2", "hello"]],
columns=['a', 'b', 'c', 'd']))
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)