blob: 93f2ba9ebfd390280bb18f270cf2873d7ca72396 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import json
import logging
import os
import tempfile
import unittest
from builtins import range
import avro.datafile
import avro.schema
from avro.datafile import DataFileWriter
from avro.io import DatumWriter
import hamcrest as hc
import apache_beam as beam
from apache_beam import Create
from apache_beam.io import avroio
from apache_beam.io import filebasedsource
from apache_beam.io import iobase
from apache_beam.io import source_test_utils
from apache_beam.io.avroio import _create_avro_sink # For testing
from apache_beam.io.avroio import _create_avro_source # For testing
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
# Import snappy optionally; some tests will be skipped when import fails.
try:
import snappy # pylint: disable=import-error
except ImportError:
snappy = None # pylint: disable=invalid-name
logging.warning('snappy is not installed; some tests will be skipped.')
class TestAvro(unittest.TestCase):
_temp_files = []
def __init__(self, methodName='runTest'):
super(TestAvro, self).__init__(methodName)
self.use_fastavro = False
def setUp(self):
# Reducing the size of thread pools. Without this test execution may fail in
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
def tearDown(self):
for path in self._temp_files:
if os.path.exists(path):
os.remove(path)
self._temp_files = []
RECORDS = [{'name': 'Thomas',
'favorite_number': 1,
'favorite_color': 'blue'}, {'name': 'Henry',
'favorite_number': 3,
'favorite_color': 'green'},
{'name': 'Toby',
'favorite_number': 7,
'favorite_color': 'brown'}, {'name': 'Gordon',
'favorite_number': 4,
'favorite_color': 'blue'},
{'name': 'Emily',
'favorite_number': -1,
'favorite_color': 'Red'}, {'name': 'Percy',
'favorite_number': 6,
'favorite_color': 'Green'}]
SCHEMA = avro.schema.parse('''
{"namespace": "example.avro",
"type": "record",
"name": "User",
"fields": [
{"name": "name", "type": "string"},
{"name": "favorite_number", "type": ["int", "null"]},
{"name": "favorite_color", "type": ["string", "null"]}
]
}
''')
def _write_data(self,
directory=None,
prefix=tempfile.template,
codec='null',
count=len(RECORDS)):
with tempfile.NamedTemporaryFile(
delete=False, dir=directory, prefix=prefix) as f:
writer = DataFileWriter(f, DatumWriter(), self.SCHEMA, codec=codec)
len_records = len(self.RECORDS)
for i in range(count):
writer.append(self.RECORDS[i % len_records])
writer.close()
self._temp_files.append(f.name)
return f.name
def _write_pattern(self, num_files):
assert num_files > 0
temp_dir = tempfile.mkdtemp()
file_name = None
for _ in range(num_files):
file_name = self._write_data(directory=temp_dir, prefix='mytemp')
assert file_name
file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
return file_name_prefix + os.path.sep + 'mytemp*'
def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
expected_result):
source = _create_avro_source(pattern, use_fastavro=self.use_fastavro)
read_records = []
if perform_splitting:
assert desired_bundle_size
splits = [
split
for split in source.split(desired_bundle_size=desired_bundle_size)
]
if len(splits) < 2:
raise ValueError('Test is trivial. Please adjust it so that at least '
'two splits get generated')
sources_info = [
(split.source, split.start_position, split.stop_position)
for split in splits
]
source_test_utils.assert_sources_equal_reference_source(
(source, None, None), sources_info)
else:
read_records = source_test_utils.read_from_source(source, None, None)
self.assertItemsEqual(expected_result, read_records)
def test_read_without_splitting(self):
file_name = self._write_data()
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting(self):
file_name = self._write_data()
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
def test_source_display_data(self):
file_name = 'some_avro_source'
source = \
_create_avro_source(
file_name,
validate=False,
use_fastavro=self.use_fastavro
)
dd = DisplayData.create_from(source)
# No extra avro parameters for AvroSource.
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_read_display_data(self):
file_name = 'some_avro_source'
read = \
avroio.ReadFromAvro(
file_name,
validate=False,
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(read)
# No extra avro parameters for AvroSource.
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_sink_display_data(self):
file_name = 'some_avro_sink'
sink = _create_avro_sink(
file_name,
self.SCHEMA,
'null',
'.end',
0,
None,
'application/x-avro',
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(sink)
expected_items = [
DisplayDataItemMatcher(
'schema',
str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher(
'codec',
'null'),
DisplayDataItemMatcher(
'compression',
'uncompressed')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_write_display_data(self):
file_name = 'some_avro_sink'
write = avroio.WriteToAvro(file_name,
self.SCHEMA,
use_fastavro=self.use_fastavro)
dd = DisplayData.create_from(write)
expected_items = [
DisplayDataItemMatcher(
'schema',
str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher(
'codec',
'deflate'),
DisplayDataItemMatcher(
'compression',
'uncompressed')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_read_reentrant_without_splitting(self):
file_name = self._write_data()
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def test_read_reantrant_with_splitting(self):
file_name = self._write_data()
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
splits = [
split for split in source.split(desired_bundle_size=100000)]
assert len(splits) == 1
source_test_utils.assert_reentrant_reads_succeed(
(splits[0].source, splits[0].start_position, splits[0].stop_position))
def test_read_without_splitting_multiple_blocks(self):
file_name = self._write_data(count=12000)
expected_result = self.RECORDS * 2000
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting_multiple_blocks(self):
file_name = self._write_data(count=12000)
expected_result = self.RECORDS * 2000
self._run_avro_test(file_name, 10000, True, expected_result)
def test_split_points(self):
file_name = self._write_data(count=12000)
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
splits = [
split
for split in source.split(desired_bundle_size=float('inf'))
]
assert len(splits) == 1
range_tracker = splits[0].source.get_range_tracker(
splits[0].start_position, splits[0].stop_position)
split_points_report = []
for _ in splits[0].source.read(range_tracker):
split_points_report.append(range_tracker.split_points())
# There are a total of three blocks. Each block has more than 10 records.
# When reading records of the first block, range_tracker.split_points()
# should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
self.assertEquals(
split_points_report[:10],
[(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)
# When reading records of last block, range_tracker.split_points() should
# return (2, 1)
self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
def test_read_without_splitting_compressed_deflate(self):
file_name = self._write_data(codec='deflate')
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
def test_read_with_splitting_compressed_deflate(self):
file_name = self._write_data(codec='deflate')
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
@unittest.skipIf(snappy is None, 'snappy not installed.')
def test_read_without_splitting_compressed_snappy(self):
file_name = self._write_data(codec='snappy')
expected_result = self.RECORDS
self._run_avro_test(file_name, None, False, expected_result)
@unittest.skipIf(snappy is None, 'snappy not installed.')
def test_read_with_splitting_compressed_snappy(self):
file_name = self._write_data(codec='snappy')
expected_result = self.RECORDS
self._run_avro_test(file_name, 100, True, expected_result)
def test_read_without_splitting_pattern(self):
pattern = self._write_pattern(3)
expected_result = self.RECORDS * 3
self._run_avro_test(pattern, None, False, expected_result)
def test_read_with_splitting_pattern(self):
pattern = self._write_pattern(3)
expected_result = self.RECORDS * 3
self._run_avro_test(pattern, 100, True, expected_result)
def test_dynamic_work_rebalancing_exhaustive(self):
# Adjusting block size so that we can perform a exhaustive dynamic
# work rebalancing test that completes within an acceptable amount of time.
old_sync_interval = avro.datafile.SYNC_INTERVAL
try:
avro.datafile.SYNC_INTERVAL = 2
file_name = self._write_data(count=5)
source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
splits = [split
for split in source.split(desired_bundle_size=float('inf'))]
assert len(splits) == 1
source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
finally:
avro.datafile.SYNC_INTERVAL = old_sync_interval
def test_corrupted_file(self):
file_name = self._write_data()
with open(file_name, 'rb') as f:
data = f.read()
# Corrupt the last character of the file which is also the last character of
# the last sync_marker.
last_char_index = len(data) - 1
corrupted_data = data[:last_char_index]
corrupted_data += 'A' if data[last_char_index] == 'B' else 'B'
with tempfile.NamedTemporaryFile(
delete=False, prefix=tempfile.template) as f:
f.write(corrupted_data)
corrupted_file_name = f.name
source = _create_avro_source(
corrupted_file_name, use_fastavro=self.use_fastavro)
with self.assertRaises(ValueError) as exn:
source_test_utils.read_from_source(source, None, None)
self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
def test_read_from_avro(self):
path = self._write_data()
with TestPipeline() as p:
assert_that(
p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro),
equal_to(self.RECORDS))
def test_read_all_from_avro_single_file(self):
path = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS))
def test_read_all_from_avro_many_single_files(self):
path1 = self._write_data()
path2 = self._write_data()
path3 = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path1, path2, path3]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 3))
def test_read_all_from_avro_file_pattern(self):
file_pattern = self._write_pattern(5)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 5))
def test_read_all_from_avro_many_file_patterns(self):
file_pattern1 = self._write_pattern(5)
file_pattern2 = self._write_pattern(2)
file_pattern3 = self._write_pattern(3)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern1, file_pattern2, file_pattern3]) \
| avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro),
equal_to(self.RECORDS * 10))
def test_sink_transform(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p \
| beam.Create(self.RECORDS) \
| avroio.WriteToAvro(path, self.SCHEMA, use_fastavro=self.use_fastavro)
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \
| beam.Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
@unittest.skipIf(snappy is None, 'snappy not installed.')
def test_sink_transform_snappy(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p \
| beam.Create(self.RECORDS) \
| avroio.WriteToAvro(
path,
self.SCHEMA,
codec='snappy',
use_fastavro=self.use_fastavro)
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \
| beam.Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
class TestFastAvro(TestAvro):
def __init__(self, methodName='runTest'):
super(TestFastAvro, self).__init__(methodName)
self.use_fastavro = True
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()