| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| # pytype: skip-file |
| |
| import binascii |
| import glob |
| import gzip |
| import io |
| import json |
| import logging |
| import os |
| import pickle |
| import random |
| import re |
| import shutil |
| import tempfile |
| import unittest |
| import zlib |
| from datetime import datetime |
| |
| import pytz |
| |
| import apache_beam as beam |
| from apache_beam import Create |
| from apache_beam import coders |
| from apache_beam.io.filesystem import CompressionTypes |
| from apache_beam.io.tfrecordio import ReadAllFromTFRecord |
| from apache_beam.io.tfrecordio import ReadFromTFRecord |
| from apache_beam.io.tfrecordio import WriteToTFRecord |
| from apache_beam.io.tfrecordio import _TFRecordSink |
| from apache_beam.io.tfrecordio import _TFRecordUtil |
| from apache_beam.testing.test_pipeline import TestPipeline |
| from apache_beam.testing.test_stream import TestStream |
| from apache_beam.testing.test_utils import TempDir |
| from apache_beam.testing.util import assert_that |
| from apache_beam.testing.util import equal_to |
| from apache_beam.transforms.util import LogElements |
| |
| try: |
| import tensorflow.compat.v1 as tf # pylint: disable=import-error |
| except ImportError: |
| try: |
| import tensorflow as tf # pylint: disable=import-error |
| except ImportError: |
| tf = None # pylint: disable=invalid-name |
| logging.warning('Tensorflow is not installed, so skipping some tests.') |
| |
| try: |
| import crcmod |
| except ImportError: |
| crcmod = None |
| |
| # Created by running following code in python: |
| # >>> import tensorflow as tf |
| # >>> import base64 |
| # >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') |
| # >>> writer.write(b'foo') |
| # >>> writer.close() |
| # >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: |
| # ... data = base64.b64encode(f.read()) |
| # ... print(data) |
| FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g==' |
| |
| # Same as above but containing two records [b'foo', b'bar'] |
| FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=' |
| |
| |
| def _write_file(path, base64_records): |
| record = binascii.a2b_base64(base64_records) |
| with open(path, 'wb') as f: |
| f.write(record) |
| |
| |
| def _write_file_deflate(path, base64_records): |
| record = binascii.a2b_base64(base64_records) |
| with open(path, 'wb') as f: |
| f.write(zlib.compress(record)) |
| |
| |
| def _write_file_gzip(path, base64_records): |
| record = binascii.a2b_base64(base64_records) |
| with gzip.GzipFile(path, 'wb') as f: |
| f.write(record) |
| |
| |
| class TestTFRecordUtil(unittest.TestCase): |
| def setUp(self): |
| self.record = binascii.a2b_base64(FOO_RECORD_BASE64) |
| |
| def _as_file_handle(self, contents): |
| result = io.BytesIO() |
| result.write(contents) |
| result.seek(0) |
| return result |
| |
| def _increment_value_at_index(self, value, index): |
| l = list(value) |
| l[index] = l[index] + 1 |
| return bytes(l) |
| |
| def _test_error(self, record, error_text): |
| with self.assertRaisesRegex(ValueError, re.escape(error_text)): |
| _TFRecordUtil.read_record(self._as_file_handle(record)) |
| |
| def test_masked_crc32c(self): |
| self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32)) |
| self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32)) |
| self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo')) |
| self.assertEqual( |
| 0xe4999b0, |
| _TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00')) |
| |
| @unittest.skipIf(crcmod is None, 'crcmod not installed.') |
| def test_masked_crc32c_crcmod(self): |
| crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') |
| self.assertEqual( |
| 0xfd7fffa, |
| _TFRecordUtil._masked_crc32c(b'\x00' * 32, crc32c_fn=crc32c_fn)) |
| self.assertEqual( |
| 0xf909b029, |
| _TFRecordUtil._masked_crc32c(b'\xff' * 32, crc32c_fn=crc32c_fn)) |
| self.assertEqual( |
| 0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo', crc32c_fn=crc32c_fn)) |
| self.assertEqual( |
| 0xe4999b0, |
| _TFRecordUtil._masked_crc32c( |
| b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn)) |
| |
| def test_write_record(self): |
| file_handle = io.BytesIO() |
| _TFRecordUtil.write_record(file_handle, b'foo') |
| self.assertEqual(self.record, file_handle.getvalue()) |
| |
| def test_read_record(self): |
| actual = _TFRecordUtil.read_record(self._as_file_handle(self.record)) |
| self.assertEqual(b'foo', actual) |
| |
| def test_read_record_invalid_record(self): |
| self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes') |
| |
| def test_read_record_invalid_length_mask(self): |
| record = self._increment_value_at_index(self.record, 9) |
| self._test_error(record, 'Mismatch of length mask') |
| |
| def test_read_record_invalid_data_mask(self): |
| record = self._increment_value_at_index(self.record, 16) |
| self._test_error(record, 'Mismatch of data mask') |
| |
| def test_compatibility_read_write(self): |
| for record in [b'', b'blah', b'another blah']: |
| file_handle = io.BytesIO() |
| _TFRecordUtil.write_record(file_handle, record) |
| file_handle.seek(0) |
| actual = _TFRecordUtil.read_record(file_handle) |
| self.assertEqual(record, actual) |
| |
| |
| class TestTFRecordSink(unittest.TestCase): |
| def _write_lines(self, sink, path, lines): |
| f = sink.open(path) |
| for l in lines: |
| sink.write_record(f, l) |
| sink.close(f) |
| |
| def test_write_record_single(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| record = binascii.a2b_base64(FOO_RECORD_BASE64) |
| sink = _TFRecordSink( |
| path, |
| coder=coders.BytesCoder(), |
| file_name_suffix='', |
| num_shards=0, |
| shard_name_template=None, |
| compression_type=CompressionTypes.UNCOMPRESSED) |
| self._write_lines(sink, path, [b'foo']) |
| |
| with open(path, 'rb') as f: |
| self.assertEqual(f.read(), record) |
| |
| def test_write_record_multiple(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) |
| sink = _TFRecordSink( |
| path, |
| coder=coders.BytesCoder(), |
| file_name_suffix='', |
| num_shards=0, |
| shard_name_template=None, |
| compression_type=CompressionTypes.UNCOMPRESSED) |
| self._write_lines(sink, path, [b'foo', b'bar']) |
| |
| with open(path, 'rb') as f: |
| self.assertEqual(f.read(), record) |
| |
| |
| @unittest.skipIf(tf is None, 'tensorflow not installed.') |
| class TestWriteToTFRecord(TestTFRecordSink): |
| def test_write_record_gzip(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| with TestPipeline() as p: |
| input_data = [b'foo', b'bar'] |
| _ = p | beam.Create(input_data) | WriteToTFRecord( |
| file_path_prefix, compression_type=CompressionTypes.GZIP) |
| |
| actual = [] |
| file_name = glob.glob(file_path_prefix + '-*')[0] |
| for r in tf.python_io.tf_record_iterator( |
| file_name, |
| options=tf.python_io.TFRecordOptions( |
| tf.python_io.TFRecordCompressionType.GZIP)): |
| actual.append(r) |
| self.assertEqual(sorted(actual), sorted(input_data)) |
| |
| def test_write_record_auto(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| with TestPipeline() as p: |
| input_data = [b'foo', b'bar'] |
| _ = p | beam.Create(input_data) | WriteToTFRecord( |
| file_path_prefix, file_name_suffix='.gz') |
| |
| actual = [] |
| file_name = glob.glob(file_path_prefix + '-*.gz')[0] |
| for r in tf.python_io.tf_record_iterator( |
| file_name, |
| options=tf.python_io.TFRecordOptions( |
| tf.python_io.TFRecordCompressionType.GZIP)): |
| actual.append(r) |
| self.assertEqual(sorted(actual), sorted(input_data)) |
| |
| |
| class TestReadFromTFRecord(unittest.TestCase): |
| def test_process_single(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file(path, FOO_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord( |
| path, |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO, |
| validate=True)) |
| assert_that(result, equal_to([b'foo'])) |
| |
| def test_process_multiple(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord( |
| path, |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO, |
| validate=True)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_deflate(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file_deflate(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord( |
| path, |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.DEFLATE, |
| validate=True)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_gzip_with_coder(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord( |
| path, |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.GZIP, |
| validate=True)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_gzip_without_coder(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord(path, compression_type=CompressionTypes.GZIP)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_auto(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result.gz') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord( |
| path, |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO, |
| validate=True)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_gzip_auto(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result.gz') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | ReadFromTFRecord(path, compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| |
| class TestReadAllFromTFRecord(unittest.TestCase): |
| def _write_glob(self, temp_dir, suffix, include_empty=False): |
| for _ in range(3): |
| path = temp_dir.create_temp_file(suffix) |
| _write_file(path, FOO_BAR_RECORD_BASE64) |
| if include_empty: |
| path = temp_dir.create_temp_file(suffix) |
| _write_file(path, '') |
| |
| def test_process_single(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file(path, FOO_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([path]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo'])) |
| |
| def test_process_multiple(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([path]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_with_filename(self): |
| with TempDir() as temp_dir: |
| num_files = 3 |
| files = [] |
| for i in range(num_files): |
| path = temp_dir.create_temp_file('result%s' % i) |
| _write_file(path, FOO_BAR_RECORD_BASE64) |
| files.append(path) |
| content = [b'foo', b'bar'] |
| expected = [(file, line) for file in files for line in content] |
| pattern = temp_dir.get_path() + os.path.sep + '*' |
| |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([pattern]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO, |
| with_filename=True)) |
| assert_that(result, equal_to(expected)) |
| |
| def test_process_glob(self): |
| with TempDir() as temp_dir: |
| self._write_glob(temp_dir, 'result') |
| glob = temp_dir.get_path() + os.path.sep + '*result' |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([glob]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'] * 3)) |
| |
| def test_process_glob_with_empty_file(self): |
| with TempDir() as temp_dir: |
| self._write_glob(temp_dir, 'result', include_empty=True) |
| glob = temp_dir.get_path() + os.path.sep + '*result' |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([glob]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'] * 3)) |
| |
| def test_process_multiple_globs(self): |
| with TempDir() as temp_dir: |
| globs = [] |
| for i in range(3): |
| suffix = 'result' + str(i) |
| self._write_glob(temp_dir, suffix) |
| globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix) |
| |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create(globs) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'] * 9)) |
| |
| def test_process_deflate(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file_deflate(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([path]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.DEFLATE)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_gzip(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([path]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.GZIP)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| def test_process_auto(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result.gz') |
| _write_file_gzip(path, FOO_BAR_RECORD_BASE64) |
| with TestPipeline() as p: |
| result = ( |
| p |
| | Create([path]) |
| | ReadAllFromTFRecord( |
| coder=coders.BytesCoder(), |
| compression_type=CompressionTypes.AUTO)) |
| assert_that(result, equal_to([b'foo', b'bar'])) |
| |
| |
| class TestEnd2EndWriteAndRead(unittest.TestCase): |
| def create_inputs(self): |
| input_array = [[random.random() - 0.5 for _ in range(15)] |
| for _ in range(12)] |
| memfile = io.BytesIO() |
| pickle.dump(input_array, memfile) |
| return memfile.getvalue() |
| |
| def test_end2end(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| |
| # Generate a TFRecord file. |
| with TestPipeline() as p: |
| expected_data = [self.create_inputs() for _ in range(0, 10)] |
| _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) |
| |
| # Read the file back and compare. |
| with TestPipeline() as p: |
| actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') |
| assert_that(actual_data, equal_to(expected_data)) |
| |
| def test_end2end_auto_compression(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| |
| # Generate a TFRecord file. |
| with TestPipeline() as p: |
| expected_data = [self.create_inputs() for _ in range(0, 10)] |
| _ = p | beam.Create(expected_data) | WriteToTFRecord( |
| file_path_prefix, file_name_suffix='.gz') |
| |
| # Read the file back and compare. |
| with TestPipeline() as p: |
| actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') |
| assert_that(actual_data, equal_to(expected_data)) |
| |
| def test_end2end_auto_compression_unsharded(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| |
| # Generate a TFRecord file. |
| with TestPipeline() as p: |
| expected_data = [self.create_inputs() for _ in range(0, 10)] |
| _ = p | beam.Create(expected_data) | WriteToTFRecord( |
| file_path_prefix + '.gz', shard_name_template='') |
| |
| # Read the file back and compare. |
| with TestPipeline() as p: |
| actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') |
| assert_that(actual_data, equal_to(expected_data)) |
| |
| @unittest.skipIf(tf is None, 'tensorflow not installed.') |
| def test_end2end_example_proto(self): |
| with TempDir() as temp_dir: |
| file_path_prefix = temp_dir.create_temp_file('result') |
| |
| example = tf.train.Example() |
| example.features.feature['int'].int64_list.value.extend(list(range(3))) |
| example.features.feature['bytes'].bytes_list.value.extend( |
| [b'foo', b'bar']) |
| |
| with TestPipeline() as p: |
| _ = p | beam.Create([example]) | WriteToTFRecord( |
| file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) |
| |
| # Read the file back and compare. |
| with TestPipeline() as p: |
| actual_data = ( |
| p | ReadFromTFRecord( |
| file_path_prefix + '-*', |
| coder=beam.coders.ProtoCoder(example.__class__))) |
| assert_that(actual_data, equal_to([example])) |
| |
| def test_end2end_read_write_read(self): |
| with TempDir() as temp_dir: |
| path = temp_dir.create_temp_file('result') |
| with TestPipeline() as p: |
| # Initial read to validate the pipeline doesn't fail before the file is |
| # created. |
| _ = p | ReadFromTFRecord(path + '-*', validate=False) |
| expected_data = [self.create_inputs() for _ in range(0, 10)] |
| _ = p | beam.Create(expected_data) | WriteToTFRecord( |
| path, file_name_suffix='.gz') |
| |
| # Read the file back and compare. |
| with TestPipeline() as p: |
| actual_data = p | ReadFromTFRecord(path + '-*', validate=True) |
| assert_that(actual_data, equal_to(expected_data)) |
| |
| |
| class GenerateEvent(beam.PTransform): |
| @staticmethod |
| def sample_data(): |
| return GenerateEvent() |
| |
| def expand(self, input): |
| elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] |
| elem = elemlist |
| return ( |
| input |
| | TestStream().add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 1, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 2, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 3, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 4, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| advance_watermark_to( |
| datetime(2021, 3, 1, 0, 0, 5, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 5, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 6, |
| 0, tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 7, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 8, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 9, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| advance_watermark_to( |
| datetime(2021, 3, 1, 0, 0, 10, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 10, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 11, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 12, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 13, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 14, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| advance_watermark_to( |
| datetime(2021, 3, 1, 0, 0, 15, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 15, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 16, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 17, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 18, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 19, 0, |
| tzinfo=pytz.UTC).timestamp()). |
| advance_watermark_to( |
| datetime(2021, 3, 1, 0, 0, 20, 0, |
| tzinfo=pytz.UTC).timestamp()).add_elements( |
| elements=elem, |
| event_timestamp=datetime( |
| 2021, 3, 1, 0, 0, 20, 0, |
| tzinfo=pytz.UTC).timestamp()).advance_watermark_to( |
| datetime( |
| 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). |
| timestamp()).advance_watermark_to_infinity()) |
| |
| |
| class WriteStreamingTest(unittest.TestCase): |
| def setUp(self): |
| super().setUp() |
| self.tempdir = tempfile.mkdtemp() |
| |
| def tearDown(self): |
| if os.path.exists(self.tempdir): |
| shutil.rmtree(self.tempdir) |
| |
| def test_write_streaming_2_shards_default_shard_name_template( |
| self, num_shards=2): |
| with TestPipeline() as p: |
| output = ( |
| p |
| | GenerateEvent.sample_data() |
| | 'User windowing' >> beam.transforms.core.WindowInto( |
| beam.transforms.window.FixedWindows(60), |
| trigger=beam.transforms.trigger.AfterWatermark(), |
| accumulation_mode=beam.transforms.trigger.AccumulationMode. |
| DISCARDING, |
| allowed_lateness=beam.utils.timestamp.Duration(seconds=0)) |
| | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) |
| #TFrecordIO |
| output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( |
| file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", |
| file_name_suffix=".tfrecord", |
| num_shards=num_shards, |
| ) |
| _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( |
| prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) |
| |
| # Regex to match the expected windowed file pattern |
| # Example: |
| # ouput_WriteToTFRecord-[1614556800.0, 1614556805.0)-00000-of-00002.tfrecord |
| # It captures: window_interval, shard_num, total_shards |
| pattern_string = ( |
| r'.*-\[(?P<window_start>[\d\.]+), ' |
| r'(?P<window_end>[\d\.]+|Infinity)\)-' |
| r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.tfrecord$') |
| pattern = re.compile(pattern_string) |
| file_names = [] |
| for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): |
| match = pattern.match(file_name) |
| self.assertIsNotNone( |
| match, f"File name {file_name} did not match expected pattern.") |
| if match: |
| file_names.append(file_name) |
| print("Found files matching expected pattern:", file_names) |
| self.assertEqual( |
| len(file_names), |
| num_shards, |
| "expected %d files, but got: %d" % (num_shards, len(file_names))) |
| |
| def test_write_streaming_2_shards_custom_shard_name_template( |
| self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): |
| with TestPipeline() as p: |
| output = ( |
| p |
| | GenerateEvent.sample_data() |
| | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) |
| #TFrecordIO |
| output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( |
| file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", |
| file_name_suffix=".tfrecord", |
| shard_name_template=shard_name_template, |
| num_shards=num_shards, |
| triggering_frequency=60, |
| ) |
| _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( |
| prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) |
| |
| # Regex to match the expected windowed file pattern |
| # Example: |
| # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- |
| # 00000-of-00002.tfrecord |
| # It captures: window_interval, shard_num, total_shards |
| pattern_string = ( |
| r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' |
| r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' |
| r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.tfrecord$') |
| pattern = re.compile(pattern_string) |
| file_names = [] |
| for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): |
| match = pattern.match(file_name) |
| self.assertIsNotNone( |
| match, f"File name {file_name} did not match expected pattern.") |
| if match: |
| file_names.append(file_name) |
| print("Found files matching expected pattern:", file_names) |
| self.assertEqual( |
| len(file_names), |
| num_shards, |
| "expected %d files, but got: %d" % (num_shards, len(file_names))) |
| |
| def test_write_streaming_2_shards_custom_shard_name_template_5s_window( |
| self, |
| num_shards=2, |
| shard_name_template='-V-SSSSS-of-NNNNN', |
| triggering_frequency=5): |
| with TestPipeline() as p: |
| output = ( |
| p |
| | GenerateEvent.sample_data() |
| | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) |
| #TFrecordIO |
| output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( |
| file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", |
| file_name_suffix=".tfrecord", |
| shard_name_template=shard_name_template, |
| num_shards=num_shards, |
| triggering_frequency=triggering_frequency, |
| ) |
| _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( |
| prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) |
| |
| # Regex to match the expected windowed file pattern |
| # Example: |
| # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- |
| # 00000-of-00002.tfrecord |
| # It captures: window_interval, shard_num, total_shards |
| pattern_string = ( |
| r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' |
| r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' |
| r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.tfrecord$') |
| pattern = re.compile(pattern_string) |
| file_names = [] |
| for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): |
| match = pattern.match(file_name) |
| self.assertIsNotNone( |
| match, f"File name {file_name} did not match expected pattern.") |
| if match: |
| file_names.append(file_name) |
| print("Found files matching expected pattern:", file_names) |
| # for 5s window size, the input should be processed by 5 windows with |
| # 2 shards per window |
| self.assertEqual( |
| len(file_names), |
| 10, |
| "expected %d files, but got: %d" % (num_shards, len(file_names))) |
| |
| |
| if __name__ == '__main__': |
| logging.getLogger().setLevel(logging.INFO) |
| unittest.main() |