| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import datetime |
| import decimal |
| import os |
| import uuid |
| |
| from pyflink.common.typeinfo import Types |
| from pyflink.common.watermark_strategy import WatermarkStrategy, TimestampAssigner |
| from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic |
| from pyflink.datastream.data_stream import DataStream |
| from pyflink.datastream.functions import FilterFunction, ProcessFunction, KeyedProcessFunction, \ |
| RuntimeContext |
| from pyflink.datastream.functions import KeySelector |
| from pyflink.datastream.functions import MapFunction, FlatMapFunction |
| from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction |
| from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction |
| from pyflink.java_gateway import get_gateway |
| from pyflink.common import Row |
| from pyflink.testing.test_case_utils import PyFlinkTestCase, invoke_java_object_method |
| |
| |
| class DataStreamTests(PyFlinkTestCase): |
| |
| def setUp(self) -> None: |
| self.env = StreamExecutionEnvironment.get_execution_environment() |
| self.env.set_parallelism(2) |
| getConfigurationMethod = invoke_java_object_method( |
| self.env._j_stream_execution_environment, "getConfiguration") |
| getConfigurationMethod.setString("akka.ask.timeout", "20 s") |
| self.test_sink = DataStreamTestSinkFunction() |
| |
| def test_data_stream_name(self): |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) |
| test_name = 'test_name' |
| ds.name(test_name) |
| self.assertEqual(test_name, ds.get_name()) |
| |
| def test_set_parallelism(self): |
| parallelism = 3 |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) |
| ds.set_parallelism(parallelism).add_sink(self.test_sink) |
| plan = eval(str(self.env.get_execution_plan())) |
| self.assertEqual(parallelism, plan['nodes'][1]['parallelism']) |
| |
| def test_set_max_parallelism(self): |
| max_parallelism = 4 |
| self.env.set_parallelism(8) |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) |
| ds.set_parallelism(max_parallelism).add_sink(self.test_sink) |
| plan = eval(str(self.env.get_execution_plan())) |
| self.assertEqual(max_parallelism, plan['nodes'][1]['parallelism']) |
| |
| def test_force_non_parallel(self): |
| self.env.set_parallelism(8) |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) |
| ds.force_non_parallel().add_sink(self.test_sink) |
| plan = eval(str(self.env.get_execution_plan())) |
| self.assertEqual(1, plan['nodes'][0]['parallelism']) |
| |
| def test_reduce_function_without_data_types(self): |
| ds = self.env.from_collection([(1, 'a'), (2, 'a'), (3, 'a'), (4, 'b')], |
| type_info=Types.ROW([Types.INT(), Types.STRING()])) |
| ds.key_by(lambda a: a[1]) \ |
| .reduce(lambda a, b: Row(a[0] + b[0], b[1])) \ |
| .add_sink(self.test_sink) |
| self.env.execute('reduce_function_test') |
| result = self.test_sink.get_results() |
| expected = ["1,a", "3,a", "6,a", "4,b"] |
| expected.sort() |
| result.sort() |
| self.assertEqual(expected, result) |
| |
| def test_map_function_without_data_types(self): |
| self.env.set_parallelism(1) |
| ds = self.env.from_collection([('ab', decimal.Decimal(1)), |
| ('bdc', decimal.Decimal(2)), |
| ('cfgs', decimal.Decimal(3)), |
| ('deeefg', decimal.Decimal(4))], |
| type_info=Types.ROW([Types.STRING(), Types.BIG_DEC()])) |
| ds.map(MyMapFunction()).add_sink(self.test_sink) |
| self.env.execute('map_function_test') |
| results = self.test_sink.get_results(True) |
| expected = ["<Row('ab', 2, Decimal('1'))>", "<Row('bdc', 3, Decimal('2'))>", |
| "<Row('cfgs', 4, Decimal('3'))>", "<Row('deeefg', 6, Decimal('4'))>"] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_map_function_with_data_types(self): |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.TUPLE([Types.STRING(), Types.INT()])) |
| |
| def map_func(value): |
| result = Row(value[0], len(value[0]), value[1]) |
| return result |
| |
| ds.map(map_func, output_type=Types.ROW([Types.STRING(), Types.INT(), Types.INT()]))\ |
| .add_sink(self.test_sink) |
| self.env.execute('map_function_test') |
| results = self.test_sink.get_results(False) |
| expected = ['ab,2,1', 'bdc,3,2', 'cfgs,4,3', 'deeefg,6,4'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_co_map_function_without_data_types(self): |
| self.env.set_parallelism(1) |
| ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], |
| type_info=Types.ROW([Types.INT(), Types.INT()])) |
| ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], |
| type_info=Types.ROW([Types.STRING(), Types.STRING()])) |
| ds1.connect(ds2).map(MyCoMapFunction()).add_sink(self.test_sink) |
| self.env.execute('co_map_function_test') |
| results = self.test_sink.get_results(True) |
| expected = ['2', '3', '4', 'a', 'b', 'c'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_connected_streams_with_dependency(self): |
| python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4())) |
| os.mkdir(python_file_dir) |
| python_file_path = os.path.join(python_file_dir, "test_stream_dependency_manage_lib.py") |
| with open(python_file_path, 'w') as f: |
| f.write("def add_two(a):\n return a + 2") |
| |
| class TestCoMapFunction(CoMapFunction): |
| |
| def map1(self, value): |
| from test_stream_dependency_manage_lib import add_two |
| return add_two(value) |
| |
| def map2(self, value): |
| return value + 1 |
| |
| self.env.add_python_file(python_file_path) |
| ds = self.env.from_collection([1, 2, 3, 4, 5]) |
| ds_1 = ds.map(lambda x: x * 2) |
| ds.connect(ds_1).map(TestCoMapFunction()).add_sink(self.test_sink) |
| self.env.execute("test co-map add python file") |
| result = self.test_sink.get_results(True) |
| expected = ['11', '3', '3', '4', '5', '5', '6', '7', '7', '9'] |
| result.sort() |
| expected.sort() |
| self.assertEqual(expected, result) |
| |
| def test_co_map_function_with_data_types(self): |
| self.env.set_parallelism(1) |
| ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], |
| type_info=Types.ROW([Types.INT(), Types.INT()])) |
| ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], |
| type_info=Types.ROW([Types.STRING(), Types.STRING()])) |
| ds1.connect(ds2).map(MyCoMapFunction(), output_type=Types.STRING()).add_sink(self.test_sink) |
| self.env.execute('co_map_function_test') |
| results = self.test_sink.get_results(False) |
| expected = ['2', '3', '4', 'a', 'b', 'c'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_key_by_on_connect_stream(self): |
| ds1 = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) \ |
| .key_by(MyKeySelector(), key_type_info=Types.INT()) |
| ds2 = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| |
| class AssertKeyCoMapFunction(CoMapFunction): |
| def __init__(self): |
| self.pre1 = None |
| self.pre2 = None |
| |
| def map1(self, value): |
| if value[0] == 'b': |
| assert self.pre1 == 'a' |
| if value[0] == 'd': |
| assert self.pre1 == 'c' |
| self.pre1 = value[0] |
| return value |
| |
| def map2(self, value): |
| if value[0] == 'b': |
| assert self.pre2 == 'a' |
| if value[0] == 'd': |
| assert self.pre2 == 'c' |
| self.pre2 = value[0] |
| return value |
| |
| ds1.connect(ds2)\ |
| .key_by(MyKeySelector(), MyKeySelector(), key_type_info=Types.INT())\ |
| .map(AssertKeyCoMapFunction())\ |
| .add_sink(self.test_sink) |
| |
| self.env.execute('key_by_test') |
| results = self.test_sink.get_results(True) |
| expected = ["Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b', f1=0)", |
| "Row(f0='c', f1=1)", "Row(f0='d', f1=1)", "Row(f0='e', f1=2)", |
| "Row(f0='a', f1=0)", "Row(f0='b', f1=0)", "Row(f0='c', f1=1)", |
| "Row(f0='d', f1=1)"] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_map_function_with_data_types_and_function_object(self): |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| |
| ds.map(MyMapFunction(), output_type=Types.ROW([Types.STRING(), Types.INT(), Types.INT()]))\ |
| .add_sink(self.test_sink) |
| self.env.execute('map_function_test') |
| results = self.test_sink.get_results(False) |
| expected = ['ab,2,1', 'bdc,3,2', 'cfgs,4,3', 'deeefg,6,4'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_flat_map_function(self): |
| ds = self.env.from_collection([('a', 0), ('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| ds.flat_map(MyFlatMapFunction(), result_type=Types.ROW([Types.STRING(), Types.INT()]))\ |
| .add_sink(self.test_sink) |
| |
| self.env.execute('flat_map_test') |
| results = self.test_sink.get_results(False) |
| expected = ['a,0', 'bdc,2', 'deeefg,4'] |
| results.sort() |
| expected.sort() |
| |
| self.assertEqual(expected, results) |
| |
| def test_flat_map_function_with_function_object(self): |
| ds = self.env.from_collection([('a', 0), ('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| |
| def flat_map(value): |
| if value[1] % 2 == 0: |
| yield value |
| |
| ds.flat_map(flat_map, result_type=Types.ROW([Types.STRING(), Types.INT()]))\ |
| .add_sink(self.test_sink) |
| self.env.execute('flat_map_test') |
| results = self.test_sink.get_results(False) |
| expected = ['a,0', 'bdc,2', 'deeefg,4'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_co_flat_map_function_without_data_types(self): |
| self.env.set_parallelism(1) |
| ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], |
| type_info=Types.ROW([Types.INT(), Types.INT()])) |
| ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], |
| type_info=Types.ROW([Types.STRING(), Types.STRING()])) |
| ds1.connect(ds2).flat_map(MyCoFlatMapFunction()).add_sink(self.test_sink) |
| self.env.execute('co_flat_map_function_test') |
| results = self.test_sink.get_results(True) |
| expected = ['2', '2', '3', '3', '4', '4', 'b'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_co_flat_map_function_with_data_types(self): |
| self.env.set_parallelism(1) |
| ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)], |
| type_info=Types.ROW([Types.INT(), Types.INT()])) |
| ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")], |
| type_info=Types.ROW([Types.STRING(), Types.STRING()])) |
| ds1.connect(ds2).flat_map(MyCoFlatMapFunction(), output_type=Types.STRING())\ |
| .add_sink(self.test_sink) |
| self.env.execute('co_flat_map_function_test') |
| results = self.test_sink.get_results(False) |
| expected = ['2', '2', '3', '3', '4', '4', 'b'] |
| expected.sort() |
| results.sort() |
| self.assertEqual(expected, results) |
| |
| def test_filter_without_data_types(self): |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) |
| ds.filter(MyFilterFunction()).add_sink(self.test_sink) |
| self.env.execute("test filter") |
| results = self.test_sink.get_results(True) |
| expected = ["(2, 'Hello', 'Hi')"] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_filter_with_data_types(self): |
| ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')], |
| type_info=Types.ROW( |
| [Types.INT(), Types.STRING(), Types.STRING()]) |
| ) |
| ds.filter(lambda x: x[0] % 2 == 0).add_sink(self.test_sink) |
| self.env.execute("test filter") |
| results = self.test_sink.get_results(False) |
| expected = ['2,Hello,Hi'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_add_sink(self): |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| ds.add_sink(self.test_sink) |
| self.env.execute("test_add_sink") |
| results = self.test_sink.get_results(False) |
| expected = ['deeefg,4', 'bdc,2', 'ab,1', 'cfgs,3'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_key_by_map(self): |
| ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| keyed_stream = ds.key_by(MyKeySelector(), key_type_info=Types.INT()) |
| |
| with self.assertRaises(Exception): |
| keyed_stream.name("keyed stream") |
| |
| class AssertKeyMapFunction(MapFunction): |
| def __init__(self): |
| self.pre = None |
| |
| def map(self, value): |
| if value[0] == 'b': |
| assert self.pre == 'a' |
| if value[0] == 'd': |
| assert self.pre == 'c' |
| self.pre = value[0] |
| return value |
| |
| keyed_stream.map(AssertKeyMapFunction()).add_sink(self.test_sink) |
| self.env.execute('key_by_test') |
| results = self.test_sink.get_results(True) |
| expected = ["Row(f0='e', f1=2)", "Row(f0='a', f1=0)", "Row(f0='b', f1=0)", |
| "Row(f0='c', f1=1)", "Row(f0='d', f1=1)"] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_multi_key_by(self): |
| ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| ds.key_by(MyKeySelector(), key_type_info=Types.INT()).key_by(lambda x: x[0])\ |
| .add_sink(self.test_sink) |
| |
| self.env.execute("test multi key by") |
| results = self.test_sink.get_results(False) |
| expected = ['d,1', 'c,1', 'a,0', 'b,0', 'e,2'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_print_without_align_output(self): |
| # No need to align output typeinfo since we have specified the type info of the DataStream. |
| self.env.set_parallelism(1) |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| ds.print() |
| plan = eval(str(self.env.get_execution_plan())) |
| self.assertEqual("Sink: Print to Std. Out", plan['nodes'][1]['type']) |
| |
| def test_print_with_align_output(self): |
| # need to align output type before print, therefore the plan will contain three nodes |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)]) |
| ds.print() |
| plan = eval(str(self.env.get_execution_plan())) |
| self.assertEqual(3, len(plan['nodes'])) |
| self.assertEqual("Sink: Print to Std. Out", plan['nodes'][2]['type']) |
| |
| def test_union_stream(self): |
| ds_1 = self.env.from_collection([1, 2, 3]) |
| ds_2 = self.env.from_collection([4, 5, 6]) |
| ds_3 = self.env.from_collection([7, 8, 9]) |
| |
| united_stream = ds_3.union(ds_1, ds_2) |
| |
| united_stream.map(lambda x: x + 1).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| source_ids = [] |
| union_node_pre_ids = [] |
| for node in exec_plan['nodes']: |
| if node['pact'] == 'Data Source': |
| source_ids.append(node['id']) |
| if node['pact'] == 'Operator': |
| for pre in node['predecessors']: |
| union_node_pre_ids.append(pre['id']) |
| |
| source_ids.sort() |
| union_node_pre_ids.sort() |
| self.assertEqual(source_ids, union_node_pre_ids) |
| |
| def test_project(self): |
| ds = self.env.from_collection([[1, 2, 3, 4], [5, 6, 7, 8]], |
| type_info=Types.TUPLE( |
| [Types.INT(), Types.INT(), Types.INT(), Types.INT()])) |
| ds.project(1, 3).map(lambda x: (x[0], x[1] + 1)).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| self.assertEqual(exec_plan['nodes'][1]['type'], 'Projection') |
| |
| def test_broadcast(self): |
| ds_1 = self.env.from_collection([1, 2, 3]) |
| ds_1.broadcast().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| broadcast_node = exec_plan['nodes'][1] |
| pre_ship_strategy = broadcast_node['predecessors'][0]['ship_strategy'] |
| self.assertEqual(pre_ship_strategy, 'BROADCAST') |
| |
| def test_rebalance(self): |
| ds_1 = self.env.from_collection([1, 2, 3]) |
| ds_1.rebalance().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| rebalance_node = exec_plan['nodes'][1] |
| pre_ship_strategy = rebalance_node['predecessors'][0]['ship_strategy'] |
| self.assertEqual(pre_ship_strategy, 'REBALANCE') |
| |
| def test_rescale(self): |
| ds_1 = self.env.from_collection([1, 2, 3]) |
| ds_1.rescale().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| rescale_node = exec_plan['nodes'][1] |
| pre_ship_strategy = rescale_node['predecessors'][0]['ship_strategy'] |
| self.assertEqual(pre_ship_strategy, 'RESCALE') |
| |
| def test_shuffle(self): |
| ds_1 = self.env.from_collection([1, 2, 3]) |
| ds_1.shuffle().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) |
| exec_plan = eval(self.env.get_execution_plan()) |
| shuffle_node = exec_plan['nodes'][1] |
| pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy'] |
| self.assertEqual(pre_ship_strategy, 'SHUFFLE') |
| |
| def test_partition_custom(self): |
| ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2), |
| ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| |
| expected_num_partitions = 5 |
| |
| def my_partitioner(key, num_partitions): |
| assert expected_num_partitions, num_partitions |
| return key % num_partitions |
| |
| partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(), |
| Types.INT()]))\ |
| .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1]) |
| |
| JPartitionCustomTestMapFunction = get_gateway().jvm\ |
| .org.apache.flink.python.util.PartitionCustomTestMapFunction |
| test_map_stream = DataStream(partitioned_stream |
| ._j_data_stream.map(JPartitionCustomTestMapFunction())) |
| test_map_stream.set_parallelism(expected_num_partitions).add_sink(self.test_sink) |
| |
| self.env.execute('test_partition_custom') |
| |
| def test_keyed_stream_partitioning(self): |
| ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)]) |
| keyed_stream = ds.key_by(lambda x: x[1]) |
| with self.assertRaises(Exception): |
| keyed_stream.shuffle() |
| |
| with self.assertRaises(Exception): |
| keyed_stream.rebalance() |
| |
| with self.assertRaises(Exception): |
| keyed_stream.rescale() |
| |
| with self.assertRaises(Exception): |
| keyed_stream.broadcast() |
| |
| with self.assertRaises(Exception): |
| keyed_stream.forward() |
| |
| def test_slot_sharing_group(self): |
| source_operator_name = 'collection source' |
| map_operator_name = 'map_operator' |
| slot_sharing_group_1 = 'slot_sharing_group_1' |
| slot_sharing_group_2 = 'slot_sharing_group_2' |
| ds_1 = self.env.from_collection([1, 2, 3]).name(source_operator_name) |
| ds_1.slot_sharing_group(slot_sharing_group_1).map(lambda x: x + 1).set_parallelism(3)\ |
| .name(map_operator_name).slot_sharing_group(slot_sharing_group_2)\ |
| .add_sink(self.test_sink) |
| |
| j_generated_stream_graph = self.env._j_stream_execution_environment \ |
| .getStreamGraph("test start new_chain", True) |
| |
| j_stream_nodes = list(j_generated_stream_graph.getStreamNodes().toArray()) |
| for j_stream_node in j_stream_nodes: |
| if j_stream_node.getOperatorName() == source_operator_name: |
| self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_1) |
| elif j_stream_node.getOperatorName() == map_operator_name: |
| self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_2) |
| |
| def test_chaining_strategy(self): |
| chained_operator_name_0 = "map_operator_0" |
| chained_operator_name_1 = "map_operator_1" |
| chained_operator_name_2 = "map_operator_2" |
| |
| ds = self.env.from_collection([1, 2, 3]) |
| ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0)\ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1)\ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2)\ |
| .add_sink(self.test_sink) |
| |
| def assert_chainable(j_stream_graph, expected_upstream_chainable, |
| expected_downstream_chainable): |
| j_stream_nodes = list(j_stream_graph.getStreamNodes().toArray()) |
| for j_stream_node in j_stream_nodes: |
| if j_stream_node.getOperatorName() == chained_operator_name_1: |
| JStreamingJobGraphGenerator = get_gateway().jvm \ |
| .org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator |
| |
| j_in_stream_edge = j_stream_node.getInEdges().get(0) |
| upstream_chainable = JStreamingJobGraphGenerator.isChainable(j_in_stream_edge, |
| j_stream_graph) |
| self.assertEqual(expected_upstream_chainable, upstream_chainable) |
| |
| j_out_stream_edge = j_stream_node.getOutEdges().get(0) |
| downstream_chainable = JStreamingJobGraphGenerator.isChainable( |
| j_out_stream_edge, j_stream_graph) |
| self.assertEqual(expected_downstream_chainable, downstream_chainable) |
| |
| # The map_operator_1 has the same parallelism with map_operator_0 and map_operator_2, and |
| # ship_strategy for map_operator_0 and map_operator_1 is FORWARD, so the map_operator_1 |
| # can be chained with map_operator_0 and map_operator_2. |
| j_generated_stream_graph = self.env._j_stream_execution_environment\ |
| .getStreamGraph("test start new_chain", True) |
| assert_chainable(j_generated_stream_graph, True, True) |
| |
| ds = self.env.from_collection([1, 2, 3]) |
| # Start a new chain for map_operator_1 |
| ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).start_new_chain() \ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ |
| .add_sink(self.test_sink) |
| |
| j_generated_stream_graph = self.env._j_stream_execution_environment \ |
| .getStreamGraph("test start new_chain", True) |
| # We start a new chain for map operator, therefore, it cannot be chained with upstream |
| # operator, but can be chained with downstream operator. |
| assert_chainable(j_generated_stream_graph, False, True) |
| |
| ds = self.env.from_collection([1, 2, 3]) |
| # Disable chaining for map_operator_1 |
| ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).disable_chaining() \ |
| .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ |
| .add_sink(self.test_sink) |
| |
| j_generated_stream_graph = self.env._j_stream_execution_environment \ |
| .getStreamGraph("test start new_chain", True) |
| # We disable chaining for map_operator_1, therefore, it cannot be chained with |
| # upstream and downstream operators. |
| assert_chainable(j_generated_stream_graph, False, False) |
| |
| def test_primitive_array_type_info(self): |
| ds = self.env.from_collection([(1, [1.1, 1.2, 1.30]), (2, [2.1, 2.2, 2.3]), |
| (3, [3.1, 3.2, 3.3])], |
| type_info=Types.ROW([Types.INT(), |
| Types.PRIMITIVE_ARRAY(Types.FLOAT())])) |
| |
| ds.map(lambda x: x, output_type=Types.ROW([Types.INT(), |
| Types.PRIMITIVE_ARRAY(Types.FLOAT())]))\ |
| .add_sink(self.test_sink) |
| self.env.execute("test primitive array type info") |
| results = self.test_sink.get_results() |
| expected = ['1,[1.1, 1.2, 1.3]', '2,[2.1, 2.2, 2.3]', '3,[3.1, 3.2, 3.3]'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_basic_array_type_info(self): |
| ds = self.env.from_collection([(1, [1.1, None, 1.30], [None, 'hi', 'flink']), |
| (2, [None, 2.2, 2.3], ['hello', None, 'flink']), |
| (3, [3.1, 3.2, None], ['hello', 'hi', None])], |
| type_info=Types.ROW([Types.INT(), |
| Types.BASIC_ARRAY(Types.FLOAT()), |
| Types.BASIC_ARRAY(Types.STRING())])) |
| |
| ds.map(lambda x: x, output_type=Types.ROW([Types.INT(), |
| Types.BASIC_ARRAY(Types.FLOAT()), |
| Types.BASIC_ARRAY(Types.STRING())]))\ |
| .add_sink(self.test_sink) |
| self.env.execute("test basic array type info") |
| results = self.test_sink.get_results() |
| expected = ['1,[1.1, null, 1.3],[null, hi, flink]', |
| '2,[null, 2.2, 2.3],[hello, null, flink]', |
| '3,[3.1, 3.2, null],[hello, hi, null]'] |
| results.sort() |
| expected.sort() |
| self.assertEqual(expected, results) |
| |
| def test_sql_timestamp_type_info(self): |
| ds = self.env.from_collection([(datetime.date(2021, 1, 9), |
| datetime.time(12, 0, 0), |
| datetime.datetime(2021, 1, 9, 12, 0, 0, 11000))], |
| type_info=Types.ROW([Types.SQL_DATE(), |
| Types.SQL_TIME(), |
| Types.SQL_TIMESTAMP()])) |
| |
| ds.map(lambda x: x, output_type=Types.ROW([Types.SQL_DATE(), |
| Types.SQL_TIME(), |
| Types.SQL_TIMESTAMP()]))\ |
| .add_sink(self.test_sink) |
| self.env.execute("test sql timestamp type info") |
| results = self.test_sink.get_results() |
| expected = ['2021-01-09,12:00:00,2021-01-09 12:00:00.011'] |
| self.assertEqual(expected, results) |
| |
| def test_timestamp_assigner_and_watermark_strategy(self): |
| self.env.set_parallelism(1) |
| self.env.get_config().set_auto_watermark_interval(2000) |
| self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) |
| data_stream = self.env.from_collection([(1, '1603708211000'), |
| (2, '1603708224000'), |
| (3, '1603708226000'), |
| (4, '1603708289000')], |
| type_info=Types.ROW([Types.INT(), Types.STRING()])) |
| |
| class MyTimestampAssigner(TimestampAssigner): |
| |
| def extract_timestamp(self, value, record_timestamp) -> int: |
| return int(value[1]) |
| |
| class MyProcessFunction(KeyedProcessFunction): |
| |
| def __init__(self): |
| self.timer_registered = False |
| |
| def open(self, runtime_context: RuntimeContext): |
| self.timer_registered = False |
| |
| def process_element(self, value, ctx): |
| if not self.timer_registered: |
| ctx.timer_service().register_event_time_timer(3) |
| self.timer_registered = True |
| current_timestamp = ctx.timestamp() |
| current_watermark = ctx.timer_service().current_watermark() |
| current_key = ctx.get_current_key() |
| yield "current key: {}, current timestamp: {}, current watermark: {}, " \ |
| "current_value: {}".format(str(current_key), str(current_timestamp), |
| str(current_watermark), str(value)) |
| |
| def on_timer(self, timestamp, ctx): |
| yield "on timer: " + str(timestamp) |
| |
| watermark_strategy = WatermarkStrategy.for_monotonous_timestamps()\ |
| .with_timestamp_assigner(MyTimestampAssigner()) |
| data_stream.assign_timestamps_and_watermarks(watermark_strategy)\ |
| .key_by(lambda x: x[0], key_type_info=Types.INT()) \ |
| .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) |
| self.env.execute('test time stamp assigner with keyed process function') |
| result = self.test_sink.get_results() |
| # Because the watermark interval is too long, no watermark was sent before processing these |
| # data. So all current watermarks are Long.MIN_VALUE. |
| expected_result = ["current key: 1, current timestamp: 1603708211000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=1, f1='1603708211000')", |
| "current key: 2, current timestamp: 1603708224000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=2, f1='1603708224000')", |
| "current key: 3, current timestamp: 1603708226000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=3, f1='1603708226000')", |
| "current key: 4, current timestamp: 1603708289000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=4, f1='1603708289000')", |
| "on timer: 3"] |
| result.sort() |
| expected_result.sort() |
| self.assertEqual(expected_result, result) |
| |
| def test_process_function(self): |
| self.env.set_parallelism(1) |
| self.env.get_config().set_auto_watermark_interval(2000) |
| self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) |
| data_stream = self.env.from_collection([(1, '1603708211000'), |
| (2, '1603708224000'), |
| (3, '1603708226000'), |
| (4, '1603708289000')], |
| type_info=Types.ROW([Types.INT(), Types.STRING()])) |
| |
| class MyTimestampAssigner(TimestampAssigner): |
| |
| def extract_timestamp(self, value, record_timestamp) -> int: |
| return int(value[1]) |
| |
| class MyProcessFunction(ProcessFunction): |
| |
| def process_element(self, value, ctx): |
| current_timestamp = ctx.timestamp() |
| current_watermark = ctx.timer_service().current_watermark() |
| yield "current timestamp: {}, current watermark: {}, current_value: {}"\ |
| .format(str(current_timestamp), str(current_watermark), str(value)) |
| |
| def on_timer(self, timestamp, ctx, out): |
| pass |
| |
| watermark_strategy = WatermarkStrategy.for_monotonous_timestamps()\ |
| .with_timestamp_assigner(MyTimestampAssigner()) |
| data_stream.assign_timestamps_and_watermarks(watermark_strategy)\ |
| .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) |
| self.env.execute('test process function') |
| result = self.test_sink.get_results() |
| expected_result = ["current timestamp: 1603708211000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=1, f1='1603708211000')", |
| "current timestamp: 1603708224000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=2, f1='1603708224000')", |
| "current timestamp: 1603708226000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=3, f1='1603708226000')", |
| "current timestamp: 1603708289000, current watermark: " |
| "-9223372036854775808, current_value: Row(f0=4, f1='1603708289000')"] |
| result.sort() |
| expected_result.sort() |
| self.assertEqual(expected_result, result) |
| |
| def test_function_with_error(self): |
| ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 1)], |
| type_info=Types.ROW([Types.STRING(), Types.INT()])) |
| keyed_stream = ds.key_by(MyKeySelector()) |
| |
| def flat_map_func(x): |
| raise ValueError('flat_map_func error') |
| yield x |
| |
| from py4j.protocol import Py4JJavaError |
| import pytest |
| with pytest.raises(Py4JJavaError, match="flat_map_func error"): |
| keyed_stream.flat_map(flat_map_func).print() |
| self.env.execute("test_process_function_with_error") |
| |
| def tearDown(self) -> None: |
| self.test_sink.clear() |
| |
| |
| class MyMapFunction(MapFunction): |
| |
| def map(self, value): |
| result = Row(value[0], len(value[0]), value[1]) |
| return result |
| |
| |
| class MyFlatMapFunction(FlatMapFunction): |
| |
| def flat_map(self, value): |
| if value[1] % 2 == 0: |
| yield value |
| |
| |
| class MyKeySelector(KeySelector): |
| def get_key(self, value): |
| return value[1] |
| |
| |
| class MyFilterFunction(FilterFunction): |
| |
| def filter(self, value): |
| return value[0] % 2 == 0 |
| |
| |
| class MyCoMapFunction(CoMapFunction): |
| |
| def map1(self, value): |
| return str(value[0] + 1) |
| |
| def map2(self, value): |
| return value[0] |
| |
| |
| class MyCoFlatMapFunction(CoFlatMapFunction): |
| |
| def flat_map1(self, value): |
| yield str(value[0] + 1) |
| yield str(value[0] + 1) |
| |
| def flat_map2(self, value): |
| if value[0] == 'b': |
| yield value[0] |