blob: c0a3c2dca446fd7212bc6d911a22b6f0cbd724a8 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import binascii
import glob
import gzip
import io
import logging
import os
import pickle
import random
import re
import sys
import unittest
import zlib
from builtins import range
import crcmod
import apache_beam as beam
from apache_beam import Create
from apache_beam import coders
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.tfrecordio import ReadAllFromTFRecord
from apache_beam.io.tfrecordio import ReadFromTFRecord
from apache_beam.io.tfrecordio import WriteToTFRecord
from apache_beam.io.tfrecordio import _TFRecordSink
from apache_beam.io.tfrecordio import _TFRecordUtil
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_utils import TempDir
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
try:
import tensorflow.compat.v1 as tf # pylint: disable=import-error
except ImportError:
try:
import tensorflow as tf # pylint: disable=import-error
except ImportError:
tf = None # pylint: disable=invalid-name
logging.warning('Tensorflow is not installed, so skipping some tests.')
# Created by running following code in python:
# >>> import tensorflow as tf
# >>> import base64
# >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
# >>> writer.write(b'foo')
# >>> writer.close()
# >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
# ... data = base64.b64encode(f.read())
# ... print(data)
FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g=='
# Same as above but containing two records [b'foo', b'bar']
FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
def _write_file(path, base64_records):
record = binascii.a2b_base64(base64_records)
with open(path, 'wb') as f:
f.write(record)
def _write_file_deflate(path, base64_records):
record = binascii.a2b_base64(base64_records)
with open(path, 'wb') as f:
f.write(zlib.compress(record))
def _write_file_gzip(path, base64_records):
record = binascii.a2b_base64(base64_records)
with gzip.GzipFile(path, 'wb') as f:
f.write(record)
class TestTFRecordUtil(unittest.TestCase):
def setUp(self):
self.record = binascii.a2b_base64(FOO_RECORD_BASE64)
def _as_file_handle(self, contents):
result = io.BytesIO()
result.write(contents)
result.seek(0)
return result
def _increment_value_at_index(self, value, index):
l = list(value)
if sys.version_info[0] <= 2:
l[index] = bytes(ord(l[index]) + 1)
return b"".join(l)
else:
l[index] = l[index] + 1
return bytes(l)
def _test_error(self, record, error_text):
with self.assertRaisesRegexp(ValueError, re.escape(error_text)):
_TFRecordUtil.read_record(self._as_file_handle(record))
def test_masked_crc32c(self):
self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32))
self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32))
self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo'))
self.assertEqual(
0xe4999b0,
_TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00'))
def test_masked_crc32c_crcmod(self):
crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
self.assertEqual(
0xfd7fffa,
_TFRecordUtil._masked_crc32c(
b'\x00' * 32, crc32c_fn=crc32c_fn))
self.assertEqual(
0xf909b029,
_TFRecordUtil._masked_crc32c(
b'\xff' * 32, crc32c_fn=crc32c_fn))
self.assertEqual(
0xfebe8a61, _TFRecordUtil._masked_crc32c(
b'foo', crc32c_fn=crc32c_fn))
self.assertEqual(
0xe4999b0,
_TFRecordUtil._masked_crc32c(
b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
def test_write_record(self):
file_handle = io.BytesIO()
_TFRecordUtil.write_record(file_handle, b'foo')
self.assertEqual(self.record, file_handle.getvalue())
def test_read_record(self):
actual = _TFRecordUtil.read_record(self._as_file_handle(self.record))
self.assertEqual(b'foo', actual)
def test_read_record_invalid_record(self):
self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes')
def test_read_record_invalid_length_mask(self):
record = self._increment_value_at_index(self.record, 9)
self._test_error(record, 'Mismatch of length mask')
def test_read_record_invalid_data_mask(self):
record = self._increment_value_at_index(self.record, 16)
self._test_error(record, 'Mismatch of data mask')
def test_compatibility_read_write(self):
for record in [b'', b'blah', b'another blah']:
file_handle = io.BytesIO()
_TFRecordUtil.write_record(file_handle, record)
file_handle.seek(0)
actual = _TFRecordUtil.read_record(file_handle)
self.assertEqual(record, actual)
class TestTFRecordSink(unittest.TestCase):
def _write_lines(self, sink, path, lines):
f = sink.open(path)
for l in lines:
sink.write_record(f, l)
sink.close(f)
def test_write_record_single(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
record = binascii.a2b_base64(FOO_RECORD_BASE64)
sink = _TFRecordSink(
path,
coder=coders.BytesCoder(),
file_name_suffix='',
num_shards=0,
shard_name_template=None,
compression_type=CompressionTypes.UNCOMPRESSED)
self._write_lines(sink, path, [b'foo'])
with open(path, 'rb') as f:
self.assertEqual(f.read(), record)
def test_write_record_multiple(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
sink = _TFRecordSink(
path,
coder=coders.BytesCoder(),
file_name_suffix='',
num_shards=0,
shard_name_template=None,
compression_type=CompressionTypes.UNCOMPRESSED)
self._write_lines(sink, path, [b'foo', b'bar'])
with open(path, 'rb') as f:
self.assertEqual(f.read(), record)
@unittest.skipIf(tf is None, 'tensorflow not installed.')
class TestWriteToTFRecord(TestTFRecordSink):
def test_write_record_gzip(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
with TestPipeline() as p:
input_data = [b'foo', b'bar']
_ = p | beam.Create(input_data) | WriteToTFRecord(
file_path_prefix, compression_type=CompressionTypes.GZIP)
actual = []
file_name = glob.glob(file_path_prefix + '-*')[0]
for r in tf.python_io.tf_record_iterator(
file_name, options=tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)):
actual.append(r)
self.assertEqual(actual, input_data)
def test_write_record_auto(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
with TestPipeline() as p:
input_data = [b'foo', b'bar']
_ = p | beam.Create(input_data) | WriteToTFRecord(
file_path_prefix, file_name_suffix='.gz')
actual = []
file_name = glob.glob(file_path_prefix + '-*.gz')[0]
for r in tf.python_io.tf_record_iterator(
file_name, options=tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)):
actual.append(r)
self.assertEqual(actual, input_data)
class TestReadFromTFRecord(unittest.TestCase):
def test_process_single(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file(path, FOO_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True))
assert_that(result, equal_to([b'foo']))
def test_process_multiple(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_deflate(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_deflate(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.DEFLATE,
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_gzip(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.GZIP,
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_auto(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result.gz')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_gzip(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path, compression_type=CompressionTypes.GZIP))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_gzip_auto(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result.gz')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| ReadFromTFRecord(
path, compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar']))
class TestReadAllFromTFRecord(unittest.TestCase):
def _write_glob(self, temp_dir, suffix):
for _ in range(3):
path = temp_dir.create_temp_file(suffix)
_write_file(path, FOO_BAR_RECORD_BASE64)
def test_process_single(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file(path, FOO_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| Create([path])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo']))
def test_process_multiple(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| Create([path])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_glob(self):
with TempDir() as temp_dir:
self._write_glob(temp_dir, 'result')
glob = temp_dir.get_path() + os.path.sep + '*result'
with TestPipeline() as p:
result = (p
| Create([glob])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar'] * 3))
def test_process_multiple_globs(self):
with TempDir() as temp_dir:
globs = []
for i in range(3):
suffix = 'result' + str(i)
self._write_glob(temp_dir, suffix)
globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix)
with TestPipeline() as p:
result = (p
| Create(globs)
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar'] * 9))
def test_process_deflate(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_deflate(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| Create([path])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.DEFLATE))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_gzip(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| Create([path])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.GZIP))
assert_that(result, equal_to([b'foo', b'bar']))
def test_process_auto(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result.gz')
_write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with TestPipeline() as p:
result = (p
| Create([path])
| ReadAllFromTFRecord(
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar']))
class TestEnd2EndWriteAndRead(unittest.TestCase):
def create_inputs(self):
input_array = [[random.random() - 0.5 for _ in range(15)]
for _ in range(12)]
memfile = io.BytesIO()
pickle.dump(input_array, memfile)
return memfile.getvalue()
def test_end2end(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
# Generate a TFRecord file.
with TestPipeline() as p:
expected_data = [self.create_inputs() for _ in range(0, 10)]
_ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix)
# Read the file back and compare.
with TestPipeline() as p:
actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
assert_that(actual_data, equal_to(expected_data))
def test_end2end_auto_compression(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
# Generate a TFRecord file.
with TestPipeline() as p:
expected_data = [self.create_inputs() for _ in range(0, 10)]
_ = p | beam.Create(expected_data) | WriteToTFRecord(
file_path_prefix, file_name_suffix='.gz')
# Read the file back and compare.
with TestPipeline() as p:
actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
assert_that(actual_data, equal_to(expected_data))
def test_end2end_auto_compression_unsharded(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
# Generate a TFRecord file.
with TestPipeline() as p:
expected_data = [self.create_inputs() for _ in range(0, 10)]
_ = p | beam.Create(expected_data) | WriteToTFRecord(
file_path_prefix + '.gz', shard_name_template='')
# Read the file back and compare.
with TestPipeline() as p:
actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz')
assert_that(actual_data, equal_to(expected_data))
@unittest.skipIf(tf is None, 'tensorflow not installed.')
def test_end2end_example_proto(self):
with TempDir() as temp_dir:
file_path_prefix = temp_dir.create_temp_file('result')
example = tf.train.Example()
example.features.feature['int'].int64_list.value.extend(list(range(3)))
example.features.feature['bytes'].bytes_list.value.extend(
[b'foo', b'bar'])
with TestPipeline() as p:
_ = p | beam.Create([example]) | WriteToTFRecord(
file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__))
# Read the file back and compare.
with TestPipeline() as p:
actual_data = (p | ReadFromTFRecord(
file_path_prefix + '-*',
coder=beam.coders.ProtoCoder(example.__class__)))
assert_that(actual_data, equal_to([example]))
def test_end2end_read_write_read(self):
with TempDir() as temp_dir:
path = temp_dir.create_temp_file('result')
with TestPipeline() as p:
# Initial read to validate the pipeline doesn't fail before the file is
# created.
_ = p | ReadFromTFRecord(path + '-*', validate=False)
expected_data = [self.create_inputs() for _ in range(0, 10)]
_ = p | beam.Create(expected_data) | WriteToTFRecord(
path, file_name_suffix='.gz')
# Read the file back and compare.
with TestPipeline() as p:
actual_data = p | ReadFromTFRecord(path+'-*', validate=True)
assert_that(actual_data, equal_to(expected_data))
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()