blob: 6a7637a5335a4f409451e3ec3102079e199f4e34 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
################################################################################
import os
from py4j.compat import unicode
from pyflink.dataset import ExecutionEnvironment
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.table import DataTypes, CsvTableSink, StreamTableEnvironment, EnvironmentSettings
from pyflink.table.table_config import TableConfig
from pyflink.table.table_environment import BatchTableEnvironment
from pyflink.table.types import RowType
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, PyFlinkBatchTableTestCase
from pyflink.util.exceptions import TableException
class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_register_table_source_scan(self):
t_env = self.t_env
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
source_path = os.path.join(self.tempdir + '/streaming.csv')
csv_source = self.prepare_csv_source(source_path, [], field_types, field_names)
t_env.register_table_source("Source", csv_source)
result = t_env.scan("Source")
self.assertEqual(
'CatalogTable: (identifier: [`default_catalog`.`default_database`.`Source`]'
', fields: [a, b, c])',
result._j_table.getQueryOperation().asSummaryString())
def test_register_table_sink(self):
t_env = self.t_env
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"Sinks",
source_sink_utils.TestAppendSink(field_names, field_types))
t_env.from_elements([(1, "Hi", "Hello")], ["a", "b", "c"]).insert_into("Sinks")
self.t_env.execute("test")
actual = source_sink_utils.results()
expected = ['1,Hi,Hello']
self.assert_equals(actual, expected)
def test_from_table_source(self):
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
source_path = os.path.join(self.tempdir + '/streaming.csv')
csv_source = self.prepare_csv_source(source_path, [], field_types, field_names)
result = self.t_env.from_table_source(csv_source)
self.assertEqual(
'TableSource: (fields: [a, b, c])',
result._j_table.getQueryOperation().asSummaryString())
def test_list_tables(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types, field_names)
t_env = self.t_env
t_env.register_table_source("Orders", csv_source)
t_env.register_table_sink(
"Sinks",
source_sink_utils.TestAppendSink(field_names, field_types))
t_env.register_table_sink(
"Results",
source_sink_utils.TestAppendSink(field_names, field_types))
actual = t_env.list_tables()
expected = ['Orders', 'Results', 'Sinks']
self.assert_equals(actual, expected)
def test_explain(self):
schema = RowType()\
.add('a', DataTypes.INT())\
.add('b', DataTypes.STRING())\
.add('c', DataTypes.STRING())
t_env = self.t_env
t = t_env.from_elements([], schema)
result = t.select("1 + a, b, c")
actual = t_env.explain(result)
assert isinstance(actual, str) or isinstance(actual, unicode)
def test_explain_with_extended(self):
schema = RowType() \
.add('a', DataTypes.INT()) \
.add('b', DataTypes.STRING()) \
.add('c', DataTypes.STRING())
t_env = self.t_env
t = t_env.from_elements([], schema)
result = t.select("1 + a, b, c")
actual = t_env.explain(result, True)
assert isinstance(actual, str) or isinstance(actual, unicode)
def test_explain_with_multi_sinks(self):
t_env = self.t_env
source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"sink1",
source_sink_utils.TestAppendSink(field_names, field_types))
t_env.register_table_sink(
"sink2",
source_sink_utils.TestAppendSink(field_names, field_types))
t_env.sql_update("insert into sink1 select * from %s where a > 100" % source)
t_env.sql_update("insert into sink2 select * from %s where a < 100" % source)
actual = t_env.explain(extended=True)
assert isinstance(actual, str) or isinstance(actual, unicode)
def test_sql_query(self):
t_env = self.t_env
source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"sinks",
source_sink_utils.TestAppendSink(field_names, field_types))
result = t_env.sql_query("select a + 1, b, c from %s" % source)
result.insert_into("sinks")
self.t_env.execute("test")
actual = source_sink_utils.results()
expected = ['2,Hi,Hello', '3,Hello,Hello']
self.assert_equals(actual, expected)
def test_sql_update(self):
t_env = self.t_env
source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"sinks",
source_sink_utils.TestAppendSink(field_names, field_types))
t_env.sql_update("insert into sinks select * from %s" % source)
self.t_env.execute("test_sql_job")
actual = source_sink_utils.results()
expected = ['1,Hi,Hello', '2,Hello,Hello']
self.assert_equals(actual, expected)
def test_register_java_function(self):
t_env = self.t_env
t_env.register_java_function("scalar_func",
"org.apache.flink.table.expressions.utils.RichFunc0")
t_env.register_java_function(
"agg_func", "org.apache.flink.table.functions.aggfunctions.ByteMaxAggFunction")
t_env.register_java_function("table_func", "org.apache.flink.table.utils.TableFunc1")
actual = t_env.list_user_defined_functions()
expected = ['scalar_func', 'agg_func', 'table_func']
self.assert_equals(actual, expected)
def test_create_table_environment(self):
table_config = TableConfig()
table_config.set_max_generated_code_length(32000)
table_config.set_null_check(False)
table_config.set_local_timezone("Asia/Shanghai")
env = StreamExecutionEnvironment.get_execution_environment()
t_env = StreamTableEnvironment.create(env, table_config)
readed_table_config = t_env.get_config()
self.assertFalse(readed_table_config.get_null_check())
self.assertEqual(readed_table_config.get_max_generated_code_length(), 32000)
self.assertEqual(readed_table_config.get_local_timezone(), "Asia/Shanghai")
def test_create_table_environment_with_blink_planner(self):
t_env = StreamTableEnvironment.create(
self.env,
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
planner = t_env._j_tenv.getPlanner()
self.assertEqual(
planner.getClass().getName(),
"org.apache.flink.table.planner.delegation.StreamPlanner")
def test_table_environment_with_blink_planner(self):
t_env = StreamTableEnvironment.create(
self.env,
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
source_path = os.path.join(self.tempdir + '/streaming.csv')
sink_path = os.path.join(self.tempdir + '/result.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, 'hi', 'hello'), (2, 'hello', 'hello')]
csv_source = self.prepare_csv_source(source_path, data, field_types, field_names)
t_env.register_table_source("source", csv_source)
t_env.register_table_sink(
"sink",
CsvTableSink(field_names, field_types, sink_path))
source = t_env.scan("source")
result = source.alias("a, b, c").select("1 + a, b, c")
result.insert_into("sink")
t_env.execute("blink_test")
results = []
with open(sink_path, 'r') as f:
results.append(f.readline())
results.append(f.readline())
self.assert_equals(results, ['2,hi,hello\n', '3,hello,hello\n'])
class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_explain(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types, field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
result = source.alias("a, b, c").select("1 + a, b, c")
actual = t_env.explain(result)
self.assertIsInstance(actual, (str, unicode))
def test_explain_with_extended(self):
schema = RowType() \
.add('a', DataTypes.INT()) \
.add('b', DataTypes.STRING()) \
.add('c', DataTypes.STRING())
t_env = self.t_env
t = t_env.from_elements([], schema)
result = t.select("1 + a, b, c")
actual = t_env.explain(result, True)
assert isinstance(actual, str) or isinstance(actual, unicode)
def test_explain_with_multi_sinks(self):
t_env = self.t_env
source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
field_names = ["a", "b", "c"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"sink1",
CsvTableSink(field_names, field_types, "path1"))
t_env.register_table_sink(
"sink2",
CsvTableSink(field_names, field_types, "path2"))
t_env.sql_update("insert into sink1 select * from %s where a > 100" % source)
t_env.sql_update("insert into sink2 select * from %s where a < 100" % source)
with self.assertRaises(TableException):
t_env.explain(extended=True)
def test_register_java_function(self):
t_env = self.t_env
t_env.register_java_function("scalar_func",
"org.apache.flink.table.expressions.utils.RichFunc0")
t_env.register_java_function(
"agg_func", "org.apache.flink.table.functions.aggfunctions.ByteMaxAggFunction")
t_env.register_java_function("table_func", "org.apache.flink.table.utils.TableFunc1")
actual = t_env.list_user_defined_functions()
expected = ['scalar_func', 'agg_func', 'table_func']
self.assert_equals(actual, expected)
def test_create_table_environment(self):
table_config = TableConfig()
table_config.set_max_generated_code_length(32000)
table_config.set_null_check(False)
table_config.set_local_timezone("Asia/Shanghai")
env = ExecutionEnvironment.get_execution_environment()
t_env = BatchTableEnvironment.create(env, table_config)
readed_table_config = t_env.get_config()
self.assertFalse(readed_table_config.get_null_check())
self.assertEqual(readed_table_config.get_max_generated_code_length(), 32000)
self.assertEqual(readed_table_config.get_local_timezone(), "Asia/Shanghai")
def test_create_table_environment_with_blink_planner(self):
t_env = BatchTableEnvironment.create(
environment_settings=EnvironmentSettings.new_instance().in_batch_mode()
.use_blink_planner().build())
planner = t_env._j_tenv.getPlanner()
self.assertEqual(
planner.getClass().getName(),
"org.apache.flink.table.planner.delegation.BatchPlanner")
def test_table_environment_with_blink_planner(self):
t_env = BatchTableEnvironment.create(
environment_settings=EnvironmentSettings.new_instance().in_batch_mode()
.use_blink_planner().build())
source_path = os.path.join(self.tempdir + '/streaming.csv')
sink_path = os.path.join(self.tempdir + '/results')
field_names = ["a", "b", "c"]
field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, 'hi', 'hello'), (2, 'hello', 'hello')]
csv_source = self.prepare_csv_source(source_path, data, field_types, field_names)
t_env.register_table_source("source", csv_source)
t_env.register_table_sink(
"sink",
CsvTableSink(field_names, field_types, sink_path))
source = t_env.scan("source")
result = source.alias("a, b, c").select("1 + a, b, c")
result.insert_into("sink")
t_env.execute("blink_test")
results = []
for root, dirs, files in os.walk(sink_path):
for sub_file in files:
with open(os.path.join(root, sub_file), 'r') as f:
line = f.readline()
while line is not None and line != '':
results.append(line)
line = f.readline()
self.assert_equals(results, ['2,hi,hello\n', '3,hello,hello\n'])