simple changes to mxnet.image.ImageIter to make it more newbie friendly (#5278)
* simple changes to mxnet.image.ImageIter to make more newbie friendly
* lint update
diff --git a/python/mxnet/image.py b/python/mxnet/image.py
index 9874be6..4b70c38 100644
--- a/python/mxnet/image.py
+++ b/python/mxnet/image.py
@@ -8,13 +8,13 @@
import random
import logging
import numpy as np
-from .base import numeric_types
try:
import cv2
except ImportError:
cv2 = None
+from .base import numeric_types
from . import ndarray as nd
from . import _ndarray_internal as _internal
from ._ndarray_internal import _cvimresize as imresize
@@ -313,18 +313,24 @@
Partition index
num_parts : int
Total number of partitions.
+ data_name : str
+ data name for provided symbols
+ label_name : str
+ label name for provided symbols
kwargs : ...
More arguments for creating augumenter. See mx.image.CreateAugmenter.
"""
+
def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
- shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, **kwargs):
+ shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
+ data_name='data', label_name='softmax_label', **kwargs):
super(ImageIter, self).__init__()
- assert(path_imgrec or path_imglist or (isinstance(imglist, list)))
+ assert path_imgrec or path_imglist or (isinstance(imglist, list))
if path_imgrec:
print('loading recordio...')
if path_imgidx:
- self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+ self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.imgidx = list(self.imgrec.keys)
else:
self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type
@@ -363,12 +369,12 @@
self.imglist = None
self.path_root = path_root
- assert len(data_shape) == 3 and data_shape[0] == 3
- self.provide_data = [('data', (batch_size,) + data_shape)]
+ self.check_data_shape(data_shape)
+ self.provide_data = [(data_name, (batch_size,) + data_shape)]
if label_width > 1:
- self.provide_label = [('softmax_label', (batch_size, label_width))]
+ self.provide_label = [(label_name, (batch_size, label_width))]
else:
- self.provide_label = [('softmax_label', (batch_size,))]
+ self.provide_label = [(label_name, (batch_size,))]
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width
@@ -417,10 +423,7 @@
return self.imglist[idx][0], img
else:
label, fname = self.imglist[idx]
- if self.imgrec is None:
- with open(os.path.join(self.path_root, fname), 'rb') as fin:
- img = fin.read()
- return label, img
+ return label, self.read_image(fname)
else:
s = self.imgrec.read()
if s is None:
@@ -437,15 +440,16 @@
try:
while i < batch_size:
label, s = self.next_sample()
- data = [imdecode(s)]
- if len(data[0].shape) == 0:
- logging.debug('Invalid image, skipping.')
+ data = [self.imdecode(s)]
+ try:
+ self.check_valid_image(data)
+ except RuntimeError as e:
+ logging.debug('Invalid image, skipping: %s', str(e))
continue
- for aug in self.auglist:
- data = [ret for src in data for ret in aug(src)]
- for d in data:
+ data = self.augmentation_transform(data)
+ for datum in data:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
- batch_data[i][:] = nd.transpose(d, axes=(2, 0, 1))
+ batch_data[i][:] = self.postprocess_data(datum)
batch_label[i][:] = label
i += 1
except StopIteration:
@@ -453,3 +457,35 @@
raise StopIteration
return io.DataBatch([batch_data], [batch_label], batch_size-i)
+
+ def check_data_shape(self, data_shape):
+ """checks that the input data shape is valid"""
+ if not len(data_shape) == 3:
+ raise ValueError('data_shape should have length 3, with dimensions CxHxW')
+ if not data_shape[0] == 3:
+ raise ValueError('This iterator expects inputs to have 3 channels.')
+
+ def check_valid_image(self, data):
+ """checks that data is valid"""
+ if len(data[0].shape) == 0:
+ raise RuntimeError('Data shape is wrong')
+
+ def imdecode(self, s):
+ """decodes a sting or byte string into an image."""
+ return imdecode(s)
+
+ def read_image(self, fname):
+ """reads image from fname and returns the raw bytes to be decoded."""
+ with open(os.path.join(self.path_root, fname), 'rb') as fin:
+ img = fin.read()
+ return img
+
+ def augmentation_transform(self, data):
+ """transforms data with specificied augmentation."""
+ for aug in self.auglist:
+ data = [ret for src in data for ret in aug(src)]
+ return data
+
+ def postprocess_data(self, datum):
+ """final postprocessing step before image is loaded into the batch."""
+ return nd.transpose(datum, axes=(2, 0, 1))