blob: 40e8ad3cc971b66bb1956abd6aa29461ed7d41b8 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import calendar
import datetime
import glob
import os
import tempfile
import time
import unittest
from decimal import Decimal
from typing import List, Tuple
import pandas as pd
import pytz
from pyflink.common.time import Instant
from pyflink.common import Configuration, Row
from pyflink.common.typeinfo import RowTypeInfo, Types
from pyflink.common.watermark_strategy import WatermarkStrategy
from pyflink.datastream.connectors.file_system import FileSource, FileSink
from pyflink.datastream.formats.tests.test_avro import \
_create_basic_avro_schema_and_py_objects, _check_basic_avro_schema_results, \
_create_enum_avro_schema_and_py_objects, _check_enum_avro_schema_results, \
_create_union_avro_schema_and_py_objects, _check_union_avro_schema_results, \
_create_array_avro_schema_and_py_objects, _check_array_avro_schema_results, \
_create_map_avro_schema_and_py_objects, _check_map_avro_schema_results, \
_create_map_avro_schema_and_records, _create_array_avro_schema_and_records, \
_create_union_avro_schema_and_records, _create_enum_avro_schema_and_records, \
_create_basic_avro_schema_and_records, _import_avro_classes
from pyflink.datastream.formats.avro import GenericRecordAvroTypeInfo, AvroSchema
from pyflink.datastream.formats.parquet import AvroParquetReaders, ParquetColumnarRowInputFormat, \
AvroParquetWriters, ParquetBulkWriters
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.datastream.utils import create_hadoop_configuration
from pyflink.java_gateway import get_gateway
from pyflink.table.types import RowType, DataTypes, _to_java_data_type
from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase, to_java_data_structure
@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
'Some Hadoop lib is needed for Parquet-Avro format tests')
class FileSourceAvroParquetReadersTests(PyFlinkStreamingTestCase):
def setUp(self):
super().setUp()
self.test_sink = DataStreamTestSinkFunction()
_import_avro_classes()
def test_parquet_avro_basic(self):
parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
schema, records = _create_basic_avro_schema_and_records()
self._create_parquet_avro_file(parquet_file_name, schema, records)
self._build_parquet_avro_job(schema, parquet_file_name)
self.env.execute("test_parquet_avro_basic")
results = self.test_sink.get_results(True, False)
_check_basic_avro_schema_results(self, results)
def test_parquet_avro_enum(self):
parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
schema, records = _create_enum_avro_schema_and_records()
self._create_parquet_avro_file(parquet_file_name, schema, records)
self._build_parquet_avro_job(schema, parquet_file_name)
self.env.execute("test_parquet_avro_enum")
results = self.test_sink.get_results(True, False)
_check_enum_avro_schema_results(self, results)
def test_parquet_avro_union(self):
parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
schema, records = _create_union_avro_schema_and_records()
self._create_parquet_avro_file(parquet_file_name, schema, records)
self._build_parquet_avro_job(schema, parquet_file_name)
self.env.execute("test_parquet_avro_union")
results = self.test_sink.get_results(True, False)
_check_union_avro_schema_results(self, results)
def test_parquet_avro_array(self):
parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
schema, records = _create_array_avro_schema_and_records()
self._create_parquet_avro_file(parquet_file_name, schema, records)
self._build_parquet_avro_job(schema, parquet_file_name)
self.env.execute("test_parquet_avro_array")
results = self.test_sink.get_results(True, False)
_check_array_avro_schema_results(self, results)
def test_parquet_avro_map(self):
parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
schema, records = _create_map_avro_schema_and_records()
self._create_parquet_avro_file(parquet_file_name, schema, records)
self._build_parquet_avro_job(schema, parquet_file_name)
self.env.execute("test_parquet_avro_map")
results = self.test_sink.get_results(True, False)
_check_map_avro_schema_results(self, results)
def _build_parquet_avro_job(self, record_schema, *parquet_file_name):
ds = self.env.from_source(
FileSource.for_record_stream_format(
AvroParquetReaders.for_generic_record(record_schema),
*parquet_file_name
).build(),
WatermarkStrategy.for_monotonous_timestamps(),
"parquet-source"
)
ds.map(lambda e: e).add_sink(self.test_sink)
@staticmethod
def _create_parquet_avro_file(file_path: str, schema: AvroSchema, records: list):
jvm = get_gateway().jvm
j_path = jvm.org.apache.flink.core.fs.Path(file_path)
writer = jvm.org.apache.flink.formats.parquet.avro.AvroParquetWriters \
.forGenericRecord(schema._j_schema) \
.create(j_path.getFileSystem().create(
j_path,
jvm.org.apache.flink.core.fs.FileSystem.WriteMode.OVERWRITE
))
for record in records:
writer.addElement(record)
writer.flush()
writer.finish()
@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
'Some Hadoop lib is needed for Parquet-Avro format tests')
class FileSinkAvroParquetWritersTests(PyFlinkStreamingTestCase):
def setUp(self):
super().setUp()
# NOTE: parallelism == 1 is required to keep the order of results
self.env.set_parallelism(1)
self.parquet_dir_name = tempfile.mkdtemp(dir=self.tempdir)
self.test_sink = DataStreamTestSinkFunction()
def test_parquet_avro_basic_write(self):
schema, objects = _create_basic_avro_schema_and_py_objects()
self._build_avro_parquet_job(schema, objects)
self.env.execute('test_parquet_avro_basic_write')
results = self._read_parquet_avro_file(schema)
_check_basic_avro_schema_results(self, results)
def test_parquet_avro_enum_write(self):
schema, objects = _create_enum_avro_schema_and_py_objects()
self._build_avro_parquet_job(schema, objects)
self.env.execute('test_parquet_avro_enum_write')
results = self._read_parquet_avro_file(schema)
_check_enum_avro_schema_results(self, results)
def test_parquet_avro_union_write(self):
schema, objects = _create_union_avro_schema_and_py_objects()
self._build_avro_parquet_job(schema, objects)
self.env.execute('test_parquet_avro_union_write')
results = self._read_parquet_avro_file(schema)
_check_union_avro_schema_results(self, results)
def test_parquet_avro_array_write(self):
schema, objects = _create_array_avro_schema_and_py_objects()
self._build_avro_parquet_job(schema, objects)
self.env.execute('test_parquet_avro_array_write')
results = self._read_parquet_avro_file(schema)
_check_array_avro_schema_results(self, results)
def test_parquet_avro_map_write(self):
schema, objects = _create_map_avro_schema_and_py_objects()
self._build_avro_parquet_job(schema, objects)
self.env.execute('test_parquet_avro_map_write')
results = self._read_parquet_avro_file(schema)
_check_map_avro_schema_results(self, results)
def _build_avro_parquet_job(self, schema, objects):
ds = self.env.from_collection(objects)
avro_type_info = GenericRecordAvroTypeInfo(schema)
sink = FileSink.for_bulk_format(
self.parquet_dir_name, AvroParquetWriters.for_generic_record(schema)
).build()
ds.map(lambda e: e, output_type=avro_type_info).sink_to(sink)
def _read_parquet_avro_file(self, schema) -> List[dict]:
parquet_files = [f for f in glob.glob(self.parquet_dir_name, recursive=True)]
FileSourceAvroParquetReadersTests._build_parquet_avro_job(self, schema, *parquet_files)
self.env.execute()
return self.test_sink.get_results(True, False)
@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
'Some Hadoop lib is needed for Parquet Columnar format tests')
class FileSourceParquetColumnarRowInputFormatTests(PyFlinkStreamingTestCase):
def setUp(self):
super().setUp()
self.test_sink = DataStreamTestSinkFunction()
self.parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
def test_parquet_columnar_basic_read(self):
os.environ['TZ'] = 'Asia/Shanghai'
time.tzset()
row_type, _, data = _create_parquet_basic_row_and_data()
_write_row_data_to_parquet_file(self.parquet_file_name, row_type, data)
self._build_parquet_columnar_job(row_type)
self.env.execute('test_parquet_columnar_basic_read')
results = self.test_sink.get_results(True, False)
_check_parquet_basic_results(self, results)
def _build_parquet_columnar_job(self, row_type: RowType):
source = FileSource.for_bulk_file_format(
ParquetColumnarRowInputFormat(row_type, Configuration(), 10, True, False),
self.parquet_file_name
).build()
ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'parquet-source')
ds.map(lambda e: e).add_sink(self.test_sink)
@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
'Some Hadoop lib is needed for Parquet RowData format tests')
class FileSinkParquetBulkWriterTests(PyFlinkStreamingTestCase):
def setUp(self):
super().setUp()
# NOTE: parallelism == 1 is required to keep the order of results
self.env.set_parallelism(1)
self.parquet_dir_name = tempfile.mkdtemp(dir=self.tempdir)
def test_parquet_row_data_basic_write(self):
os.environ['TZ'] = 'Asia/Shanghai'
time.tzset()
row_type, row_type_info, data = _create_parquet_basic_row_and_data()
self._build_parquet_job(row_type, row_type_info, data)
self.env.execute('test_parquet_row_data_basic_write')
results = self._read_parquet_file()
_check_parquet_basic_results(self, results)
def test_parquet_row_data_array_write(self):
row_type, row_type_info, data = _create_parquet_array_row_and_data()
self._build_parquet_job(row_type, row_type_info, data)
self.env.execute('test_parquet_row_data_array_write')
results = self._read_parquet_file()
_check_parquet_array_results(self, results)
@unittest.skip('ParquetSchemaConverter in flink-parquet annotate map keys as optional, but '
'Arrow restricts them to be required')
def test_parquet_row_data_map_write(self):
row_type, row_type_info, data = _create_parquet_map_row_and_data()
self._build_parquet_job(row_type, row_type_info, data)
self.env.execute('test_parquet_row_data_map_write')
results = self._read_parquet_file()
_check_parquet_map_results(self, results)
def _build_parquet_job(self, row_type: RowType, row_type_info: RowTypeInfo, data: List[Row]):
sink = FileSink.for_bulk_format(
self.parquet_dir_name, ParquetBulkWriters.for_row_type(row_type, utc_timestamp=True)
).build()
ds = self.env.from_collection(data, type_info=row_type_info)
ds.sink_to(sink)
def _read_parquet_file(self):
records = []
for file in glob.glob(os.path.join(os.path.join(self.parquet_dir_name, '**/*'))):
df = pd.read_parquet(file)
for i in range(df.shape[0]):
records.append(df.loc[i])
return records
def _write_row_data_to_parquet_file(path: str, row_type: RowType, rows: List[Row]):
jvm = get_gateway().jvm
flink = jvm.org.apache.flink
j_output_stream = flink.core.fs.local.LocalDataOutputStream(jvm.java.io.File(path))
j_bulk_writer = flink.formats.parquet.row.ParquetRowDataBuilder.createWriterFactory(
_to_java_data_type(row_type).getLogicalType(),
create_hadoop_configuration(Configuration()),
True,
).create(j_output_stream)
row_row_converter = flink.table.data.conversion.RowRowConverter.create(
_to_java_data_type(row_type)
)
row_row_converter.open(row_row_converter.getClass().getClassLoader())
for row in rows:
j_bulk_writer.addElement(row_row_converter.toInternal(to_java_data_structure(row)))
j_bulk_writer.finish()
def _create_parquet_basic_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]:
row_type = DataTypes.ROW([
DataTypes.FIELD('char', DataTypes.CHAR(10)),
DataTypes.FIELD('varchar', DataTypes.VARCHAR(10)),
DataTypes.FIELD('binary', DataTypes.BINARY(10)),
DataTypes.FIELD('varbinary', DataTypes.VARBINARY(10)),
DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
DataTypes.FIELD('decimal', DataTypes.DECIMAL(2, 0)),
DataTypes.FIELD('int', DataTypes.INT()),
DataTypes.FIELD('bigint', DataTypes.BIGINT()),
DataTypes.FIELD('double', DataTypes.DOUBLE()),
DataTypes.FIELD('date', DataTypes.DATE().bridged_to('java.sql.Date')),
DataTypes.FIELD('time', DataTypes.TIME().bridged_to('java.sql.Time')),
DataTypes.FIELD('timestamp', DataTypes.TIMESTAMP(3).bridged_to('java.sql.Timestamp')),
DataTypes.FIELD('timestamp_ltz', DataTypes.TIMESTAMP_LTZ(3)),
])
row_type_info = Types.ROW_NAMED(
['char', 'varchar', 'binary', 'varbinary', 'boolean', 'decimal', 'int', 'bigint', 'double',
'date', 'time', 'timestamp', 'timestamp_ltz'],
[Types.STRING(), Types.STRING(), Types.PRIMITIVE_ARRAY(Types.BYTE()),
Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.BOOLEAN(), Types.BIG_DEC(), Types.INT(),
Types.LONG(), Types.DOUBLE(), Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP(),
Types.INSTANT()]
)
datetime_ltz = datetime.datetime(1970, 2, 3, 4, 5, 6, 700000, tzinfo=pytz.timezone('UTC'))
timestamp_ltz = Instant.of_epoch_milli(
(
calendar.timegm(datetime_ltz.utctimetuple()) +
calendar.timegm(time.localtime(0))
) * 1000 + datetime_ltz.microsecond // 1000
)
data = [Row(
char='char',
varchar='varchar',
binary=b'binary',
varbinary=b'varbinary',
boolean=True,
decimal=Decimal(1.5),
int=2147483647,
bigint=-9223372036854775808,
double=2e-308,
date=datetime.date(1970, 1, 1),
time=datetime.time(1, 1, 1),
timestamp=datetime.datetime(1970, 1, 2, 3, 4, 5, 600000),
timestamp_ltz=timestamp_ltz
)]
return row_type, row_type_info, data
def _check_parquet_basic_results(test, results):
row = results[0]
test.assertEqual(row['char'], 'char')
test.assertEqual(row['varchar'], 'varchar')
test.assertEqual(row['binary'], b'binary')
test.assertEqual(row['varbinary'], b'varbinary')
test.assertEqual(row['boolean'], True)
test.assertAlmostEqual(row['decimal'], 2)
test.assertEqual(row['int'], 2147483647)
test.assertEqual(row['bigint'], -9223372036854775808)
test.assertAlmostEqual(row['double'], 2e-308, delta=1e-311)
test.assertEqual(row['date'], datetime.date(1970, 1, 1))
test.assertEqual(row['time'], datetime.time(1, 1, 1))
ts = row['timestamp']
if isinstance(ts, pd.Timestamp):
ts = ts.to_pydatetime()
test.assertEqual(ts, datetime.datetime(1970, 1, 2, 3, 4, 5, 600000))
ts_ltz = row['timestamp_ltz']
if isinstance(ts_ltz, pd.Timestamp):
ts_ltz = pytz.timezone('Asia/Shanghai').localize(ts_ltz.to_pydatetime())
test.assertEqual(
ts_ltz,
pytz.timezone('Asia/Shanghai').localize(datetime.datetime(1970, 2, 3, 12, 5, 6, 700000))
)
def _create_parquet_array_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]:
row_type = DataTypes.ROW([
DataTypes.FIELD(
'string_array',
DataTypes.ARRAY(DataTypes.STRING()).bridged_to('java.util.ArrayList')
),
DataTypes.FIELD(
'int_array',
DataTypes.ARRAY(DataTypes.INT()).bridged_to('java.util.ArrayList')
),
])
row_type_info = Types.ROW_NAMED([
'string_array',
'int_array',
], [
Types.LIST(Types.STRING()),
Types.LIST(Types.INT()),
])
data = [Row(
string_array=['a', 'b', 'c'],
int_array=[1, 2, 3],
)]
return row_type, row_type_info, data
def _check_parquet_array_results(test, results):
row = results[0]
test.assertEqual(row['string_array'][0], 'a')
test.assertEqual(row['string_array'][1], 'b')
test.assertEqual(row['string_array'][2], 'c')
test.assertEqual(row['int_array'][0], 1)
test.assertEqual(row['int_array'][1], 2)
test.assertEqual(row['int_array'][2], 3)
def _create_parquet_map_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]:
row_type = DataTypes.ROW([
DataTypes.FIELD('map', DataTypes.MAP(DataTypes.INT(), DataTypes.STRING())),
])
row_type_info = Types.ROW_NAMED(['map'], [Types.MAP(Types.INT(), Types.STRING())])
data = [Row(
map={0: 'a', 1: 'b', 2: 'c'}
)]
return row_type, row_type_info, data
def _check_parquet_map_results(test, results):
m = {k: v for k, v in results[0]['map']}
test.assertEqual(m[0], 'a')
test.assertEqual(m[1], 'b')
test.assertEqual(m[2], 'c')