blob: 049159e4c1fc97813f19e46e8d4d26dce443bfb3 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
from __future__ import division
import json
import logging
import math
import os
import tempfile
import unittest
from builtins import range
import sys
import hamcrest as hc
import avro
import avro.datafile
from avro.datafile import DataFileWriter
from avro.io import DatumWriter
from fastavro.schema import parse_schema
from fastavro import writer
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from avro.schema import Parse # avro-python3 library for python3
except ImportError:
from avro.schema import parse as Parse # avro library for python2
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
import apache_beam as beam
from apache_beam import Create
from apache_beam.io import avroio
from apache_beam.io import filebasedsource
from apache_beam.io import iobase
from apache_beam.io import source_test_utils
from apache_beam.io.avroio import _create_avro_sink # For testing
from apache_beam.io.avroio import _create_avro_source # For testing
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
# Import snappy optionally; some tests will be skipped when import fails.
try:
import snappy # pylint: disable=import-error
except ImportError:
snappy = None # pylint: disable=invalid-name
logging.warning('python-snappy is not installed; some tests will be skipped.')
RECORDS = [
{'name': 'Thomas',
'favorite_number': 1,
'favorite_color': 'blue'},
{'name': 'Henry',
'favorite_number': 3,
'favorite_color': 'green'},
{'name': 'Toby',
'favorite_number': 7,
'favorite_color': 'brown'},
{'name': 'Gordon',
'favorite_number': 4,
'favorite_color': 'blue'},
{'name': 'Emily',
'favorite_number': -1,
'favorite_color': 'Red'},
{'name': 'Percy',
'favorite_number': 6,
'favorite_color': 'Green'}
]
class AvroBase(object):
_temp_files = []
def __init__(self, methodName='runTest'):
super(AvroBase, self).__init__(methodName)
self.RECORDS = RECORDS
self.SCHEMA_STRING = '''
{"namespace": "example.avro",
"type": "record",
"name": "User",
"fields": [
{"name": "name", "type": "string"},
{"name": "favorite_number", "type": ["int", "null"]},
{"name": "favorite_color", "type": ["string", "null"]}
]
}
'''
@classmethod
def setUpClass(cls):
# Method has been renamed in Python 3
if sys.version_info[0] < 3:
cls.assertCountEqual = cls.assertItemsEqual
def setUp(self):
# Reducing the size of thread pools. Without this test execution may fail in
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
def tearDown(self):
for path in self._temp_files:
if os.path.exists(path):
os.remove(path)
self._temp_files = []
def _write_data(self, directory, prefix, codec, count, sync_interval):
raise NotImplementedError
def _write_pattern(self, num_files):
assert num_files > 0
temp_dir = tempfile.mkdtemp()
file_name = None
for _ in range(num_files):
file_name = self._write_data(directory=temp_dir, prefix='mytemp')
assert file_name
file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
return file_name_prefix + os.path.sep + 'mytemp*'
def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
expected_result):
source = _create_avro_source(pattern, use_fastavro=self.use_fastavro)
if perform_splitting:
assert desired_bundle_size
splits = [
split
for split in source.split(desired_bundle_size=desired_bundle_size)
]
if len(splits) < 2:
raise ValueError('Test is trivial. Please adjust it so that at least '
'two splits get generated')
sources_info = [
(split.source, split.start_position, split.stop_position)
for split in splits
]
source_test_utils.assert_sources_equal_reference_source(
(source, None, None), sources_info)
else:
read_records = source_test_utils.read_from_source(source, None, None)
self.assertCountEqual(expected_result, read_records)
def test_read_without_splitting(self):
file_name = self._write_data()
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting(self):
file_name = self._write_data()
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
def test_source_display_data(self):
file_name = 'some_avro_source'
source = \
_create_avro_source(
file_name,
validate=False,
use_fastavro=self.use_fastavro
)
dd = DisplayData.create_from(source)
# No extra avro parameters for AvroSource.
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_read_display_data(self):
file_name = 'some_avro_source'
read = \
avroio.ReadFromAvro(
file_name,
validate=False,
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(read)
# No extra avro parameters for AvroSource.
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_sink_display_data(self):
file_name = 'some_avro_sink'
sink = _create_avro_sink(
file_name,
self.SCHEMA,
'null',
'.end',
0,
None,
'application/x-avro',
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(sink)
expected_items = [
DisplayDataItemMatcher(
'schema',
str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher(
'codec',
'null'),
DisplayDataItemMatcher(
'compression',
'uncompressed')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_write_display_data(self):
file_name = 'some_avro_sink'
write = avroio.WriteToAvro(file_name,
self.SCHEMA,
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(write)
expected_items = [
DisplayDataItemMatcher(
'schema',
str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher(
'codec',
'deflate'),
DisplayDataItemMatcher(
'compression',
'uncompressed')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_read_reentrant_without_splitting(self):
file_name = self._write_data()
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def test_read_reantrant_with_splitting(self):
file_name = self._write_data()
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
splits = [
split for split in source.split(desired_bundle_size=100000)]
assert len(splits) == 1
source_test_utils.assert_reentrant_reads_succeed(
(splits[0].source, splits[0].start_position, splits[0].stop_position))
def test_read_without_splitting_multiple_blocks(self):
file_name = self._write_data(count=12000)
expected_result = self.RECORDS * 2000
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting_multiple_blocks(self):
file_name = self._write_data(count=12000)
expected_result = self.RECORDS * 2000
self._run_avro_test(file_name, 10000, True, expected_result)
def test_split_points(self):
num_records = 12000
sync_interval = 16000
file_name = self._write_data(count=num_records,
sync_interval=sync_interval)
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
splits = [
split
for split in source.split(desired_bundle_size=float('inf'))
]
assert len(splits) == 1
range_tracker = splits[0].source.get_range_tracker(
splits[0].start_position, splits[0].stop_position)
split_points_report = []
for _ in splits[0].source.read(range_tracker):
split_points_report.append(range_tracker.split_points())
# There will be a total of num_blocks in the generated test file,
# proportional to number of records in the file divided by syncronization
# interval used by avro during write. Each block has more than 10 records.
num_blocks = int(math.ceil(14.5 * num_records /
sync_interval))
assert num_blocks > 1
# When reading records of the first block, range_tracker.split_points()
# should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
self.assertEqual(
split_points_report[:10],
[(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)
# When reading records of last block, range_tracker.split_points() should
# return (num_blocks - 1, 1)
self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
def test_read_without_splitting_compressed_deflate(self):
file_name = self._write_data(codec='deflate')
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting_compressed_deflate(self):
file_name = self._write_data(codec='deflate')
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
@unittest.skipIf(snappy is None, 'python-snappy not installed.')
def test_read_without_splitting_compressed_snappy(self):
file_name = self._write_data(codec='snappy')
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
@unittest.skipIf(snappy is None, 'python-snappy not installed.')
def test_read_with_splitting_compressed_snappy(self):
file_name = self._write_data(codec='snappy')
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
def test_read_without_splitting_pattern(self):
pattern = self._write_pattern(3)
expected_result = self.RECORDS * 3
self._run_avro_test(pattern, None, False, expected_result)
def test_read_with_splitting_pattern(self):
pattern = self._write_pattern(3)
expected_result = self.RECORDS * 3
self._run_avro_test(pattern, 100, True, expected_result)
def test_dynamic_work_rebalancing_exhaustive(self):
def compare_split_points(file_name):
source = _create_avro_source(file_name,
use_fastavro=self.use_fastavro)
splits = [split
for split in source.split(desired_bundle_size=float('inf'))]
assert len(splits) == 1
source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
# Adjusting block size so that we can perform a exhaustive dynamic
# work rebalancing test that completes within an acceptable amount of time.
file_name = self._write_data(count=5, sync_interval=2)
compare_split_points(file_name)
def test_corrupted_file(self):
file_name = self._write_data()
with open(file_name, 'rb') as f:
data = f.read()
# Corrupt the last character of the file which is also the last character of
# the last sync_marker.
# https://avro.apache.org/docs/current/spec.html#Object+Container+Files
corrupted_data = bytearray(data)
corrupted_data[-1] = (corrupted_data[-1] + 1) % 256
with tempfile.NamedTemporaryFile(
delete=False, prefix=tempfile.template) as f:
f.write(corrupted_data)
corrupted_file_name = f.name
source = _create_avro_source(
corrupted_file_name, use_fastavro=self.use_fastavro)
with self.assertRaisesRegexp(ValueError, r'expected sync marker'):
source_test_utils.read_from_source(source, None, None)
def test_read_from_avro(self):
path = self._write_data()
with TestPipeline() as p:
assert_that(
p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro),
equal_to(self.RECORDS))
def test_read_all_from_avro_single_file(self):
path = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS))
def test_read_all_from_avro_many_single_files(self):
path1 = self._write_data()
path2 = self._write_data()
path3 = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path1, path2, path3]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 3))
def test_read_all_from_avro_file_pattern(self):
file_pattern = self._write_pattern(5)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 5))
def test_read_all_from_avro_many_file_patterns(self):
file_pattern1 = self._write_pattern(5)
file_pattern2 = self._write_pattern(2)
file_pattern3 = self._write_pattern(3)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern1, file_pattern2, file_pattern3]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 10))
def test_sink_transform(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p \
| beam.Create(self.RECORDS) \
| avroio.WriteToAvro(path, self.SCHEMA, use_fastavro=self.use_fastavro)
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \
| beam.Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
@unittest.skipIf(snappy is None, 'python-snappy not installed.')
def test_sink_transform_snappy(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p \
| beam.Create(self.RECORDS) \
| avroio.WriteToAvro(
path,
self.SCHEMA,
codec='snappy',
use_fastavro=self.use_fastavro)
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \
| beam.Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
@unittest.skipIf(sys.version_info[0] == 3 and
os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1',
'This test still needs to be fixed on Python 3. '
'TODO: BEAM-6522.')
class TestAvro(AvroBase, unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestAvro, self).__init__(methodName)
self.use_fastavro = False
self.SCHEMA = Parse(self.SCHEMA_STRING)
def _write_data(self,
directory=None,
prefix=tempfile.template,
codec='null',
count=len(RECORDS),
sync_interval=avro.datafile.SYNC_INTERVAL):
old_sync_interval = avro.datafile.SYNC_INTERVAL
try:
avro.datafile.SYNC_INTERVAL = sync_interval
with tempfile.NamedTemporaryFile(delete=False,
dir=directory,
prefix=prefix) as f:
writer = DataFileWriter(f, DatumWriter(), self.SCHEMA, codec=codec)
len_records = len(self.RECORDS)
for i in range(count):
writer.append(self.RECORDS[i % len_records])
writer.close()
self._temp_files.append(f.name)
return f.name
finally:
avro.datafile.SYNC_INTERVAL = old_sync_interval
class TestFastAvro(AvroBase, unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestFastAvro, self).__init__(methodName)
self.use_fastavro = True
self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING))
def _write_data(self,
directory=None,
prefix=tempfile.template,
codec='null',
count=len(RECORDS),
**kwargs):
all_records = self.RECORDS * \
(count // len(self.RECORDS)) + self.RECORDS[:(count % len(self.RECORDS))]
with tempfile.NamedTemporaryFile(delete=False,
dir=directory,
prefix=prefix,
mode='w+b') as f:
writer(f, self.SCHEMA, all_records,
codec=codec, **kwargs)
self._temp_files.append(f.name)
return f.name
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()