| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import datetime |
| import decimal |
| |
| from pandas.util.testing import assert_frame_equal |
| |
| from pyflink.common import Row |
| from pyflink.table.types import DataTypes |
| from pyflink.testing import source_sink_utils |
| from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \ |
| PyFlinkBlinkStreamTableTestCase, PyFlinkOldStreamTableTestCase |
| |
| |
| class PandasConversionTestBase(object): |
| |
| @classmethod |
| def setUpClass(cls): |
| super(PandasConversionTestBase, cls).setUpClass() |
| cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"), |
| decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13), |
| datetime.time(hour=1, minute=0, second=1), |
| datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'], |
| Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), |
| d=[1, 2])), |
| (1, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"), |
| decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13), |
| datetime.time(hour=1, minute=0, second=1), |
| datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'], |
| Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), |
| d=[1, 2]))] |
| cls.data_type = DataTypes.ROW( |
| [DataTypes.FIELD("f1", DataTypes.TINYINT()), |
| DataTypes.FIELD("f2", DataTypes.SMALLINT()), |
| DataTypes.FIELD("f3", DataTypes.INT()), |
| DataTypes.FIELD("f4", DataTypes.BIGINT()), |
| DataTypes.FIELD("f5", DataTypes.BOOLEAN()), |
| DataTypes.FIELD("f6", DataTypes.FLOAT()), |
| DataTypes.FIELD("f7", DataTypes.DOUBLE()), |
| DataTypes.FIELD("f8", DataTypes.STRING()), |
| DataTypes.FIELD("f9", DataTypes.BYTES()), |
| DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)), |
| DataTypes.FIELD("f11", DataTypes.DATE()), |
| DataTypes.FIELD("f12", DataTypes.TIME()), |
| DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)), |
| DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())), |
| DataTypes.FIELD("f15", DataTypes.ROW( |
| [DataTypes.FIELD("a", DataTypes.INT()), |
| DataTypes.FIELD("b", DataTypes.STRING()), |
| DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)), |
| DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))], False) |
| cls.pdf = cls.create_pandas_data_frame() |
| |
| @classmethod |
| def create_pandas_data_frame(cls): |
| data_dict = {} |
| for j, name in enumerate(cls.data_type.names): |
| data_dict[name] = [cls.data[i][j] for i in range(len(cls.data))] |
| # need convert to numpy types |
| import numpy as np |
| data_dict["f1"] = np.int8(data_dict["f1"]) |
| data_dict["f2"] = np.int16(data_dict["f2"]) |
| data_dict["f3"] = np.int32(data_dict["f3"]) |
| data_dict["f4"] = np.int64(data_dict["f4"]) |
| data_dict["f6"] = np.float32(data_dict["f6"]) |
| data_dict["f7"] = np.float64(data_dict["f7"]) |
| data_dict["f15"] = [row.as_dict() for row in data_dict["f15"]] |
| import pandas as pd |
| return pd.DataFrame(data=data_dict, |
| index=[2., 3.], |
| columns=['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', |
| 'f10', 'f11', 'f12', 'f13', 'f14', 'f15']) |
| |
| |
| class PandasConversionTests(PandasConversionTestBase): |
| |
| def test_from_pandas_with_incorrect_schema(self): |
| fields = self.data_type.fields.copy() |
| fields[0], fields[7] = fields[7], fields[0] # swap str with tinyint |
| wrong_schema = DataTypes.ROW(fields) # should be DataTypes.STRING() |
| with self.assertRaisesRegex(Exception, "Expected a string.*got int8"): |
| self.t_env.from_pandas(self.pdf, schema=wrong_schema) |
| |
| def test_from_pandas_with_names(self): |
| # skip decimal as currently only decimal(38, 18) is supported |
| pdf = self.pdf.drop(['f10', 'f11', 'f12', 'f13', 'f14', 'f15'], axis=1) |
| new_names = list(map(str, range(len(pdf.columns)))) |
| table = self.t_env.from_pandas(pdf, schema=new_names) |
| self.assertEqual(new_names, table.get_schema().get_field_names()) |
| table = self.t_env.from_pandas(pdf, schema=tuple(new_names)) |
| self.assertEqual(new_names, table.get_schema().get_field_names()) |
| |
| def test_from_pandas_with_types(self): |
| new_types = self.data_type.field_types() |
| new_types[0] = DataTypes.BIGINT() |
| table = self.t_env.from_pandas(self.pdf, schema=new_types) |
| self.assertEqual(new_types, table.get_schema().get_field_data_types()) |
| table = self.t_env.from_pandas(self.pdf, schema=tuple(new_types)) |
| self.assertEqual(new_types, table.get_schema().get_field_data_types()) |
| |
| |
| class PandasConversionITTests(PandasConversionTestBase): |
| |
| def test_from_pandas(self): |
| table = self.t_env.from_pandas(self.pdf, self.data_type, 5) |
| self.assertEqual(self.data_type, table.get_schema().to_row_data_type()) |
| |
| table = table.filter(table.f2 < 2) |
| table_sink = source_sink_utils.TestAppendSink( |
| self.data_type.field_names(), |
| self.data_type.field_types()) |
| self.t_env.register_table_sink("Results", table_sink) |
| table.execute_insert("Results").wait() |
| actual = source_sink_utils.results() |
| self.assert_equals(actual, |
| ["+I[1, 1, 1, 1, true, 1.1, 1.2, hello, [97, 97, 97], " |
| "1000000000000000000.010000000000000000, 2014-09-13, 01:00:01, " |
| "1970-01-01 00:00:00.123, [hello, 中文], +I[1, hello, " |
| "1970-01-01 00:00:00.123, [1, 2]]]"]) |
| |
| def test_to_pandas(self): |
| table = self.t_env.from_pandas(self.pdf, self.data_type) |
| result_pdf = table.to_pandas() |
| result_pdf.index = self.pdf.index |
| self.assertEqual(2, len(result_pdf)) |
| expected_arrow = self.pdf.to_records(index=False) |
| result_arrow = result_pdf.to_records(index=False) |
| for r in range(len(expected_arrow)): |
| for e in range(len(expected_arrow[r])): |
| self.assert_equal_field(expected_arrow[r][e], result_arrow[r][e]) |
| |
| def test_empty_to_pandas(self): |
| table = self.t_env.from_pandas(self.pdf, self.data_type) |
| pdf = table.filter(table.f1 < 0).to_pandas() |
| self.assertTrue(pdf.empty) |
| |
| def test_to_pandas_for_retract_table(self): |
| table = self.t_env.from_pandas(self.pdf, self.data_type) |
| result_pdf = table.group_by(table.f1).select(table.f2.max.alias('f2')).to_pandas() |
| import pandas as pd |
| import numpy as np |
| assert_frame_equal(result_pdf, pd.DataFrame(data={'f2': np.int16([2])})) |
| |
| result_pdf = table.group_by("f2").select("max(f1) as f2").to_pandas() |
| assert_frame_equal(result_pdf, pd.DataFrame(data={'f2': np.int8([1, 1])})) |
| |
| def assert_equal_field(self, expected_field, result_field): |
| import numpy as np |
| result_type = type(result_field) |
| if result_type == dict: |
| self.assertEqual(expected_field.keys(), result_field.keys()) |
| for key in expected_field: |
| self.assert_equal_field(expected_field[key], result_field[key]) |
| elif result_type == np.ndarray: |
| self.assertTrue((expected_field == result_field).all()) |
| else: |
| self.assertTrue(expected_field == result_field) |
| |
| |
| class StreamPandasConversionTests(PandasConversionITTests, |
| PyFlinkOldStreamTableTestCase): |
| pass |
| |
| |
| class BlinkBatchPandasConversionTests(PandasConversionTests, |
| PandasConversionITTests, |
| PyFlinkBlinkBatchTableTestCase): |
| pass |
| |
| |
| class BlinkStreamPandasConversionTests(PandasConversionITTests, |
| PyFlinkBlinkStreamTableTestCase): |
| def test_to_pandas_with_event_time(self): |
| self.t_env.get_config().get_configuration().set_string("parallelism.default", "1") |
| # create source file path |
| import tempfile |
| import os |
| tmp_dir = tempfile.gettempdir() |
| data = [ |
| '2018-03-11 03:10:00', |
| '2018-03-11 03:10:00', |
| '2018-03-11 03:10:00', |
| '2018-03-11 03:40:00', |
| '2018-03-11 04:20:00', |
| '2018-03-11 03:30:00' |
| ] |
| source_path = tmp_dir + '/test_to_pandas_with_event_time.csv' |
| with open(source_path, 'w') as fd: |
| for ele in data: |
| fd.write(ele + '\n') |
| |
| self.t_env.get_config().get_configuration().set_string( |
| "pipeline.time-characteristic", "EventTime") |
| |
| source_table = """ |
| create table source_table( |
| rowtime TIMESTAMP(3), |
| WATERMARK FOR rowtime AS rowtime - INTERVAL '60' MINUTE |
| ) with( |
| 'connector.type' = 'filesystem', |
| 'format.type' = 'csv', |
| 'connector.path' = '%s', |
| 'format.ignore-first-line' = 'false', |
| 'format.field-delimiter' = ',' |
| ) |
| """ % source_path |
| self.t_env.execute_sql(source_table) |
| t = self.t_env.from_path("source_table") |
| result_pdf = t.to_pandas() |
| import pandas as pd |
| os.remove(source_path) |
| assert_frame_equal(result_pdf, pd.DataFrame( |
| data={"rowtime": [ |
| datetime.datetime(2018, 3, 11, 3, 10), |
| datetime.datetime(2018, 3, 11, 3, 10), |
| datetime.datetime(2018, 3, 11, 3, 10), |
| datetime.datetime(2018, 3, 11, 3, 40), |
| datetime.datetime(2018, 3, 11, 4, 20), |
| datetime.datetime(2018, 3, 11, 3, 30), |
| ]})) |