blob: c8c845c7790f2c7a9659a87d2a80be3f87c95b0e [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import glob
import os
import tempfile
from typing import Tuple, List
from avro.datafile import DataFileReader
from avro.io import DatumReader
from py4j.java_gateway import JavaObject, java_import
from pyflink.datastream import MapFunction
from pyflink.datastream.connectors.file_system import FileSink
from pyflink.datastream.formats.avro import AvroSchema, GenericRecordAvroTypeInfo, \
AvroBulkWriters, AvroInputFormat
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.java_gateway import get_gateway
from pyflink.testing.test_case_utils import PyFlinkUTTestCase
class FileSourceAvroInputFormatTests(PyFlinkUTTestCase):
def setUp(self):
super().setUp()
self.test_sink = DataStreamTestSinkFunction()
self.avro_file_name = tempfile.mktemp(suffix='.avro', dir=self.tempdir)
_import_avro_classes()
def test_avro_basic_read(self):
schema, records = _create_basic_avro_schema_and_records()
self._create_avro_file(schema, records)
self._build_avro_job(schema)
self.env.execute('test_avro_basic_read')
results = self.test_sink.get_results(True, False)
_check_basic_avro_schema_results(self, results)
def test_avro_enum_read(self):
schema, records = _create_enum_avro_schema_and_records()
self._create_avro_file(schema, records)
self._build_avro_job(schema)
self.env.execute('test_avro_enum_read')
results = self.test_sink.get_results(True, False)
_check_enum_avro_schema_results(self, results)
def test_avro_union_read(self):
schema, records = _create_union_avro_schema_and_records()
self._create_avro_file(schema, records)
self._build_avro_job(schema)
self.env.execute('test_avro_union_read')
results = self.test_sink.get_results(True, False)
_check_union_avro_schema_results(self, results)
def test_avro_array_read(self):
schema, records = _create_array_avro_schema_and_records()
self._create_avro_file(schema, records)
self._build_avro_job(schema)
self.env.execute('test_avro_array_read')
results = self.test_sink.get_results(True, False)
_check_array_avro_schema_results(self, results)
def test_avro_map_read(self):
schema, records = _create_map_avro_schema_and_records()
self._create_avro_file(schema, records)
self._build_avro_job(schema)
self.env.execute('test_avro_map_read')
results = self.test_sink.get_results(True, False)
_check_map_avro_schema_results(self, results)
def _build_avro_job(self, record_schema):
ds = self.env.create_input(AvroInputFormat(self.avro_file_name, record_schema))
ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
def _create_avro_file(self, schema: AvroSchema, records: list):
jvm = get_gateway().jvm
j_file = jvm.java.io.File(self.avro_file_name)
j_datum_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.generic \
.GenericDatumWriter()
j_file_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.file \
.DataFileWriter(j_datum_writer)
j_file_writer.create(schema._j_schema, j_file)
for r in records:
j_file_writer.append(r)
j_file_writer.close()
class FileSinkAvroWritersTests(PyFlinkUTTestCase):
def setUp(self) -> None:
super().setUp()
# NOTE: parallelism == 1 is required to keep the order of results
self.env.set_parallelism(1)
self.avro_dir_name = tempfile.mkdtemp(dir=self.tempdir)
def test_avro_basic_write(self):
schema, objects = _create_basic_avro_schema_and_py_objects()
self._build_avro_job(schema, objects)
self.env.execute('test_avro_basic_write')
results = self._read_avro_file()
_check_basic_avro_schema_results(self, results)
def test_avro_enum_write(self):
schema, objects = _create_enum_avro_schema_and_py_objects()
self._build_avro_job(schema, objects)
self.env.execute('test_avro_enum_write')
results = self._read_avro_file()
_check_enum_avro_schema_results(self, results)
def test_avro_union_write(self):
schema, objects = _create_union_avro_schema_and_py_objects()
self._build_avro_job(schema, objects)
self.env.execute('test_avro_union_write')
results = self._read_avro_file()
_check_union_avro_schema_results(self, results)
def test_avro_array_write(self):
schema, objects = _create_array_avro_schema_and_py_objects()
self._build_avro_job(schema, objects)
self.env.execute('test_avro_array_write')
results = self._read_avro_file()
_check_array_avro_schema_results(self, results)
def test_avro_map_write(self):
schema, objects = _create_map_avro_schema_and_py_objects()
self._build_avro_job(schema, objects)
self.env.execute('test_avro_map_write')
results = self._read_avro_file()
_check_map_avro_schema_results(self, results)
def _build_avro_job(self, schema, objects):
ds = self.env.from_collection(objects)
sink = FileSink.for_bulk_format(
self.avro_dir_name, AvroBulkWriters.for_generic_record(schema)
).build()
ds.map(lambda e: e, output_type=GenericRecordAvroTypeInfo(schema)).sink_to(sink)
def _read_avro_file(self) -> List[dict]:
records = []
for file in glob.glob(os.path.join(os.path.join(self.avro_dir_name, '**/*'))):
for record in DataFileReader(open(file, 'rb'), DatumReader()):
records.append(record)
return records
class PassThroughMapFunction(MapFunction):
def map(self, value):
return value
def _import_avro_classes():
jvm = get_gateway().jvm
classes = ['org.apache.avro.generic.GenericData']
prefix = 'org.apache.flink.avro.shaded.'
for cls in classes:
java_import(jvm, prefix + cls)
BASIC_SCHEMA = """
{
"type": "record",
"name": "test",
"fields": [
{ "name": "null", "type": "null" },
{ "name": "boolean", "type": "boolean" },
{ "name": "int", "type": "int" },
{ "name": "long", "type": "long" },
{ "name": "float", "type": "float" },
{ "name": "double", "type": "double" },
{ "name": "string", "type": "string" }
]
}
"""
def _create_basic_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
schema = AvroSchema.parse_string(BASIC_SCHEMA)
records = [_create_basic_avro_record(schema, True, 0, 1, 2, 3, 's1'),
_create_basic_avro_record(schema, False, 4, 5, 6, 7, 's2')]
return schema, records
def _create_basic_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
schema = AvroSchema.parse_string(BASIC_SCHEMA)
objects = [
{'null': None, 'boolean': True, 'int': 0, 'long': 1,
'float': 2., 'double': 3., 'string': 's1'},
{'null': None, 'boolean': False, 'int': 4, 'long': 5,
'float': 6., 'double': 7., 'string': 's2'},
]
return schema, objects
def _check_basic_avro_schema_results(test, results):
result1 = results[0]
result2 = results[1]
test.assertEqual(result1['null'], None)
test.assertEqual(result1['boolean'], True)
test.assertEqual(result1['int'], 0)
test.assertEqual(result1['long'], 1)
test.assertAlmostEqual(result1['float'], 2, delta=1e-3)
test.assertAlmostEqual(result1['double'], 3, delta=1e-3)
test.assertEqual(result1['string'], 's1')
test.assertEqual(result2['null'], None)
test.assertEqual(result2['boolean'], False)
test.assertEqual(result2['int'], 4)
test.assertEqual(result2['long'], 5)
test.assertAlmostEqual(result2['float'], 6, delta=1e-3)
test.assertAlmostEqual(result2['double'], 7, delta=1e-3)
test.assertEqual(result2['string'], 's2')
ENUM_SCHEMA = """
{
"type": "record",
"name": "test",
"fields": [
{
"name": "suit",
"type": {
"type": "enum",
"name": "Suit",
"symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
}
}
]
}
"""
def _create_enum_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
schema = AvroSchema.parse_string(ENUM_SCHEMA)
records = [_create_enum_avro_record(schema, 'SPADES'),
_create_enum_avro_record(schema, 'DIAMONDS')]
return schema, records
def _create_enum_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
schema = AvroSchema.parse_string(ENUM_SCHEMA)
records = [
{'suit': 'SPADES'},
{'suit': 'DIAMONDS'},
]
return schema, records
def _check_enum_avro_schema_results(test, results):
test.assertEqual(results[0]['suit'], 'SPADES')
test.assertEqual(results[1]['suit'], 'DIAMONDS')
UNION_SCHEMA = """
{
"type": "record",
"name": "test",
"fields": [
{
"name": "union",
"type": [ "int", "double", "null" ]
}
]
}
"""
def _create_union_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
schema = AvroSchema.parse_string(UNION_SCHEMA)
records = [_create_union_avro_record(schema, 1),
_create_union_avro_record(schema, 2.),
_create_union_avro_record(schema, None)]
return schema, records
def _create_union_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
schema = AvroSchema.parse_string(UNION_SCHEMA)
records = [
{'union': 1},
{'union': 2.},
{'union': None},
]
return schema, records
def _check_union_avro_schema_results(test, results):
test.assertEqual(results[0]['union'], 1)
test.assertAlmostEqual(results[1]['union'], 2.0, delta=1e-3)
test.assertEqual(results[2]['union'], None)
# It seems there's bug when array item record contains only one field, which throws
# java.lang.ClassCastException: required ... is not a group when reading
ARRAY_SCHEMA = """
{
"type": "record",
"name": "test",
"fields": [
{
"name": "array",
"type": {
"type": "array",
"items": {
"type": "record",
"name": "item",
"fields": [
{ "name": "int", "type": "int" },
{ "name": "double", "type": "double" }
]
}
}
}
]
}
"""
def _create_array_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
schema = AvroSchema.parse_string(ARRAY_SCHEMA)
records = [_create_array_avro_record(schema, [(1, 2.), (3, 4.)]),
_create_array_avro_record(schema, [(5, 6.), (7, 8.)])]
return schema, records
def _create_array_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
schema = AvroSchema.parse_string(ARRAY_SCHEMA)
records = [
{'array': [{'int': 1, 'double': 2.}, {'int': 3, 'double': 4.}]},
{'array': [{'int': 5, 'double': 6.}, {'int': 7, 'double': 8.}]},
]
return schema, records
def _check_array_avro_schema_results(test, results):
result1 = results[0]
result2 = results[1]
test.assertEqual(result1['array'][0]['int'], 1)
test.assertAlmostEqual(result1['array'][0]['double'], 2., delta=1e-3)
test.assertEqual(result1['array'][1]['int'], 3)
test.assertAlmostEqual(result1['array'][1]['double'], 4., delta=1e-3)
test.assertEqual(result2['array'][0]['int'], 5)
test.assertAlmostEqual(result2['array'][0]['double'], 6., delta=1e-3)
test.assertEqual(result2['array'][1]['int'], 7)
test.assertAlmostEqual(result2['array'][1]['double'], 8., delta=1e-3)
MAP_SCHEMA = """
{
"type": "record",
"name": "test",
"fields": [
{
"name": "map",
"type": {
"type": "map",
"values": "long"
}
}
]
}
"""
def _create_map_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
schema = AvroSchema.parse_string(MAP_SCHEMA)
records = [_create_map_avro_record(schema, {'a': 1, 'b': 2}),
_create_map_avro_record(schema, {'c': 3, 'd': 4})]
return schema, records
def _create_map_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
schema = AvroSchema.parse_string(MAP_SCHEMA)
records = [
{'map': {'a': 1, 'b': 2}},
{'map': {'c': 3, 'd': 4}},
]
return schema, records
def _check_map_avro_schema_results(test, results):
result1 = results[0]
result2 = results[1]
test.assertEqual(result1['map']['a'], 1)
test.assertEqual(result1['map']['b'], 2)
test.assertEqual(result2['map']['c'], 3)
test.assertEqual(result2['map']['d'], 4)
def _create_basic_avro_record(schema: AvroSchema, boolean_value, int_value, long_value,
float_value, double_value, string_value):
jvm = get_gateway().jvm
j_record = jvm.GenericData.Record(schema._j_schema)
j_record.put('boolean', boolean_value)
j_record.put('int', int_value)
j_record.put('long', long_value)
j_record.put('float', float_value)
j_record.put('double', double_value)
j_record.put('string', string_value)
return j_record
def _create_enum_avro_record(schema: AvroSchema, enum_value):
jvm = get_gateway().jvm
j_record = jvm.GenericData.Record(schema._j_schema)
j_enum = jvm.GenericData.EnumSymbol(schema._j_schema.getField('suit').schema(), enum_value)
j_record.put('suit', j_enum)
return j_record
def _create_union_avro_record(schema, union_value):
jvm = get_gateway().jvm
j_record = jvm.GenericData.Record(schema._j_schema)
j_record.put('union', union_value)
return j_record
def _create_array_avro_record(schema, item_values: list):
jvm = get_gateway().jvm
j_record = jvm.GenericData.Record(schema._j_schema)
item_schema = AvroSchema(schema._j_schema.getField('array').schema().getElementType())
j_array = jvm.java.util.ArrayList()
for idx, item_value in enumerate(item_values):
j_item = jvm.GenericData.Record(item_schema._j_schema)
j_item.put('int', item_value[0])
j_item.put('double', item_value[1])
j_array.add(j_item)
j_record.put('array', j_array)
return j_record
def _create_map_avro_record(schema, map: dict):
jvm = get_gateway().jvm
j_record = jvm.GenericData.Record(schema._j_schema)
j_map = jvm.java.util.HashMap()
for k, v in map.items():
j_map.put(k, v)
j_record.put('map', j_map)
return j_record