| # | 
 | # Licensed to the Apache Software Foundation (ASF) under one or more | 
 | # contributor license agreements.  See the NOTICE file distributed with | 
 | # this work for additional information regarding copyright ownership. | 
 | # The ASF licenses this file to You under the Apache License, Version 2.0 | 
 | # (the "License"); you may not use this file except in compliance with | 
 | # the License.  You may obtain a copy of the License at | 
 | # | 
 | #    http://www.apache.org/licenses/LICENSE-2.0 | 
 | # | 
 | # Unless required by applicable law or agreed to in writing, software | 
 | # distributed under the License is distributed on an "AS IS" BASIS, | 
 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | # See the License for the specific language governing permissions and | 
 | # limitations under the License. | 
 | # | 
 | # pytype: skip-file | 
 |  | 
 | from __future__ import absolute_import | 
 | from __future__ import division | 
 |  | 
 | import json | 
 | import logging | 
 | import math | 
 | import os | 
 | import tempfile | 
 | import unittest | 
 | from builtins import range | 
 | from typing import List | 
 | import sys | 
 |  | 
 | # patches unittest.TestCase to be python3 compatible | 
 | import future.tests.base  # pylint: disable=unused-import | 
 | import hamcrest as hc | 
 |  | 
 | import avro | 
 | import avro.datafile | 
 | from avro.datafile import DataFileWriter | 
 | from avro.io import DatumWriter | 
 | from fastavro.schema import parse_schema | 
 | from fastavro import writer | 
 |  | 
 | # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports | 
 | try: | 
 |   from avro.schema import Parse  # avro-python3 library for python3 | 
 | except ImportError: | 
 |   from avro.schema import parse as Parse  # avro library for python2 | 
 | # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports | 
 |  | 
 | import apache_beam as beam | 
 | from apache_beam import Create | 
 | from apache_beam.io import avroio | 
 | from apache_beam.io import filebasedsource | 
 | from apache_beam.io import iobase | 
 | from apache_beam.io import source_test_utils | 
 | from apache_beam.io.avroio import _create_avro_sink  # For testing | 
 | from apache_beam.io.avroio import _create_avro_source  # For testing | 
 | from apache_beam.testing.test_pipeline import TestPipeline | 
 | from apache_beam.testing.util import assert_that | 
 | from apache_beam.testing.util import equal_to | 
 | from apache_beam.transforms.display import DisplayData | 
 | from apache_beam.transforms.display_test import DisplayDataItemMatcher | 
 |  | 
 | # Import snappy optionally; some tests will be skipped when import fails. | 
 | try: | 
 |   import snappy  # pylint: disable=import-error | 
 | except ImportError: | 
 |   snappy = None  # pylint: disable=invalid-name | 
 |   logging.warning('python-snappy is not installed; some tests will be skipped.') | 
 |  | 
 | RECORDS = [{ | 
 |     'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue' | 
 | }, { | 
 |     'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green' | 
 | }, { | 
 |     'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown' | 
 | }, { | 
 |     'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue' | 
 | }, { | 
 |     'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red' | 
 | }, { | 
 |     'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green' | 
 | }] | 
 |  | 
 |  | 
 | class AvroBase(object): | 
 |  | 
 |   _temp_files = []  # type: List[str] | 
 |  | 
 |   def __init__(self, methodName='runTest'): | 
 |     super(AvroBase, self).__init__(methodName) | 
 |     self.RECORDS = RECORDS | 
 |     self.SCHEMA_STRING = ''' | 
 |           {"namespace": "example.avro", | 
 |            "type": "record", | 
 |            "name": "User", | 
 |            "fields": [ | 
 |                {"name": "name", "type": "string"}, | 
 |                {"name": "favorite_number",  "type": ["int", "null"]}, | 
 |                {"name": "favorite_color", "type": ["string", "null"]} | 
 |            ] | 
 |           } | 
 |           ''' | 
 |  | 
 |   @classmethod | 
 |   def setUpClass(cls): | 
 |     # Method has been renamed in Python 3 | 
 |     if sys.version_info[0] < 3: | 
 |       cls.assertCountEqual = cls.assertItemsEqual | 
 |  | 
 |   def setUp(self): | 
 |     # Reducing the size of thread pools. Without this test execution may fail in | 
 |     # environments with limited amount of resources. | 
 |     filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 | 
 |  | 
 |   def tearDown(self): | 
 |     for path in self._temp_files: | 
 |       if os.path.exists(path): | 
 |         os.remove(path) | 
 |     self._temp_files = [] | 
 |  | 
 |   def _write_data(self, directory, prefix, codec, count, sync_interval): | 
 |     raise NotImplementedError | 
 |  | 
 |   def _write_pattern(self, num_files): | 
 |     assert num_files > 0 | 
 |     temp_dir = tempfile.mkdtemp() | 
 |  | 
 |     file_name = None | 
 |     for _ in range(num_files): | 
 |       file_name = self._write_data(directory=temp_dir, prefix='mytemp') | 
 |  | 
 |     assert file_name | 
 |     file_name_prefix = file_name[:file_name.rfind(os.path.sep)] | 
 |     return file_name_prefix + os.path.sep + 'mytemp*' | 
 |  | 
 |   def _run_avro_test( | 
 |       self, pattern, desired_bundle_size, perform_splitting, expected_result): | 
 |     source = _create_avro_source(pattern, use_fastavro=self.use_fastavro) | 
 |  | 
 |     if perform_splitting: | 
 |       assert desired_bundle_size | 
 |       splits = [ | 
 |           split | 
 |           for split in source.split(desired_bundle_size=desired_bundle_size) | 
 |       ] | 
 |       if len(splits) < 2: | 
 |         raise ValueError( | 
 |             'Test is trivial. Please adjust it so that at least ' | 
 |             'two splits get generated') | 
 |  | 
 |       sources_info = [(split.source, split.start_position, split.stop_position) | 
 |                       for split in splits] | 
 |       source_test_utils.assert_sources_equal_reference_source( | 
 |           (source, None, None), sources_info) | 
 |     else: | 
 |       read_records = source_test_utils.read_from_source(source, None, None) | 
 |       self.assertCountEqual(expected_result, read_records) | 
 |  | 
 |   def test_read_without_splitting(self): | 
 |     file_name = self._write_data() | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, None, False, expected_result) | 
 |  | 
 |   def test_read_with_splitting(self): | 
 |     file_name = self._write_data() | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, 100, True, expected_result) | 
 |  | 
 |   def test_source_display_data(self): | 
 |     file_name = 'some_avro_source' | 
 |     source = \ | 
 |         _create_avro_source( | 
 |             file_name, | 
 |             validate=False, | 
 |             use_fastavro=self.use_fastavro | 
 |         ) | 
 |     dd = DisplayData.create_from(source) | 
 |  | 
 |     # No extra avro parameters for AvroSource. | 
 |     expected_items = [ | 
 |         DisplayDataItemMatcher('compression', 'auto'), | 
 |         DisplayDataItemMatcher('file_pattern', file_name) | 
 |     ] | 
 |     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) | 
 |  | 
 |   def test_read_display_data(self): | 
 |     file_name = 'some_avro_source' | 
 |     read = \ | 
 |         avroio.ReadFromAvro( | 
 |             file_name, | 
 |             validate=False, | 
 |             use_fastavro=self.use_fastavro) | 
 |     dd = DisplayData.create_from(read) | 
 |  | 
 |     # No extra avro parameters for AvroSource. | 
 |     expected_items = [ | 
 |         DisplayDataItemMatcher('compression', 'auto'), | 
 |         DisplayDataItemMatcher('file_pattern', file_name) | 
 |     ] | 
 |     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) | 
 |  | 
 |   def test_sink_display_data(self): | 
 |     file_name = 'some_avro_sink' | 
 |     sink = _create_avro_sink( | 
 |         file_name, | 
 |         self.SCHEMA, | 
 |         'null', | 
 |         '.end', | 
 |         0, | 
 |         None, | 
 |         'application/x-avro', | 
 |         use_fastavro=self.use_fastavro) | 
 |     dd = DisplayData.create_from(sink) | 
 |  | 
 |     expected_items = [ | 
 |         DisplayDataItemMatcher('schema', str(self.SCHEMA)), | 
 |         DisplayDataItemMatcher( | 
 |             'file_pattern', | 
 |             'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'), | 
 |         DisplayDataItemMatcher('codec', 'null'), | 
 |         DisplayDataItemMatcher('compression', 'uncompressed') | 
 |     ] | 
 |  | 
 |     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) | 
 |  | 
 |   def test_write_display_data(self): | 
 |     file_name = 'some_avro_sink' | 
 |     write = avroio.WriteToAvro( | 
 |         file_name, self.SCHEMA, use_fastavro=self.use_fastavro) | 
 |     dd = DisplayData.create_from(write) | 
 |     expected_items = [ | 
 |         DisplayDataItemMatcher('schema', str(self.SCHEMA)), | 
 |         DisplayDataItemMatcher( | 
 |             'file_pattern', | 
 |             'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'), | 
 |         DisplayDataItemMatcher('codec', 'deflate'), | 
 |         DisplayDataItemMatcher('compression', 'uncompressed') | 
 |     ] | 
 |     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) | 
 |  | 
 |   def test_read_reentrant_without_splitting(self): | 
 |     file_name = self._write_data() | 
 |     source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) | 
 |     source_test_utils.assert_reentrant_reads_succeed((source, None, None)) | 
 |  | 
 |   def test_read_reantrant_with_splitting(self): | 
 |     file_name = self._write_data() | 
 |     source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) | 
 |     splits = [split for split in source.split(desired_bundle_size=100000)] | 
 |     assert len(splits) == 1 | 
 |     source_test_utils.assert_reentrant_reads_succeed( | 
 |         (splits[0].source, splits[0].start_position, splits[0].stop_position)) | 
 |  | 
 |   def test_read_without_splitting_multiple_blocks(self): | 
 |     file_name = self._write_data(count=12000) | 
 |     expected_result = self.RECORDS * 2000 | 
 |     self._run_avro_test(file_name, None, False, expected_result) | 
 |  | 
 |   def test_read_with_splitting_multiple_blocks(self): | 
 |     file_name = self._write_data(count=12000) | 
 |     expected_result = self.RECORDS * 2000 | 
 |     self._run_avro_test(file_name, 10000, True, expected_result) | 
 |  | 
 |   def test_split_points(self): | 
 |     num_records = 12000 | 
 |     sync_interval = 16000 | 
 |     file_name = self._write_data(count=num_records, sync_interval=sync_interval) | 
 |  | 
 |     source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) | 
 |  | 
 |     splits = [split for split in source.split(desired_bundle_size=float('inf'))] | 
 |     assert len(splits) == 1 | 
 |     range_tracker = splits[0].source.get_range_tracker( | 
 |         splits[0].start_position, splits[0].stop_position) | 
 |  | 
 |     split_points_report = [] | 
 |  | 
 |     for _ in splits[0].source.read(range_tracker): | 
 |       split_points_report.append(range_tracker.split_points()) | 
 |     # There will be a total of num_blocks in the generated test file, | 
 |     # proportional to number of records in the file divided by syncronization | 
 |     # interval used by avro during write. Each block has more than 10 records. | 
 |     num_blocks = int(math.ceil(14.5 * num_records / sync_interval)) | 
 |     assert num_blocks > 1 | 
 |     # When reading records of the first block, range_tracker.split_points() | 
 |     # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) | 
 |     self.assertEqual( | 
 |         split_points_report[:10], | 
 |         [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) | 
 |  | 
 |     # When reading records of last block, range_tracker.split_points() should | 
 |     # return (num_blocks - 1, 1) | 
 |     self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10) | 
 |  | 
 |   def test_read_without_splitting_compressed_deflate(self): | 
 |     file_name = self._write_data(codec='deflate') | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, None, False, expected_result) | 
 |  | 
 |   def test_read_with_splitting_compressed_deflate(self): | 
 |     file_name = self._write_data(codec='deflate') | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, 100, True, expected_result) | 
 |  | 
 |   @unittest.skipIf(snappy is None, 'python-snappy not installed.') | 
 |   def test_read_without_splitting_compressed_snappy(self): | 
 |     file_name = self._write_data(codec='snappy') | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, None, False, expected_result) | 
 |  | 
 |   @unittest.skipIf(snappy is None, 'python-snappy not installed.') | 
 |   def test_read_with_splitting_compressed_snappy(self): | 
 |     file_name = self._write_data(codec='snappy') | 
 |     expected_result = self.RECORDS | 
 |     self._run_avro_test(file_name, 100, True, expected_result) | 
 |  | 
 |   def test_read_without_splitting_pattern(self): | 
 |     pattern = self._write_pattern(3) | 
 |     expected_result = self.RECORDS * 3 | 
 |     self._run_avro_test(pattern, None, False, expected_result) | 
 |  | 
 |   def test_read_with_splitting_pattern(self): | 
 |     pattern = self._write_pattern(3) | 
 |     expected_result = self.RECORDS * 3 | 
 |     self._run_avro_test(pattern, 100, True, expected_result) | 
 |  | 
 |   def test_dynamic_work_rebalancing_exhaustive(self): | 
 |     def compare_split_points(file_name): | 
 |       source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) | 
 |       splits = [ | 
 |           split for split in source.split(desired_bundle_size=float('inf')) | 
 |       ] | 
 |       assert len(splits) == 1 | 
 |       source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source) | 
 |  | 
 |     # Adjusting block size so that we can perform a exhaustive dynamic | 
 |     # work rebalancing test that completes within an acceptable amount of time. | 
 |     file_name = self._write_data(count=5, sync_interval=2) | 
 |  | 
 |     compare_split_points(file_name) | 
 |  | 
 |   def test_corrupted_file(self): | 
 |     file_name = self._write_data() | 
 |     with open(file_name, 'rb') as f: | 
 |       data = f.read() | 
 |  | 
 |     # Corrupt the last character of the file which is also the last character of | 
 |     # the last sync_marker. | 
 |     # https://avro.apache.org/docs/current/spec.html#Object+Container+Files | 
 |     corrupted_data = bytearray(data) | 
 |     corrupted_data[-1] = (corrupted_data[-1] + 1) % 256 | 
 |     with tempfile.NamedTemporaryFile(delete=False, | 
 |                                      prefix=tempfile.template) as f: | 
 |       f.write(corrupted_data) | 
 |       corrupted_file_name = f.name | 
 |  | 
 |     source = _create_avro_source( | 
 |         corrupted_file_name, use_fastavro=self.use_fastavro) | 
 |     with self.assertRaisesRegex(ValueError, r'expected sync marker'): | 
 |       source_test_utils.read_from_source(source, None, None) | 
 |  | 
 |   def test_read_from_avro(self): | 
 |     path = self._write_data() | 
 |     with TestPipeline() as p: | 
 |       assert_that( | 
 |           p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro), | 
 |           equal_to(self.RECORDS)) | 
 |  | 
 |   def test_read_all_from_avro_single_file(self): | 
 |     path = self._write_data() | 
 |     with TestPipeline() as p: | 
 |       assert_that( | 
 |           p \ | 
 |           | Create([path]) \ | 
 |           | avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro), | 
 |           equal_to(self.RECORDS)) | 
 |  | 
 |   def test_read_all_from_avro_many_single_files(self): | 
 |     path1 = self._write_data() | 
 |     path2 = self._write_data() | 
 |     path3 = self._write_data() | 
 |     with TestPipeline() as p: | 
 |       assert_that( | 
 |           p \ | 
 |           | Create([path1, path2, path3]) \ | 
 |           | avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro), | 
 |           equal_to(self.RECORDS * 3)) | 
 |  | 
 |   def test_read_all_from_avro_file_pattern(self): | 
 |     file_pattern = self._write_pattern(5) | 
 |     with TestPipeline() as p: | 
 |       assert_that( | 
 |           p \ | 
 |           | Create([file_pattern]) \ | 
 |           | avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro), | 
 |           equal_to(self.RECORDS * 5)) | 
 |  | 
 |   def test_read_all_from_avro_many_file_patterns(self): | 
 |     file_pattern1 = self._write_pattern(5) | 
 |     file_pattern2 = self._write_pattern(2) | 
 |     file_pattern3 = self._write_pattern(3) | 
 |     with TestPipeline() as p: | 
 |       assert_that( | 
 |           p \ | 
 |           | Create([file_pattern1, file_pattern2, file_pattern3]) \ | 
 |           | avroio.ReadAllFromAvro(use_fastavro=self.use_fastavro), | 
 |           equal_to(self.RECORDS * 10)) | 
 |  | 
 |   def test_sink_transform(self): | 
 |     with tempfile.NamedTemporaryFile() as dst: | 
 |       path = dst.name | 
 |       with TestPipeline() as p: | 
 |         # pylint: disable=expression-not-assigned | 
 |         p \ | 
 |         | beam.Create(self.RECORDS) \ | 
 |         | avroio.WriteToAvro(path, self.SCHEMA, use_fastavro=self.use_fastavro) | 
 |       with TestPipeline() as p: | 
 |         # json used for stable sortability | 
 |         readback = \ | 
 |             p \ | 
 |             | avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \ | 
 |             | beam.Map(json.dumps) | 
 |         assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) | 
 |  | 
 |   @unittest.skipIf(snappy is None, 'python-snappy not installed.') | 
 |   def test_sink_transform_snappy(self): | 
 |     with tempfile.NamedTemporaryFile() as dst: | 
 |       path = dst.name | 
 |       with TestPipeline() as p: | 
 |         # pylint: disable=expression-not-assigned | 
 |         p \ | 
 |         | beam.Create(self.RECORDS) \ | 
 |         | avroio.WriteToAvro( | 
 |             path, | 
 |             self.SCHEMA, | 
 |             codec='snappy', | 
 |             use_fastavro=self.use_fastavro) | 
 |       with TestPipeline() as p: | 
 |         # json used for stable sortability | 
 |         readback = \ | 
 |             p \ | 
 |             | avroio.ReadFromAvro(path + '*', use_fastavro=self.use_fastavro) \ | 
 |             | beam.Map(json.dumps) | 
 |         assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) | 
 |  | 
 |  | 
 | @unittest.skipIf( | 
 |     sys.version_info[0] == 3 and os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1', | 
 |     'This test requires that Beam depends on avro-python3>=1.9 or newer. ' | 
 |     'See: BEAM-6522.') | 
 | class TestAvro(AvroBase, unittest.TestCase): | 
 |   def __init__(self, methodName='runTest'): | 
 |     super(TestAvro, self).__init__(methodName) | 
 |     self.use_fastavro = False | 
 |     self.SCHEMA = Parse(self.SCHEMA_STRING) | 
 |  | 
 |   def _write_data( | 
 |       self, | 
 |       directory=None, | 
 |       prefix=tempfile.template, | 
 |       codec='null', | 
 |       count=len(RECORDS), | 
 |       sync_interval=avro.datafile.SYNC_INTERVAL): | 
 |     old_sync_interval = avro.datafile.SYNC_INTERVAL | 
 |     try: | 
 |       avro.datafile.SYNC_INTERVAL = sync_interval | 
 |       with tempfile.NamedTemporaryFile(delete=False, | 
 |                                        dir=directory, | 
 |                                        prefix=prefix) as f: | 
 |         writer = DataFileWriter(f, DatumWriter(), self.SCHEMA, codec=codec) | 
 |         len_records = len(self.RECORDS) | 
 |         for i in range(count): | 
 |           writer.append(self.RECORDS[i % len_records]) | 
 |         writer.close() | 
 |         self._temp_files.append(f.name) | 
 |         return f.name | 
 |     finally: | 
 |       avro.datafile.SYNC_INTERVAL = old_sync_interval | 
 |  | 
 |  | 
 | class TestFastAvro(AvroBase, unittest.TestCase): | 
 |   def __init__(self, methodName='runTest'): | 
 |     super(TestFastAvro, self).__init__(methodName) | 
 |     self.use_fastavro = True | 
 |     self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING)) | 
 |  | 
 |   def _write_data( | 
 |       self, | 
 |       directory=None, | 
 |       prefix=tempfile.template, | 
 |       codec='null', | 
 |       count=len(RECORDS), | 
 |       **kwargs): | 
 |     all_records = self.RECORDS * \ | 
 |       (count // len(self.RECORDS)) + self.RECORDS[:(count % len(self.RECORDS))] | 
 |     with tempfile.NamedTemporaryFile(delete=False, | 
 |                                      dir=directory, | 
 |                                      prefix=prefix, | 
 |                                      mode='w+b') as f: | 
 |       writer(f, self.SCHEMA, all_records, codec=codec, **kwargs) | 
 |       self._temp_files.append(f.name) | 
 |     return f.name | 
 |  | 
 |  | 
 | if __name__ == '__main__': | 
 |   logging.getLogger().setLevel(logging.INFO) | 
 |   unittest.main() |