blob: f3d95a39a32a368c40ff6a11a240f785bb1f9f21 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
import decimal
import glob
import json
import os
import shutil
import tempfile
import time
import unittest
import uuid
from pyflink.common import Configuration, ExecutionConfig, RestartStrategies
from pyflink.common.serialization import JsonRowDeserializationSchema
from pyflink.common.typeinfo import Types
from pyflink.datastream import (StreamExecutionEnvironment, CheckpointConfig,
CheckpointingMode, MemoryStateBackend, TimeCharacteristic,
SlotSharingGroup)
from pyflink.datastream.connectors import FlinkKafkaConsumer
from pyflink.datastream.execution_mode import RuntimeExecutionMode
from pyflink.datastream.functions import SourceFunction
from pyflink.datastream.slot_sharing_group import MemorySize
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.find_flink_home import _find_flink_source_root
from pyflink.java_gateway import get_gateway
from pyflink.pyflink_gateway_server import on_windows
from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, StreamTableEnvironment, \
EnvironmentSettings
from pyflink.testing.test_case_utils import PyFlinkTestCase, exec_insert_table
from pyflink.util.java_utils import get_j_env_configuration
class StreamExecutionEnvironmentTests(PyFlinkTestCase):
def setUp(self):
self.env = self.create_new_env()
self.test_sink = DataStreamTestSinkFunction()
@staticmethod
def create_new_env(execution_mode='process'):
env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(2)
env._execution_mode = execution_mode
return env
def test_get_config(self):
execution_config = self.env.get_config()
self.assertIsInstance(execution_config, ExecutionConfig)
def test_get_set_parallelism(self):
self.env.set_parallelism(10)
parallelism = self.env.get_parallelism()
self.assertEqual(parallelism, 10)
def test_get_set_buffer_timeout(self):
self.env.set_buffer_timeout(12000)
timeout = self.env.get_buffer_timeout()
self.assertEqual(timeout, 12000)
def test_get_set_default_local_parallelism(self):
self.env.set_default_local_parallelism(8)
parallelism = self.env.get_default_local_parallelism()
self.assertEqual(parallelism, 8)
def test_set_get_restart_strategy(self):
self.env.set_restart_strategy(RestartStrategies.no_restart())
restart_strategy = self.env.get_restart_strategy()
self.assertEqual(restart_strategy, RestartStrategies.no_restart())
def test_add_default_kryo_serializer(self):
self.env.add_default_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_default_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type_with_kryo_serializer(self):
self.env.register_type_with_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_registered_types_with_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type(self):
self.env.register_type("org.apache.flink.runtime.state.StateBackendTestBase$TestPojo")
type_list = self.env.get_config().get_registered_pojo_types()
self.assertEqual(type_list,
['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo'])
def test_get_set_max_parallelism(self):
self.env.set_max_parallelism(12)
parallelism = self.env.get_max_parallelism()
self.assertEqual(parallelism, 12)
def test_set_runtime_mode(self):
self.env.set_runtime_mode(RuntimeExecutionMode.BATCH)
config = get_j_env_configuration(self.env._j_stream_execution_environment)
runtime_mode = config.getValue(
get_gateway().jvm.org.apache.flink.configuration.ExecutionOptions.RUNTIME_MODE)
self.assertEqual(runtime_mode, "BATCH")
def test_operation_chaining(self):
self.assertTrue(self.env.is_chaining_enabled())
self.env.disable_operator_chaining()
self.assertFalse(self.env.is_chaining_enabled())
def test_get_checkpoint_config(self):
checkpoint_config = self.env.get_checkpoint_config()
self.assertIsInstance(checkpoint_config, CheckpointConfig)
def test_get_set_checkpoint_interval(self):
self.env.enable_checkpointing(30000)
interval = self.env.get_checkpoint_interval()
self.assertEqual(interval, 30000)
def test_get_set_checkpointing_mode(self):
mode = self.env.get_checkpointing_mode()
self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE)
self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE)
mode = self.env.get_checkpointing_mode()
self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE)
def test_get_state_backend(self):
state_backend = self.env.get_state_backend()
self.assertIsNone(state_backend)
def test_set_state_backend(self):
input_backend = MemoryStateBackend()
self.env.set_state_backend(input_backend)
output_backend = self.env.get_state_backend()
self.assertEqual(output_backend._j_memory_state_backend,
input_backend._j_memory_state_backend)
def test_get_set_stream_time_characteristic(self):
default_time_characteristic = self.env.get_stream_time_characteristic()
self.assertEqual(default_time_characteristic, TimeCharacteristic.EventTime)
self.env.set_stream_time_characteristic(TimeCharacteristic.ProcessingTime)
time_characteristic = self.env.get_stream_time_characteristic()
self.assertEqual(time_characteristic, TimeCharacteristic.ProcessingTime)
def test_configure(self):
configuration = Configuration()
configuration.set_string('pipeline.operator-chaining', 'false')
configuration.set_string('pipeline.time-characteristic', 'IngestionTime')
configuration.set_string('execution.buffer-timeout', '1 min')
configuration.set_string('execution.checkpointing.timeout', '12000')
configuration.set_string('state.backend', 'jobmanager')
self.env.configure(configuration)
self.assertEqual(self.env.is_chaining_enabled(), False)
self.assertEqual(self.env.get_stream_time_characteristic(),
TimeCharacteristic.IngestionTime)
self.assertEqual(self.env.get_buffer_timeout(), 60000)
self.assertEqual(self.env.get_checkpoint_config().get_checkpoint_timeout(), 12000)
self.assertTrue(isinstance(self.env.get_state_backend(), MemoryStateBackend))
@unittest.skip("Python API does not support DataStream now. refactor this test later")
def test_get_execution_plan(self):
tmp_dir = tempfile.gettempdir()
source_path = os.path.join(tmp_dir + '/streaming.csv')
tmp_csv = os.path.join(tmp_dir + '/streaming2.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
t_env = StreamTableEnvironment.create(self.env)
csv_source = CsvTableSource(source_path, field_names, field_types)
t_env.register_table_source("Orders", csv_source)
t_env.register_table_sink(
"Results",
CsvTableSink(field_names, field_types, tmp_csv))
t_env.from_path("Orders").execute_insert("Results").wait()
plan = self.env.get_execution_plan()
json.loads(plan)
def test_execute(self):
tmp_dir = tempfile.gettempdir()
field_names = ['a', 'b', 'c']
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env = StreamTableEnvironment.create(self.env)
t_env.register_table_sink(
'Results',
CsvTableSink(field_names, field_types,
os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time())))))
execution_result = exec_insert_table(
t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']),
'Results')
self.assertIsNotNone(execution_result.get_job_id())
self.assertIsNotNone(execution_result.get_net_runtime())
self.assertEqual(len(execution_result.get_all_accumulator_results()), 0)
self.assertIsNone(execution_result.get_accumulator_result('accumulator'))
self.assertIsNotNone(str(execution_result))
def test_from_collection_without_data_types(self):
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
ds.add_sink(self.test_sink)
self.env.execute("test from collection")
results = self.test_sink.get_results(True)
# user does not specify data types for input data, the collected result should be in
# in tuple format as inputs.
expected = ["(1, 'Hi', 'Hello')", "(2, 'Hello', 'Hi')"]
results.sort()
expected.sort()
self.assertEqual(expected, results)
def test_from_collection_with_data_types(self):
# verify from_collection for the collection with single object.
ds = self.env.from_collection(['Hi', 'Hello'], type_info=Types.STRING())
ds.add_sink(self.test_sink)
self.env.execute("test from collection with single object")
results = self.test_sink.get_results(False)
expected = ['Hello', 'Hi']
results.sort()
expected.sort()
self.assertEqual(expected, results)
# verify from_collection for the collection with multiple objects like tuple.
ds = self.env.from_collection([(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
datetime.time(hour=12, minute=0, second=0,
microsecond=123000),
datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [1, 2, 3],
decimal.Decimal('1000000000000000000.05'),
decimal.Decimal('1000000000000000000.0599999999999'
'9999899999999999')),
(2, None, 2, True, 43878, 9147483648, 9.87, 2.98936,
bytearray(b'flink'), 'pyflink', datetime.date(2015, 10, 14),
datetime.time(hour=11, minute=2, second=2,
microsecond=234500),
datetime.datetime(2020, 4, 15, 8, 2, 6, 235000), [2, 4, 6],
decimal.Decimal('2000000000000000000.74'),
decimal.Decimal('2000000000000000000.061111111111111'
'11111111111111'))],
type_info=Types.ROW(
[Types.LONG(), Types.LONG(), Types.SHORT(),
Types.BOOLEAN(), Types.SHORT(), Types.INT(),
Types.FLOAT(), Types.DOUBLE(),
Types.PICKLED_BYTE_ARRAY(),
Types.STRING(), Types.SQL_DATE(), Types.SQL_TIME(),
Types.SQL_TIMESTAMP(),
Types.BASIC_ARRAY(Types.LONG()), Types.BIG_DEC(),
Types.BIG_DEC()]))
ds.add_sink(self.test_sink)
self.env.execute("test from collection with tuple object")
results = self.test_sink.get_results(False)
# if user specifies data types of input data, the collected result should be in row format.
expected = [
'+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, [102, 108, 105, 110, 107], '
'pyflink, 2014-09-13, 12:00:00, 2018-03-11 03:00:00.123, [1, 2, 3], '
'1000000000000000000.05, 1000000000000000000.05999999999999999899999999999]',
'+I[2, null, 2, true, -21658, 557549056, 9.87, 2.98936, [102, 108, 105, 110, 107], '
'pyflink, 2015-10-14, 11:02:02, 2020-04-15 08:02:06.235, [2, 4, 6], '
'2000000000000000000.74, 2000000000000000000.06111111111111111111111111111]']
results.sort()
expected.sort()
self.assertEqual(expected, results)
def test_add_custom_source(self):
custom_source = SourceFunction("org.apache.flink.python.util.MyCustomSourceFunction")
ds = self.env.add_source(custom_source, type_info=Types.ROW([Types.INT(), Types.STRING()]))
ds.add_sink(self.test_sink)
self.env.execute("test add custom source")
results = self.test_sink.get_results(False)
expected = [
'+I[3, Mike]',
'+I[1, Marry]',
'+I[4, Ted]',
'+I[5, Jack]',
'+I[0, Bob]',
'+I[2, Henry]']
results.sort()
expected.sort()
self.assertEqual(expected, results)
def test_read_text_file(self):
texts = ["Mike", "Marry", "Ted", "Jack", "Bob", "Henry"]
text_file_path = self.tempdir + '/text_file'
with open(text_file_path, 'a') as f:
for text in texts:
f.write(text)
f.write('\n')
ds = self.env.read_text_file(text_file_path)
ds.add_sink(self.test_sink)
self.env.execute("test read text file")
results = self.test_sink.get_results()
results.sort()
texts.sort()
self.assertEqual(texts, results)
def test_execute_async(self):
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
type_info=Types.ROW(
[Types.INT(), Types.STRING(), Types.STRING()]))
ds.add_sink(self.test_sink)
job_client = self.env.execute_async("test execute async")
job_id = job_client.get_job_id()
self.assertIsNotNone(job_id)
execution_result = job_client.get_job_execution_result().result()
self.assertEqual(str(job_id), str(execution_result.get_job_id()))
def test_add_python_file(self):
import uuid
env = self.create_new_env("loopback")
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4()))
os.mkdir(python_file_dir)
python_file_path = os.path.join(python_file_dir, "test_dep1.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")
def plus_two_map(value):
from test_dep1 import add_two
return add_two(value)
get_j_env_configuration(env._j_stream_execution_environment).\
setString("taskmanager.numberOfTaskSlots", "10")
env.add_python_file(python_file_path)
ds = env.from_collection([1, 2, 3, 4, 5])
ds = ds.map(plus_two_map, Types.LONG()) \
.slot_sharing_group("data_stream") \
.map(lambda i: i, Types.LONG()) \
.slot_sharing_group("table")
python_file_path = os.path.join(python_file_dir, "test_dep2.py")
with open(python_file_path, 'w') as f:
f.write("def add_three(a):\n return a + 3")
def plus_three(value):
from test_dep2 import add_three
return add_three(value)
t_env = StreamTableEnvironment.create(
stream_execution_environment=env,
environment_settings=EnvironmentSettings.in_streaming_mode())
env.add_python_file(python_file_path)
from pyflink.table.udf import udf
from pyflink.table.expressions import col
add_three = udf(plus_three, result_type=DataTypes.BIGINT())
tab = t_env.from_data_stream(ds, 'a') \
.select(add_three(col('a')))
t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \
.map(lambda i: i[0]) \
.add_sink(self.test_sink)
env.execute("test add_python_file")
result = self.test_sink.get_results(True)
expected = ['6', '7', '8', '9', '10']
result.sort()
expected.sort()
self.assertEqual(expected, result)
def test_add_python_file_2(self):
import uuid
env = self.create_new_env("loopback")
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4()))
os.mkdir(python_file_dir)
python_file_path = os.path.join(python_file_dir, "test_dep1.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")
def plus_two_map(value):
from test_dep1 import add_two
return add_two(value)
get_j_env_configuration(env._j_stream_execution_environment).\
setString("taskmanager.numberOfTaskSlots", "10")
env.add_python_file(python_file_path)
ds = env.from_collection([1, 2, 3, 4, 5])
ds = ds.map(plus_two_map, Types.LONG()) \
.slot_sharing_group("data_stream") \
.map(lambda i: i, Types.LONG()) \
.slot_sharing_group("table")
python_file_path = os.path.join(python_file_dir, "test_dep2.py")
with open(python_file_path, 'w') as f:
f.write("def add_three(a):\n return a + 3")
def plus_three(value):
from test_dep2 import add_three
return add_three(value)
t_env = StreamTableEnvironment.create(
stream_execution_environment=env,
environment_settings=EnvironmentSettings.in_streaming_mode())
env.add_python_file(python_file_path)
from pyflink.table.udf import udf
from pyflink.table.expressions import col
add_three = udf(plus_three, result_type=DataTypes.BIGINT())
tab = t_env.from_data_stream(ds, 'a') \
.select(add_three(col('a')))
result = [i[0] for i in tab.execute().collect()]
expected = [6, 7, 8, 9, 10]
result.sort()
expected.sort()
self.assertEqual(expected, result)
def test_set_requirements_without_cached_directory(self):
import uuid
requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
with open(requirements_txt_path, 'w') as f:
f.write("cloudpickle==1.2.2")
self.env.set_python_requirements(requirements_txt_path)
def check_requirements(i):
import cloudpickle # noqa # pylint: disable=unused-import
return i
ds = self.env.from_collection([1, 2, 3, 4, 5])
ds.map(check_requirements).add_sink(self.test_sink)
self.env.execute("test set requirements without cache dir")
result = self.test_sink.get_results(True)
expected = ['1', '2', '3', '4', '5']
result.sort()
expected.sort()
self.assertEqual(expected, result)
def test_set_requirements_with_cached_directory(self):
import uuid
tmp_dir = self.tempdir
env = self.create_new_env("loopback")
requirements_txt_path = os.path.join(tmp_dir, "requirements_txt_" + str(uuid.uuid4()))
with open(requirements_txt_path, 'w') as f:
f.write("python-package1==0.0.0")
requirements_dir_path = os.path.join(tmp_dir, "requirements_dir_" + str(uuid.uuid4()))
os.mkdir(requirements_dir_path)
package_file_name = "python-package1-0.0.0.tar.gz"
with open(os.path.join(requirements_dir_path, package_file_name), 'wb') as f:
import base64
# This base64 data is encoded from a python package file which includes a
# "python_package1" module. The module contains a "plus(a, b)" function.
# The base64 can be recomputed by following code:
# base64.b64encode(open("python-package1-0.0.0.tar.gz", "rb").read()).decode("utf-8")
f.write(base64.b64decode(
"H4sICNefrV0C/2Rpc3QvcHl0aG9uLXBhY2thZ2UxLTAuMC4wLnRhcgDtmVtv2jAYhnPtX2H1CrRCY+ckI"
"XEx7axuUA11u5imyICTRc1JiVnHfv1MKKWjYxwKEdPehws7xkmUfH5f+3PyqfqWpa1cjG5EKFnLbOvfhX"
"FQTI3nOPPSdavS5Pa8nGMwy3Esi3ke9wyTObbnGNQxamBSKlFQavzUryG8ldG6frpbEGx4yNmDLMp/hPy"
"P8b+6fNN613vdP1z8XdteG3+ug/17/F3Hcw1qIv5H54NUYiyUaH2SRRllaYeytkl6IpEdujI2yH2XapCQ"
"wSRJRDHt0OveZa//uUfeZonUvUO5bHo+0ZcoVo9bMhFRvGx9H41kWj447aUsR0WUq+pui8arWKggK5Jli"
"wGOo/95q79ovXi6/nfyf246Dof/n078fT9KI+X77Xx6BP83bX4Xf5NxT7dz7toO/L8OxjKgeTwpG+KcDp"
"sdQjWFVJMipYI+o0MCk4X/t2UYtqI0yPabCHb3f861XcD/Ty/+Y5nLdCzT0dSPo/SmbKsf6un+b7KV+Ls"
"W4/D/OoC9w/930P9eGwM75//csrD+Q/6P/P/k9D/oX3988Wqw1bS/tf6tR+s/m3EG/ddBqXO9XKf15C8p"
"P9k4HZBtBgzZaVW5vrfKcj+W32W82ygEB9D/Xu9+4/qfP9L/rBv0X1v87yONKRX61/qfzwqjIDzIPTbv/"
"7or3/88i0H/tfBFW7s/s/avRInQH06ieEy7tDrQeYHUdRN7wP+n/vf62LOH/pld7f9xz7a5Pfufedy0oP"
"86iJI8KxStAq6yLC4JWdbbVbWRikR2z1ZGytk5vauW3QdnBFE6XqwmykazCesAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAOBw/AJw5CHBAFAAAA=="))
env.set_python_requirements(requirements_txt_path, requirements_dir_path)
def add_one(i):
from python_package1 import plus
return plus(i, 1)
ds = env.from_collection([1, 2, 3, 4, 5])
ds.map(add_one).add_sink(self.test_sink)
env.execute("test set requirements with cachd dir")
result = self.test_sink.get_results(True)
expected = ['2', '3', '4', '5', '6']
result.sort()
expected.sort()
self.assertEqual(expected, result)
def test_add_python_archive(self):
import uuid
import shutil
tmp_dir = self.tempdir
env = self.create_new_env("loopback")
archive_dir_path = os.path.join(tmp_dir, "archive_" + str(uuid.uuid4()))
os.mkdir(archive_dir_path)
with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
f.write("2")
archive_file_path = \
shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
env.add_python_archive(archive_file_path, "data")
def add_from_file(i):
with open("data/data.txt", 'r') as f:
return i + int(f.read())
ds = env.from_collection([1, 2, 3, 4, 5])
ds.map(add_from_file).add_sink(self.test_sink)
env.execute("test set python archive")
result = self.test_sink.get_results(True)
expected = ['3', '4', '5', '6', '7']
result.sort()
expected.sort()
self.assertEqual(expected, result)
@unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.")
def test_set_stream_env(self):
import sys
python_exec = sys.executable
tmp_dir = self.tempdir
env = self.create_new_env("loopback")
python_exec_link_path = os.path.join(tmp_dir, "py_exec")
os.symlink(python_exec, python_exec_link_path)
env.set_python_executable(python_exec_link_path)
def check_python_exec(i):
import os
assert os.environ["python"] == python_exec_link_path
return i
ds = env.from_collection([1, 2, 3, 4, 5])
ds.map(check_python_exec).add_sink(self.test_sink)
env.execute("test set python executable")
result = self.test_sink.get_results(True)
expected = ['1', '2', '3', '4', '5']
result.sort()
expected.sort()
self.assertEqual(expected, result)
def test_add_jars(self):
# find kafka connector jars
flink_source_root = _find_flink_source_root()
jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka'
specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar')
specific_jars = ['file://' + specific_jar for specific_jar in specific_jars]
self.env.add_jars(*specific_jars)
source_topic = 'test_source_topic'
props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
type_info = Types.ROW([Types.INT(), Types.STRING()])
# Test for kafka consumer
deserialization_schema = JsonRowDeserializationSchema.builder() \
.type_info(type_info=type_info).build()
# Will get a ClassNotFoundException if not add the kafka connector into the pipeline jars.
kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props)
self.env.add_source(kafka_consumer).print()
self.env.get_execution_plan()
def test_add_classpaths(self):
# find kafka connector jars
flink_source_root = _find_flink_source_root()
jars_abs_path = flink_source_root + '/flink-connectors/flink-sql-connector-kafka'
specific_jars = glob.glob(jars_abs_path + '/target/flink*.jar')
specific_jars = ['file://' + specific_jar for specific_jar in specific_jars]
self.env.add_classpaths(*specific_jars)
source_topic = 'test_source_topic'
props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
type_info = Types.ROW([Types.INT(), Types.STRING()])
# Test for kafka consumer
deserialization_schema = JsonRowDeserializationSchema.builder() \
.type_info(type_info=type_info).build()
# It Will raise a ClassNotFoundException if the kafka connector is not added into the
# pipeline classpaths.
kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props)
self.env.add_source(kafka_consumer).print()
self.env.get_execution_plan()
def test_generate_stream_graph_with_dependencies(self):
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4()))
os.mkdir(python_file_dir)
python_file_path = os.path.join(python_file_dir, "test_stream_dependency_manage_lib.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")
env = self.env
env.add_python_file(python_file_path)
def plus_two_map(value):
from test_stream_dependency_manage_lib import add_two
return value[0], add_two(value[1])
def add_from_file(i):
with open("data/data.txt", 'r') as f:
return i[0], i[1] + int(f.read())
from_collection_source = env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 2)],
type_info=Types.ROW([Types.STRING(),
Types.INT()]))
from_collection_source.name("From Collection")
keyed_stream = from_collection_source.key_by(lambda x: x[1], key_type=Types.INT())
plus_two_map_stream = keyed_stream.map(plus_two_map).name("Plus Two Map").set_parallelism(3)
add_from_file_map = plus_two_map_stream.map(add_from_file).name("Add From File Map")
test_stream_sink = add_from_file_map.add_sink(self.test_sink).name("Test Sink")
test_stream_sink.set_parallelism(4)
archive_dir_path = os.path.join(self.tempdir, "archive_" + str(uuid.uuid4()))
os.mkdir(archive_dir_path)
with open(os.path.join(archive_dir_path, "data.txt"), 'w') as f:
f.write("3")
archive_file_path = \
shutil.make_archive(os.path.dirname(archive_dir_path), 'zip', archive_dir_path)
env.add_python_archive(archive_file_path, "data")
nodes = eval(env.get_execution_plan())['nodes']
# The StreamGraph should be as bellow:
# Source: From Collection -> _stream_key_by_map_operator ->
# Plus Two Map -> Add From File Map -> Sink: Test Sink.
# Source: From Collection and _stream_key_by_map_operator should have same parallelism.
self.assertEqual(nodes[0]['parallelism'], nodes[1]['parallelism'])
# The parallelism of Plus Two Map should be 3
self.assertEqual(nodes[2]['parallelism'], 3)
# The ship_strategy for Source: From Collection and _stream_key_by_map_operator should be
# FORWARD
self.assertEqual(nodes[1]['predecessors'][0]['ship_strategy'], "FORWARD")
# The ship_strategy for _keyed_stream_values_operator and Plus Two Map should be
# HASH
self.assertEqual(nodes[2]['predecessors'][0]['ship_strategy'], "HASH")
# The parallelism of Sink: Test Sink should be 4
self.assertEqual(nodes[4]['parallelism'], 4)
env_config_with_dependencies = dict(get_gateway().jvm.org.apache.flink.python.util
.PythonConfigUtil.getEnvConfigWithDependencies(
env._j_stream_execution_environment).toMap())
# Make sure that user specified files and archives are correctly added.
self.assertIsNotNone(env_config_with_dependencies['python.files'])
self.assertIsNotNone(env_config_with_dependencies['python.archives'])
def test_register_slot_sharing_group(self):
slot_sharing_group_1 = SlotSharingGroup.builder('slot_sharing_group_1') \
.set_cpu_cores(1.0).set_task_heap_memory_mb(100).build()
slot_sharing_group_2 = SlotSharingGroup.builder('slot_sharing_group_2') \
.set_cpu_cores(2.0).set_task_heap_memory_mb(200).build()
slot_sharing_group_3 = SlotSharingGroup.builder('slot_sharing_group_3').build()
self.env.register_slot_sharing_group(slot_sharing_group_1)
self.env.register_slot_sharing_group(slot_sharing_group_2)
self.env.register_slot_sharing_group(slot_sharing_group_3)
ds = self.env.from_collection([1, 2, 3]).slot_sharing_group(
'slot_sharing_group_1')
ds.map(lambda x: x + 1).set_parallelism(3) \
.slot_sharing_group('slot_sharing_group_2') \
.add_sink(self.test_sink)
j_generated_stream_graph = self.env._j_stream_execution_environment \
.getStreamGraph(True)
j_resource_profile_1 = j_generated_stream_graph.getSlotSharingGroupResource(
'slot_sharing_group_1').get()
j_resource_profile_2 = j_generated_stream_graph.getSlotSharingGroupResource(
'slot_sharing_group_2').get()
j_resource_profile_3 = j_generated_stream_graph.getSlotSharingGroupResource(
'slot_sharing_group_3')
self.assertEqual(j_resource_profile_1.getCpuCores().getValue(), 1.0)
self.assertEqual(MemorySize(j_memory_size=j_resource_profile_1.getTaskHeapMemory()),
MemorySize.of_mebi_bytes(100))
self.assertEqual(j_resource_profile_2.getCpuCores().getValue(), 2.0)
self.assertEqual(MemorySize(j_memory_size=j_resource_profile_2.getTaskHeapMemory()),
MemorySize.of_mebi_bytes(200))
self.assertFalse(j_resource_profile_3.isPresent())
def tearDown(self) -> None:
self.test_sink.clear()