| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| import array |
| import datetime |
| from decimal import Decimal |
| |
| from pyflink.common import Row |
| from pyflink.table import DataTypes, BatchTableEnvironment, EnvironmentSettings |
| from pyflink.table.expressions import row |
| from pyflink.table.tests.test_types import ExamplePoint, PythonOnlyPoint, ExamplePointUDT, \ |
| PythonOnlyUDT |
| from pyflink.testing import source_sink_utils |
| from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase |
| |
| |
| class StreamTableCalcTests(PyFlinkStreamTableTestCase): |
| |
| def test_select(self): |
| t = self.t_env.from_elements([(1, 'hi', 'hello')], ['a', 'b', 'c']) |
| result = t.select(t.a + 1, t.b, t.c) |
| query_operation = result._j_table.getQueryOperation() |
| self.assertEqual('[plus(a, 1), b, c]', |
| query_operation.getProjectList().toString()) |
| |
| def test_alias(self): |
| t = self.t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']) |
| t = t.alias("d, e, f") |
| result = t.select(t.d, t.e, t.f) |
| table_schema = result._j_table.getQueryOperation().getTableSchema() |
| self.assertEqual(['d', 'e', 'f'], list(table_schema.getFieldNames())) |
| |
| def test_where(self): |
| t_env = self.t_env |
| t = t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']) |
| result = t.where((t.a > 1) & (t.b == 'Hello')) |
| query_operation = result._j_table.getQueryOperation() |
| self.assertEqual("and(" |
| "greaterThan(a, 1), " |
| "equals(b, 'Hello'))", |
| query_operation.getCondition().toString()) |
| |
| def test_filter(self): |
| t = self.t_env.from_elements([(1, 'Hi', 'Hello')], ['a', 'b', 'c']) |
| result = t.filter((t.a > 1) & (t.b == 'Hello')) |
| query_operation = result._j_table.getQueryOperation() |
| self.assertEqual("and(" |
| "greaterThan(a, 1), " |
| "equals(b, 'Hello'))", |
| query_operation.getCondition().toString()) |
| |
| def test_from_element(self): |
| t_env = self.t_env |
| field_names = ["a", "b", "c", "d", "e", "f", "g", "h", |
| "i", "j", "k", "l", "m", "n", "o", "p", "q", "r"] |
| field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(), |
| DataTypes.STRING(), DataTypes.DATE(), |
| DataTypes.TIME(), |
| DataTypes.TIMESTAMP(3), |
| DataTypes.INTERVAL(DataTypes.SECOND(3)), |
| DataTypes.ARRAY(DataTypes.DOUBLE()), |
| DataTypes.ARRAY(DataTypes.DOUBLE(False)), |
| DataTypes.ARRAY(DataTypes.STRING()), |
| DataTypes.ARRAY(DataTypes.DATE()), |
| DataTypes.DECIMAL(38, 18), |
| DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()), |
| DataTypes.FIELD("b", DataTypes.DOUBLE())]), |
| DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()), |
| DataTypes.BYTES(), ExamplePointUDT(), |
| PythonOnlyUDT()] |
| schema = DataTypes.ROW( |
| list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type), |
| field_names, |
| field_types))) |
| table_sink = source_sink_utils.TestAppendSink(field_names, field_types) |
| t_env.register_table_sink("Results", table_sink) |
| t = t_env.from_elements( |
| [(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0), |
| datetime.datetime(1970, 1, 2, 0, 0), |
| datetime.timedelta(days=1, microseconds=10), |
| [1.0, None], array.array("d", [1.0, 2.0]), |
| ["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0), |
| {"key": 1.0}, bytearray(b'ABCD'), ExamplePoint(1.0, 2.0), |
| PythonOnlyPoint(3.0, 4.0))], |
| schema) |
| t.execute_insert("Results").wait() |
| actual = source_sink_utils.results() |
| |
| expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,' |
| '86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],' |
| '1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]'] |
| self.assert_equals(actual, expected) |
| |
| def test_from_element_expression(self): |
| t_env = self.t_env |
| |
| field_names = ["a", "b", "c"] |
| field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.FLOAT()] |
| |
| schema = DataTypes.ROW( |
| list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type), |
| field_names, |
| field_types))) |
| table_sink = source_sink_utils.TestAppendSink(field_names, field_types) |
| t_env.register_table_sink("Results", table_sink) |
| t = t_env.from_elements([row(1, 'abc', 2.0), row(2, 'def', 3.0)], schema) |
| t.execute_insert("Results").wait() |
| actual = source_sink_utils.results() |
| |
| expected = ['1,abc,2.0', '2,def,3.0'] |
| self.assert_equals(actual, expected) |
| |
| def test_blink_from_element(self): |
| t_env = BatchTableEnvironment.create(environment_settings=EnvironmentSettings |
| .new_instance().use_blink_planner() |
| .in_batch_mode().build()) |
| field_names = ["a", "b", "c", "d", "e", "f", "g", "h", |
| "i", "j", "k", "l", "m", "n", "o", "p", "q"] |
| field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(), |
| DataTypes.STRING(), DataTypes.DATE(), |
| DataTypes.TIME(), |
| DataTypes.TIMESTAMP(3), |
| DataTypes.INTERVAL(DataTypes.SECOND(3)), |
| DataTypes.ARRAY(DataTypes.DOUBLE()), |
| DataTypes.ARRAY(DataTypes.DOUBLE(False)), |
| DataTypes.ARRAY(DataTypes.STRING()), |
| DataTypes.ARRAY(DataTypes.DATE()), |
| DataTypes.DECIMAL(38, 18), |
| DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()), |
| DataTypes.FIELD("b", DataTypes.DOUBLE())]), |
| DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()), |
| DataTypes.BYTES(), |
| PythonOnlyUDT()] |
| schema = DataTypes.ROW( |
| list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type), |
| field_names, |
| field_types))) |
| table_sink = source_sink_utils.TestAppendSink(field_names, field_types) |
| t_env.register_table_sink("Results", table_sink) |
| t = t_env.from_elements( |
| [(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0), |
| datetime.datetime(1970, 1, 2, 0, 0), |
| datetime.timedelta(days=1, microseconds=10), |
| [1.0, None], array.array("d", [1.0, 2.0]), |
| ["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0), |
| {"key": 1.0}, bytearray(b'ABCD'), |
| PythonOnlyPoint(3.0, 4.0))], |
| schema) |
| t.execute_insert("Results").wait() |
| actual = source_sink_utils.results() |
| |
| expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,' |
| '86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],' |
| '1.000000000000000000,1,2.0,{key=1.0},[65, 66, 67, 68],[3.0, 4.0]'] |
| self.assert_equals(actual, expected) |
| |
| |
| if __name__ == '__main__': |
| import unittest |
| |
| try: |
| import xmlrunner |
| testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |