| # pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name |
| # pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements |
| """Read invidual image files and perform augmentations.""" |
| |
| from __future__ import absolute_import, print_function |
| |
| import os |
| import random |
| import logging |
| import numpy as np |
| |
| try: |
| import cv2 |
| except ImportError: |
| cv2 = None |
| |
| from .base import numeric_types |
| from . import ndarray as nd |
| from . import _ndarray_internal as _internal |
| from ._ndarray_internal import _cvimresize as imresize |
| from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder |
| from . import io |
| from . import recordio |
| |
| |
| def imdecode(buf, **kwargs): |
| """Decode an image from string. Requires OpenCV to work. |
| |
| Parameters |
| ---------- |
| buf : str/bytes, or numpy.ndarray |
| Binary image data. |
| flag : int |
| 0 for grayscale. 1 for colored. |
| to_rgb : int |
| 0 for BGR format (OpenCV default). 1 for RGB format (MXNet default). |
| out : NDArray |
| Output buffer. Use None for automatic allocation. |
| """ |
| if not isinstance(buf, nd.NDArray): |
| buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) |
| return _internal._cvimdecode(buf, **kwargs) |
| |
| def scale_down(src_size, size): |
| """Scale down crop size if it's bigger than image size.""" |
| w, h = size |
| sw, sh = src_size |
| if sh < h: |
| w, h = float(w*sh)/h, sh |
| if sw < w: |
| w, h = sw, float(h*sw)/w |
| return int(w), int(h) |
| |
| def resize_short(src, size, interp=2): |
| """Resize shorter edge to size.""" |
| h, w, _ = src.shape |
| if h > w: |
| new_h, new_w = size*h/w, size |
| else: |
| new_h, new_w = size, size*w/h |
| return imresize(src, new_w, new_h, interp=interp) |
| |
| def fixed_crop(src, x0, y0, w, h, size=None, interp=2): |
| """Crop src at fixed location, and (optionally) resize it to size.""" |
| out = nd.crop(src, begin=(y0, x0, 0), end=(y0+h, x0+w, int(src.shape[2]))) |
| if size is not None and (w, h) != size: |
| out = imresize(out, *size, interp=interp) |
| return out |
| |
| def random_crop(src, size, interp=2): |
| """Randomly crop src with size. Upsample result if src is smaller than size.""" |
| h, w, _ = src.shape |
| new_w, new_h = scale_down((w, h), size) |
| |
| x0 = random.randint(0, w - new_w) |
| y0 = random.randint(0, h - new_h) |
| |
| out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) |
| return out, (x0, y0, new_w, new_h) |
| |
| def center_crop(src, size, interp=2): |
| """Randomly crop src with size. Upsample result if src is smaller than size.""" |
| h, w, _ = src.shape |
| new_w, new_h = scale_down((w, h), size) |
| |
| x0 = (w - new_w)/2 |
| y0 = (h - new_h)/2 |
| |
| out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) |
| return out, (x0, y0, new_w, new_h) |
| |
| def color_normalize(src, mean, std=None): |
| """Normalize src with mean and std.""" |
| src -= mean |
| if std is not None: |
| src /= std |
| return src |
| |
| def random_size_crop(src, size, min_area, ratio, interp=2): |
| """Randomly crop src with size. Randomize area and aspect ratio.""" |
| h, w, _ = src.shape |
| new_ratio = random.uniform(*ratio) |
| if new_ratio * h > w: |
| max_area = w*int(w/new_ratio) |
| else: |
| max_area = h*int(h*new_ratio) |
| |
| min_area *= h*w |
| if max_area < min_area: |
| return random_crop(src, size, interp) |
| new_area = random.uniform(min_area, max_area) |
| new_w = int(np.sqrt(new_area*new_ratio)) |
| new_h = int(np.sqrt(new_area/new_ratio)) |
| |
| assert new_w <= w and new_h <= h |
| x0 = random.randint(0, w - new_w) |
| y0 = random.randint(0, h - new_h) |
| |
| out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) |
| return out, (x0, y0, new_w, new_h) |
| |
| def ResizeAug(size, interp=2): |
| """Make resize shorter edge to size augumenter.""" |
| def aug(src): |
| """Augumenter body""" |
| return [resize_short(src, size, interp)] |
| return aug |
| |
| def RandomCropAug(size, interp=2): |
| """Make random crop augumenter""" |
| def aug(src): |
| """Augumenter body""" |
| return [random_crop(src, size, interp)[0]] |
| return aug |
| |
| def RandomSizedCropAug(size, min_area, ratio, interp=2): |
| """Make random crop with random resizing and random aspect ratio jitter augumenter.""" |
| def aug(src): |
| """Augumenter body""" |
| return [random_size_crop(src, size, min_area, ratio, interp)[0]] |
| return aug |
| |
| def CenterCropAug(size, interp=2): |
| """Make center crop augmenter.""" |
| def aug(src): |
| """Augumenter body""" |
| return [center_crop(src, size, interp)[0]] |
| return aug |
| |
| def RandomOrderAug(ts): |
| """Apply list of augmenters in random order""" |
| def aug(src): |
| """Augumenter body""" |
| src = [src] |
| random.shuffle(ts) |
| for t in ts: |
| src = [j for i in src for j in t(i)] |
| return src |
| return aug |
| |
| def ColorJitterAug(brightness, contrast, saturation): |
| """Apply random brightness, contrast and saturation jitter in random order.""" |
| ts = [] |
| coef = nd.array([[[0.299, 0.587, 0.114]]]) |
| if brightness > 0: |
| def baug(src): |
| """Augumenter body""" |
| alpha = 1.0 + random.uniform(-brightness, brightness) |
| src *= alpha |
| return [src] |
| ts.append(baug) |
| |
| if contrast > 0: |
| def caug(src): |
| """Augumenter body""" |
| alpha = 1.0 + random.uniform(-contrast, contrast) |
| gray = src*coef |
| gray = (3.0*(1.0-alpha)/gray.size)*nd.sum(gray) |
| src *= alpha |
| src += gray |
| return [src] |
| ts.append(caug) |
| |
| if saturation > 0: |
| def saug(src): |
| """Augumenter body""" |
| alpha = 1.0 + random.uniform(-saturation, saturation) |
| gray = src*coef |
| gray = nd.sum(gray, axis=2, keepdims=True) |
| gray *= (1.0-alpha) |
| src *= alpha |
| src += gray |
| return [src] |
| ts.append(saug) |
| return RandomOrderAug(ts) |
| |
| def LightingAug(alphastd, eigval, eigvec): |
| """Add PCA based noise.""" |
| def aug(src): |
| """Augumenter body""" |
| alpha = np.random.normal(0, alphastd, size=(3,)) |
| rgb = np.dot(eigvec*alpha, eigval) |
| src += nd.array(rgb) |
| return [src] |
| return aug |
| |
| def ColorNormalizeAug(mean, std): |
| """Mean and std normalization.""" |
| mean = nd.array(mean) |
| std = nd.array(std) |
| def aug(src): |
| """Augumenter body""" |
| return [color_normalize(src, mean, std)] |
| return aug |
| |
| def HorizontalFlipAug(p): |
| """Random horizontal flipping.""" |
| def aug(src): |
| """Augumenter body""" |
| if random.random() < p: |
| src = nd.flip(src, axis=1) |
| return [src] |
| return aug |
| |
| def CastAug(): |
| """Cast to float32""" |
| def aug(src): |
| """Augumenter body""" |
| src = src.astype(np.float32) |
| return [src] |
| return aug |
| |
| def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False, |
| mean=None, std=None, brightness=0, contrast=0, saturation=0, |
| pca_noise=0, inter_method=2): |
| """Create augumenter list.""" |
| auglist = [] |
| |
| if resize > 0: |
| auglist.append(ResizeAug(resize, inter_method)) |
| |
| crop_size = (data_shape[2], data_shape[1]) |
| if rand_resize: |
| assert rand_crop |
| auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0/4.0, 4.0/3.0), inter_method)) |
| elif rand_crop: |
| auglist.append(RandomCropAug(crop_size, inter_method)) |
| else: |
| auglist.append(CenterCropAug(crop_size, inter_method)) |
| |
| if rand_mirror: |
| auglist.append(HorizontalFlipAug(0.5)) |
| |
| auglist.append(CastAug()) |
| |
| if brightness or contrast or saturation: |
| auglist.append(ColorJitterAug(brightness, contrast, saturation)) |
| |
| if pca_noise > 0: |
| eigval = np.array([55.46, 4.794, 1.148]) |
| eigvec = np.array([[-0.5675, 0.7192, 0.4009], |
| [-0.5808, -0.0045, -0.8140], |
| [-0.5836, -0.6948, 0.4203]]) |
| auglist.append(LightingAug(pca_noise, eigval, eigvec)) |
| |
| if mean is True: |
| mean = np.array([123.68, 116.28, 103.53]) |
| if std is True: |
| std = np.array([58.395, 57.12, 57.375]) |
| if mean is not None: |
| assert std is not None |
| auglist.append(ColorNormalizeAug(mean, std)) |
| |
| return auglist |
| |
| |
| class ImageIter(io.DataIter): |
| """Image data iterator with a large number of augumentation choices. |
| Supports reading from both .rec files and raw image files with image list. |
| |
| To load from .rec files, please specify path_imgrec. Also specify path_imgidx |
| to use data partition (for distributed training) or shuffling. |
| |
| To load from raw image files, specify path_imglist and path_root. |
| |
| Parameters |
| ---------- |
| batch_size : int |
| Number of examples per batch. |
| data_shape : tuple |
| Data shape in (channels, height, width). |
| For now, only RGB image with 3 channels is supported. |
| label_width : int |
| dimension of label |
| path_imgrec : str |
| path to image record file (.rec). |
| Created with tools/im2rec.py or bin/im2rec |
| path_imglist : str |
| path to image list (.lst) |
| Created with tools/im2rec.py or with custom script. |
| Format: index\t[one or more label separated by \t]\trelative_path_from_root. |
| imglist: list |
| a list of image with the label(s) |
| each item is a list [imagelabel: float or list of float, imgpath]. |
| path_root : str |
| Root folder of image files |
| path_imgidx : str |
| Path to image index file. Needed for partition and shuffling when using .rec source. |
| shuffle : bool |
| Whether to shuffle all images at the start of each iteration. |
| Can be slow for HDD. |
| part_index : int |
| Partition index |
| num_parts : int |
| Total number of partitions. |
| data_name : str |
| data name for provided symbols |
| label_name : str |
| label name for provided symbols |
| kwargs : ... |
| More arguments for creating augumenter. See mx.image.CreateAugmenter. |
| """ |
| |
| def __init__(self, batch_size, data_shape, label_width=1, |
| path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, |
| shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, |
| data_name='data', label_name='softmax_label', **kwargs): |
| super(ImageIter, self).__init__() |
| assert path_imgrec or path_imglist or (isinstance(imglist, list)) |
| if path_imgrec: |
| print('loading recordio...') |
| if path_imgidx: |
| self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type |
| self.imgidx = list(self.imgrec.keys) |
| else: |
| self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type |
| self.imgidx = None |
| else: |
| self.imgrec = None |
| |
| if path_imglist: |
| print('loading image list...') |
| with open(path_imglist) as fin: |
| imglist = {} |
| imgkeys = [] |
| for line in iter(fin.readline, ''): |
| line = line.strip().split('\t') |
| label = nd.array([float(i) for i in line[1:-1]]) |
| key = int(line[0]) |
| imglist[key] = (label, line[-1]) |
| imgkeys.append(key) |
| self.imglist = imglist |
| elif isinstance(imglist, list): |
| print('loading image list...') |
| result = {} |
| imgkeys = [] |
| index = 1 |
| for img in imglist: |
| key = str(index) # pylint: disable=redefined-variable-type |
| index += 1 |
| if isinstance(img[0], numeric_types): |
| label = nd.array([img[0]]) |
| else: |
| label = nd.array(img[0]) |
| result[key] = (label, img[1]) |
| imgkeys.append(str(key)) |
| self.imglist = result |
| else: |
| self.imglist = None |
| self.path_root = path_root |
| |
| self.check_data_shape(data_shape) |
| self.provide_data = [(data_name, (batch_size,) + data_shape)] |
| if label_width > 1: |
| self.provide_label = [(label_name, (batch_size, label_width))] |
| else: |
| self.provide_label = [(label_name, (batch_size,))] |
| self.batch_size = batch_size |
| self.data_shape = data_shape |
| self.label_width = label_width |
| |
| self.shuffle = shuffle |
| if self.imgrec is None: |
| self.seq = imgkeys |
| elif shuffle or num_parts > 1: |
| assert self.imgidx is not None |
| self.seq = self.imgidx |
| else: |
| self.seq = None |
| |
| if num_parts > 1: |
| assert part_index < num_parts |
| N = len(self.seq) |
| C = N/num_parts |
| self.seq = self.seq[part_index*C:(part_index+1)*C] |
| if aug_list is None: |
| self.auglist = CreateAugmenter(data_shape, **kwargs) |
| else: |
| self.auglist = aug_list |
| self.cur = 0 |
| self.reset() |
| |
| def reset(self): |
| if self.shuffle: |
| random.shuffle(self.seq) |
| if self.imgrec is not None: |
| self.imgrec.reset() |
| self.cur = 0 |
| |
| def next_sample(self): |
| """Helper function for reading in next sample.""" |
| if self.seq is not None: |
| if self.cur >= len(self.seq): |
| raise StopIteration |
| idx = self.seq[self.cur] |
| self.cur += 1 |
| if self.imgrec is not None: |
| s = self.imgrec.read_idx(idx) |
| header, img = recordio.unpack(s) |
| if self.imglist is None: |
| return header.label, img |
| else: |
| return self.imglist[idx][0], img |
| else: |
| label, fname = self.imglist[idx] |
| return label, self.read_image(fname) |
| else: |
| s = self.imgrec.read() |
| if s is None: |
| raise StopIteration |
| header, img = recordio.unpack(s) |
| return header.label, img |
| |
| def next(self): |
| batch_size = self.batch_size |
| c, h, w = self.data_shape |
| batch_data = nd.empty((batch_size, c, h, w)) |
| batch_label = nd.empty(self.provide_label[0][1]) |
| i = 0 |
| try: |
| while i < batch_size: |
| label, s = self.next_sample() |
| data = [self.imdecode(s)] |
| try: |
| self.check_valid_image(data) |
| except RuntimeError as e: |
| logging.debug('Invalid image, skipping: %s', str(e)) |
| continue |
| data = self.augmentation_transform(data) |
| for datum in data: |
| assert i < batch_size, 'Batch size must be multiples of augmenter output length' |
| batch_data[i][:] = self.postprocess_data(datum) |
| batch_label[i][:] = label |
| i += 1 |
| except StopIteration: |
| if not i: |
| raise StopIteration |
| |
| return io.DataBatch([batch_data], [batch_label], batch_size-i) |
| |
| def check_data_shape(self, data_shape): |
| """checks that the input data shape is valid""" |
| if not len(data_shape) == 3: |
| raise ValueError('data_shape should have length 3, with dimensions CxHxW') |
| if not data_shape[0] == 3: |
| raise ValueError('This iterator expects inputs to have 3 channels.') |
| |
| def check_valid_image(self, data): |
| """checks that data is valid""" |
| if len(data[0].shape) == 0: |
| raise RuntimeError('Data shape is wrong') |
| |
| def imdecode(self, s): |
| """decodes a sting or byte string into an image.""" |
| return imdecode(s) |
| |
| def read_image(self, fname): |
| """reads image from fname and returns the raw bytes to be decoded.""" |
| with open(os.path.join(self.path_root, fname), 'rb') as fin: |
| img = fin.read() |
| return img |
| |
| def augmentation_transform(self, data): |
| """transforms data with specificied augmentation.""" |
| for aug in self.auglist: |
| data = [ret for src in data for ret in aug(src)] |
| return data |
| |
| def postprocess_data(self, datum): |
| """final postprocessing step before image is loaded into the batch.""" |
| return nd.transpose(datum, axes=(2, 0, 1)) |