simple changes to mxnet.image.ImageIter to make it more newbie friendly (#5278)

* simple changes to mxnet.image.ImageIter to make more newbie friendly

* lint update
diff --git a/python/mxnet/image.py b/python/mxnet/image.py
index 9874be6..4b70c38 100644
--- a/python/mxnet/image.py
+++ b/python/mxnet/image.py
@@ -8,13 +8,13 @@
 import random
 import logging
 import numpy as np
-from .base import numeric_types
 
 try:
     import cv2
 except ImportError:
     cv2 = None
 
+from .base import numeric_types
 from . import ndarray as nd
 from . import _ndarray_internal as _internal
 from ._ndarray_internal import _cvimresize as imresize
@@ -313,18 +313,24 @@
         Partition index
     num_parts : int
         Total number of partitions.
+    data_name : str
+        data name for provided symbols
+    label_name : str
+        label name for provided symbols
     kwargs : ...
         More arguments for creating augumenter. See mx.image.CreateAugmenter.
     """
+
     def __init__(self, batch_size, data_shape, label_width=1,
                  path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
-                 shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, **kwargs):
+                 shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
+                 data_name='data', label_name='softmax_label', **kwargs):
         super(ImageIter, self).__init__()
-        assert(path_imgrec or path_imglist or (isinstance(imglist, list)))
+        assert path_imgrec or path_imglist or (isinstance(imglist, list))
         if path_imgrec:
             print('loading recordio...')
             if path_imgidx:
-                self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+                self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
                 self.imgidx = list(self.imgrec.keys)
             else:
                 self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type
@@ -363,12 +369,12 @@
             self.imglist = None
         self.path_root = path_root
 
-        assert len(data_shape) == 3 and data_shape[0] == 3
-        self.provide_data = [('data', (batch_size,) + data_shape)]
+        self.check_data_shape(data_shape)
+        self.provide_data = [(data_name, (batch_size,) + data_shape)]
         if label_width > 1:
-            self.provide_label = [('softmax_label', (batch_size, label_width))]
+            self.provide_label = [(label_name, (batch_size, label_width))]
         else:
-            self.provide_label = [('softmax_label', (batch_size,))]
+            self.provide_label = [(label_name, (batch_size,))]
         self.batch_size = batch_size
         self.data_shape = data_shape
         self.label_width = label_width
@@ -417,10 +423,7 @@
                     return self.imglist[idx][0], img
             else:
                 label, fname = self.imglist[idx]
-                if self.imgrec is None:
-                    with open(os.path.join(self.path_root, fname), 'rb') as fin:
-                        img = fin.read()
-                return label, img
+                return label, self.read_image(fname)
         else:
             s = self.imgrec.read()
             if s is None:
@@ -437,15 +440,16 @@
         try:
             while i < batch_size:
                 label, s = self.next_sample()
-                data = [imdecode(s)]
-                if len(data[0].shape) == 0:
-                    logging.debug('Invalid image, skipping.')
+                data = [self.imdecode(s)]
+                try:
+                    self.check_valid_image(data)
+                except RuntimeError as e:
+                    logging.debug('Invalid image, skipping:  %s', str(e))
                     continue
-                for aug in self.auglist:
-                    data = [ret for src in data for ret in aug(src)]
-                for d in data:
+                data = self.augmentation_transform(data)
+                for datum in data:
                     assert i < batch_size, 'Batch size must be multiples of augmenter output length'
-                    batch_data[i][:] = nd.transpose(d, axes=(2, 0, 1))
+                    batch_data[i][:] = self.postprocess_data(datum)
                     batch_label[i][:] = label
                     i += 1
         except StopIteration:
@@ -453,3 +457,35 @@
                 raise StopIteration
 
         return io.DataBatch([batch_data], [batch_label], batch_size-i)
+
+    def check_data_shape(self, data_shape):
+        """checks that the input data shape is valid"""
+        if not len(data_shape) == 3:
+            raise ValueError('data_shape should have length 3, with dimensions CxHxW')
+        if not data_shape[0] == 3:
+            raise ValueError('This iterator expects inputs to have 3 channels.')
+
+    def check_valid_image(self, data):
+        """checks that data is valid"""
+        if len(data[0].shape) == 0:
+            raise RuntimeError('Data shape is wrong')
+
+    def imdecode(self, s):
+        """decodes a sting or byte string into an image."""
+        return imdecode(s)
+
+    def read_image(self, fname):
+        """reads image from fname and returns the raw bytes to be decoded."""
+        with open(os.path.join(self.path_root, fname), 'rb') as fin:
+            img = fin.read()
+        return img
+
+    def augmentation_transform(self, data):
+        """transforms data with specificied augmentation."""
+        for aug in self.auglist:
+            data = [ret for src in data for ret in aug(src)]
+        return data
+
+    def postprocess_data(self, datum):
+        """final postprocessing step before image is loaded into the batch."""
+        return nd.transpose(datum, axes=(2, 0, 1))