blob: 49ed8c9db3e41422a1a98ed3ed50bfd7a3f040f4 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import json
import os
import tempfile
import time
import unittest
from pyflink.common import ExecutionConfig, RestartStrategies
from pyflink.dataset import ExecutionEnvironment
from pyflink.table import DataTypes, BatchTableEnvironment, CsvTableSource, CsvTableSink
from pyflink.testing.test_case_utils import PyFlinkTestCase, exec_insert_table
class ExecutionEnvironmentTests(PyFlinkTestCase):
def setUp(self):
self.env = ExecutionEnvironment.get_execution_environment()
def test_get_set_parallelism(self):
self.env.set_parallelism(10)
parallelism = self.env.get_parallelism()
self.assertEqual(parallelism, 10)
def test_get_set_default_local_parallelism(self):
self.env.set_default_local_parallelism(8)
parallelism = self.env.get_default_local_parallelism()
self.assertEqual(parallelism, 8)
def test_get_config(self):
execution_config = self.env.get_config()
self.assertIsInstance(execution_config, ExecutionConfig)
def test_set_get_restart_strategy(self):
self.env.set_restart_strategy(RestartStrategies.no_restart())
restart_strategy = self.env.get_restart_strategy()
self.assertEqual(restart_strategy, RestartStrategies.no_restart())
def test_add_default_kryo_serializer(self):
self.env.add_default_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_default_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type_with_kryo_serializer(self):
self.env.register_type_with_kryo_serializer(
"org.apache.flink.runtime.state.StateBackendTestBase$TestPojo",
"org.apache.flink.runtime.state.StateBackendTestBase$CustomKryoTestSerializer")
class_dict = self.env.get_config().get_registered_types_with_kryo_serializer_classes()
self.assertEqual(class_dict,
{'org.apache.flink.runtime.state.StateBackendTestBase$TestPojo':
'org.apache.flink.runtime.state'
'.StateBackendTestBase$CustomKryoTestSerializer'})
def test_register_type(self):
self.env.register_type("org.apache.flink.runtime.state.StateBackendTestBase$TestPojo")
type_list = self.env.get_config().get_registered_pojo_types()
self.assertEqual(type_list,
["org.apache.flink.runtime.state.StateBackendTestBase$TestPojo"])
@unittest.skip("Python API does not support DataSet now. refactor this test later")
def test_get_execution_plan(self):
tmp_dir = tempfile.gettempdir()
source_path = os.path.join(tmp_dir + '/streaming.csv')
tmp_csv = os.path.join(tmp_dir + '/streaming2.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
t_env = BatchTableEnvironment.create(self.env)
csv_source = CsvTableSource(source_path, field_names, field_types)
t_env.register_table_source("Orders", csv_source)
t_env.register_table_sink(
"Results",
CsvTableSink(field_names, field_types, tmp_csv))
t_env.from_path("Orders").execute_insert("Results").wait()
plan = self.env.get_execution_plan()
json.loads(plan)
def test_execute(self):
tmp_dir = tempfile.gettempdir()
field_names = ['a', 'b', 'c']
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env = BatchTableEnvironment.create(self.env)
t_env.register_table_sink(
'Results',
CsvTableSink(field_names, field_types,
os.path.join('{}/{}.csv'.format(tmp_dir, round(time.time())))))
execution_result = exec_insert_table(
t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']),
'Results')
self.assertIsNotNone(execution_result.get_job_id())
self.assertIsNotNone(execution_result.get_net_runtime())
self.assertEqual(len(execution_result.get_all_accumulator_results()), 0)
self.assertIsNone(execution_result.get_accumulator_result('accumulator'))
self.assertIsNotNone(str(execution_result))