blob: ab7d2f5444a3ef79de572cdf14569c6b7723fd6d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""TFRecord sources and sinks."""
from __future__ import absolute_import
import codecs
import logging
import struct
from builtins import object
from functools import partial
import crcmod
from apache_beam import coders
from apache_beam.io import filebasedsink
from apache_beam.io.filebasedsource import FileBasedSource
from apache_beam.io.filebasedsource import ReadAllFiles
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.iobase import Read
from apache_beam.io.iobase import Write
from apache_beam.transforms import PTransform
__all__ = ['ReadFromTFRecord', 'WriteToTFRecord']
_LOGGER = logging.getLogger(__name__)
def _default_crc32c_fn(value):
"""Calculates crc32c of a bytes object using either snappy or crcmod."""
if not _default_crc32c_fn.fn:
try:
import snappy # pylint: disable=import-error
# Support multiple versions of python-snappy:
# https://github.com/andrix/python-snappy/pull/53
if getattr(snappy, '_crc32c', None):
_default_crc32c_fn.fn = snappy._crc32c # pylint: disable=protected-access
elif getattr(snappy, '_snappy', None):
_default_crc32c_fn.fn = snappy._snappy._crc32c # pylint: disable=protected-access
except ImportError:
pass
if not _default_crc32c_fn.fn:
_LOGGER.warning('Couldn\'t find python-snappy so the implementation of '
'_TFRecordUtil._masked_crc32c is not as fast as it could '
'be.')
_default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
return _default_crc32c_fn.fn(value)
_default_crc32c_fn.fn = None
class _TFRecordUtil(object):
"""Provides basic TFRecord encoding/decoding with consistency checks.
For detailed TFRecord format description see:
https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
Note that masks and length are represented in LittleEndian order.
"""
@classmethod
def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn):
"""Compute a masked crc32c checksum for a value.
Args:
value: A bytes object for which we compute the crc.
crc32c_fn: A function that can compute a crc32c.
This is a performance hook that also helps with testing. Callers are
not expected to make use of it directly.
Returns:
Masked crc32c checksum.
"""
crc = crc32c_fn(value)
return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff
@staticmethod
def encoded_num_bytes(record):
"""Return the number of bytes consumed by a record in its encoded form."""
# 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data)
return len(record) + 16
@classmethod
def write_record(cls, file_handle, value):
"""Encode a value as a TFRecord.
Args:
file_handle: The file to write to.
value: A bytes object representing content of the record.
"""
encoded_length = struct.pack(b'<Q', len(value))
file_handle.write(b''.join([
encoded_length,
struct.pack(b'<I', cls._masked_crc32c(encoded_length)),
value,
struct.pack(b'<I', cls._masked_crc32c(value))
]))
@classmethod
def read_record(cls, file_handle):
"""Read a record from a TFRecords file.
Args:
file_handle: The file to read from.
Returns:
None if EOF is reached; the paylod of the record otherwise.
Raises:
ValueError: If file appears to not be a valid TFRecords file.
"""
buf_length_expected = 12
buf = file_handle.read(buf_length_expected)
if not buf:
return None # EOF Reached.
# Validate all length related payloads.
if len(buf) != buf_length_expected:
raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
(buf_length_expected, codecs.encode(buf, 'hex')))
length, length_mask_expected = struct.unpack('<QI', buf)
length_mask_actual = cls._masked_crc32c(buf[:8])
if length_mask_actual != length_mask_expected:
raise ValueError('Not a valid TFRecord. Mismatch of length mask: %s' %
codecs.encode(buf, 'hex'))
# Validate all data related payloads.
buf_length_expected = length + 4
buf = file_handle.read(buf_length_expected)
if len(buf) != buf_length_expected:
raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
(buf_length_expected, codecs.encode(buf, 'hex')))
data, data_mask_expected = struct.unpack('<%dsI' % length, buf)
data_mask_actual = cls._masked_crc32c(data)
if data_mask_actual != data_mask_expected:
raise ValueError('Not a valid TFRecord. Mismatch of data mask: %s' %
codecs.encode(buf, 'hex'))
# All validation checks passed.
return data
class _TFRecordSource(FileBasedSource):
"""A File source for reading files of TFRecords.
For detailed TFRecords format description see:
https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
"""
def __init__(self,
file_pattern,
coder,
compression_type,
validate):
"""Initialize a TFRecordSource. See ReadFromTFRecord for details."""
super(_TFRecordSource, self).__init__(
file_pattern=file_pattern,
compression_type=compression_type,
splittable=False,
validate=validate)
self._coder = coder
def read_records(self, file_name, offset_range_tracker):
if offset_range_tracker.start_position():
raise ValueError('Start position not 0:%s' %
offset_range_tracker.start_position())
current_offset = offset_range_tracker.start_position()
with self.open_file(file_name) as file_handle:
while True:
if not offset_range_tracker.try_claim(current_offset):
raise RuntimeError('Unable to claim position: %s' % current_offset)
record = _TFRecordUtil.read_record(file_handle)
if record is None:
return # Reached EOF
else:
current_offset += _TFRecordUtil.encoded_num_bytes(record)
yield self._coder.decode(record)
def _create_tfrecordio_source(
file_pattern=None, coder=None, compression_type=None):
# We intentionally disable validation for ReadAll pattern so that reading does
# not fail for globs (elements) that are empty.
return _TFRecordSource(file_pattern, coder, compression_type,
validate=False)
class ReadAllFromTFRecord(PTransform):
"""A ``PTransform`` for reading a ``PCollection`` of TFRecord files."""
def __init__(
self,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO):
"""Initialize the ``ReadAllFromTFRecord`` transform.
Args:
coder: Coder used to decode each record.
compression_type: Used to handle compressed input files. Default value
is CompressionTypes.AUTO, in which case the file_path's extension will
be used to detect the compression.
"""
super(ReadAllFromTFRecord, self).__init__()
source_from_file = partial(
_create_tfrecordio_source, compression_type=compression_type,
coder=coder)
# Desired and min bundle sizes do not matter since TFRecord files are
# unsplittable.
self._read_all_files = ReadAllFiles(
splittable=False, compression_type=compression_type,
desired_bundle_size=0, min_bundle_size=0,
source_from_file=source_from_file)
def expand(self, pvalue):
return pvalue | 'ReadAllFiles' >> self._read_all_files
class ReadFromTFRecord(PTransform):
"""Transform for reading TFRecord sources."""
def __init__(self,
file_pattern,
coder=coders.BytesCoder(),
compression_type=CompressionTypes.AUTO,
validate=True):
"""Initialize a ReadFromTFRecord transform.
Args:
file_pattern: A file glob pattern to read TFRecords from.
coder: Coder used to decode each record.
compression_type: Used to handle compressed input files. Default value
is CompressionTypes.AUTO, in which case the file_path's extension will
be used to detect the compression.
validate: Boolean flag to verify that the files exist during the pipeline
creation time.
Returns:
A ReadFromTFRecord transform object.
"""
super(ReadFromTFRecord, self).__init__()
self._source = _TFRecordSource(file_pattern, coder, compression_type,
validate)
def expand(self, pvalue):
return pvalue.pipeline | Read(self._source)
class _TFRecordSink(filebasedsink.FileBasedSink):
"""Sink for writing TFRecords files.
For detailed TFRecord format description see:
https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
"""
def __init__(self, file_path_prefix, coder, file_name_suffix, num_shards,
shard_name_template, compression_type):
"""Initialize a TFRecordSink. See WriteToTFRecord for details."""
super(_TFRecordSink, self).__init__(
file_path_prefix=file_path_prefix,
coder=coder,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
mime_type='application/octet-stream',
compression_type=compression_type)
def write_encoded_record(self, file_handle, value):
_TFRecordUtil.write_record(file_handle, value)
class WriteToTFRecord(PTransform):
"""Transform for writing to TFRecord sinks."""
def __init__(self,
file_path_prefix,
coder=coders.BytesCoder(),
file_name_suffix='',
num_shards=0,
shard_name_template=None,
compression_type=CompressionTypes.AUTO):
"""Initialize WriteToTFRecord transform.
Args:
file_path_prefix: The file path to write to. The files written will begin
with this prefix, followed by a shard identifier (see num_shards), and
end in a common extension, if given by file_name_suffix.
coder: Coder used to encode each record.
file_name_suffix: Suffix for the files written.
num_shards: The number of files (shards) used for output. If not set, the
default value will be used.
shard_name_template: A template string containing placeholders for
the shard number and shard count. When constructing a filename for a
particular shard number, the upper-case letters 'S' and 'N' are
replaced with the 0-padded shard number and shard count respectively.
This argument can be '' in which case it behaves as if num_shards was
set to 1 and only one file will be generated. The default pattern used
is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
compression_type: Used to handle compressed output files. Typical value
is CompressionTypes.AUTO, in which case the file_path's extension will
be used to detect the compression.
Returns:
A WriteToTFRecord transform object.
"""
super(WriteToTFRecord, self).__init__()
self._sink = _TFRecordSink(file_path_prefix, coder, file_name_suffix,
num_shards, shard_name_template,
compression_type)
def expand(self, pcoll):
return pcoll | Write(self._sink)