blob: 21c9af7049cb98109439cbd67092e44bcd5e8bea [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A set of PyTorch specific helper functions"""
import re
import collections
import decimal
import numpy as np
from six import PY2
from torch.utils.data.dataloader import default_collate
if PY2:
_string_classes = basestring # noqa: F821
else:
_string_classes = (str, bytes)
def _sanitize_pytorch_types(row_as_dict):
"""Promotes values types in a dictionary to the types supported by pytorch. Raises an error if type is clear error
if the type can not be promoted.
The parameter is modified in-place.
int8, uint16 are promoted to int32; uint32 -> int64;
numpy string_, unicode_, object arrays are not supported.
:param dict[str,obj] row_as_dict: a dictionary of key-value pairs. The values types are promoted to
pytorch compatible.
:return: None
"""
for name, value in row_as_dict.items():
# PyTorch supported types are: double, float, float16, int64, int32, and uint8
if isinstance(value, np.ndarray):
if value.dtype == np.int8:
row_as_dict[name] = value.astype(np.int16)
elif value.dtype == np.uint16:
row_as_dict[name] = value.astype(np.int32)
elif value.dtype == np.uint32:
row_as_dict[name] = value.astype(np.int64)
elif value.dtype == np.bool_:
row_as_dict[name] = value.astype(np.uint8)
elif re.search('[SaUO]', value.dtype.str):
raise TypeError('Pytorch does not support arrays of string classes. Found in field {}'.format(name))
elif isinstance(value, np.bool_):
row_as_dict[name] = np.uint8(value)
elif value is None:
raise TypeError('Pytorch does not support nullable fields. Found None in {}'.format(name))
def decimal_friendly_collate(batch):
"""A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated.
We use ``decimal.Decimal`` types in dataset to represent timestamps. PyTorch's ``default_collate``
implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates
``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard
``default_collate``.
:param batch: A list of dictionaries to collate
:return: A dictionary of lists/pytorch.Tensor types
"""
if isinstance(batch[0], decimal.Decimal):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], _string_classes):
return batch
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [decimal_friendly_collate(samples) for samples in transposed]
else:
return default_collate(batch)
class DataLoader(object):
"""
A data loader adaptor for ``torch.utils.data.DataLoader``.
This class iterates and returns items from the Reader in batches.
This loader can be used as an iterator and will terminate when the reader used in the construction of the class
runs out of samples.
"""
def __init__(self, reader, batch_size=1, collate_fn=decimal_friendly_collate):
"""
Initializes a data loader object, with a default collate.
Number of epochs is defined by the configuration of the reader argument.
:param reader: Reader instance
:param batch_size: the number of items to return per batch; factored into the len() of this reader
:param collate_fn: an optional callable to merge a list of samples to form a mini-batch.
"""
self.reader = reader
self.batch_size = batch_size
self.collate_fn = collate_fn
def __iter__(self):
"""
The Data Loader iterator stops the for-loop when reader runs out of samples.
"""
batch = []
for row in self.reader:
# Default collate does not work nicely on namedtuples and treat them as lists
# Using dict will result in the yielded structures being dicts as well
row_as_dict = row._asdict()
_sanitize_pytorch_types(row_as_dict)
batch.append(row_as_dict)
if len(batch) == self.batch_size:
yield self.collate_fn(batch)
batch = []
if batch:
yield self.collate_fn(batch)
# Functions needed to treat data loader as a context manager
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.reader.stop()
self.reader.join()