blob: b02f6605d723258b2557833176da671b29037b12 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import tempfile
import json
from pyflink.common import ExecutionConfig, RestartStrategies
from pyflink.datastream import (StreamExecutionEnvironment, CheckpointConfig,
CheckpointingMode, MemoryStateBackend, TimeCharacteristic)
from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, StreamTableEnvironment
from pyflink.testing.test_case_utils import PyFlinkTestCase
class StreamExecutionEnvironmentTests(PyFlinkTestCase):
def setUp(self):
self.env = StreamExecutionEnvironment.get_execution_environment()
def test_get_config(self):
execution_config = self.env.get_config()
self.assertIsInstance(execution_config, ExecutionConfig)
def test_get_set_parallelism(self):
self.env.set_parallelism(10)
parallelism = self.env.get_parallelism()
self.assertEqual(parallelism, 10)
def test_get_set_buffer_timeout(self):
self.env.set_buffer_timeout(12000)
timeout = self.env.get_buffer_timeout()
self.assertEqual(timeout, 12000)
def test_get_set_default_local_parallelism(self):
self.env.set_default_local_parallelism(8)
parallelism = self.env.get_default_local_parallelism()
self.assertEqual(parallelism, 8)
def test_set_get_restart_strategy(self):
self.env.set_restart_strategy(RestartStrategies.no_restart())
restart_strategy = self.env.get_restart_strategy()
self.assertEqual(restart_strategy, RestartStrategies.no_restart())
def test_add_default_kryo_serializer(self):
self.env.add_default_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_default_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type_with_kryo_serializer(self):
self.env.register_type_with_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_registered_types_with_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type(self):
self.env.register_type("org.apache.flink.runtime.state.StateBackendTestBase$TestPojo")
type_list = self.env.get_config().get_registered_pojo_types()
self.assertEqual(type_list,
['org.apache.flink.runtime.state.StateBackendTestBase$TestPojo'])
def test_get_set_max_parallelism(self):
self.env.set_max_parallelism(12)
parallelism = self.env.get_max_parallelism()
self.assertEqual(parallelism, 12)
def test_operation_chaining(self):
self.assertTrue(self.env.is_chaining_enabled())
self.env.disable_operator_chaining()
self.assertFalse(self.env.is_chaining_enabled())
def test_get_checkpoint_config(self):
checkpoint_config = self.env.get_checkpoint_config()
self.assertIsInstance(checkpoint_config, CheckpointConfig)
def test_get_set_checkpoint_interval(self):
self.env.enable_checkpointing(30000)
interval = self.env.get_checkpoint_interval()
self.assertEqual(interval, 30000)
def test_get_set_checkpointing_mode(self):
mode = self.env.get_checkpointing_mode()
self.assertEqual(mode, CheckpointingMode.EXACTLY_ONCE)
self.env.enable_checkpointing(30000, CheckpointingMode.AT_LEAST_ONCE)
mode = self.env.get_checkpointing_mode()
self.assertEqual(mode, CheckpointingMode.AT_LEAST_ONCE)
def test_get_state_backend(self):
state_backend = self.env.get_state_backend()
self.assertIsNone(state_backend)
def test_set_state_backend(self):
input_backend = MemoryStateBackend()
self.env.set_state_backend(input_backend)
output_backend = self.env.get_state_backend()
self.assertEqual(output_backend._j_memory_state_backend,
input_backend._j_memory_state_backend)
def test_get_set_stream_time_characteristic(self):
default_time_characteristic = self.env.get_stream_time_characteristic()
self.assertEqual(default_time_characteristic, TimeCharacteristic.ProcessingTime)
self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)
time_characteristic = self.env.get_stream_time_characteristic()
self.assertEqual(time_characteristic, TimeCharacteristic.EventTime)
def test_get_execution_plan(self):
tmp_dir = tempfile.gettempdir()
source_path = os.path.join(tmp_dir + '/streaming.csv')
tmp_csv = os.path.join(tmp_dir + '/streaming2.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
t_env = StreamTableEnvironment.create(self.env)
csv_source = CsvTableSource(source_path, field_names, field_types)
t_env.register_table_source("Orders", csv_source)
t_env.register_table_sink(
"Results",
CsvTableSink(field_names, field_types, tmp_csv))
t_env.scan("Orders").insert_into("Results")
plan = self.env.get_execution_plan()
json.loads(plan)