blob: 890de7d0ffb8656e184b048accc079bb2d0db972 [file] [log] [blame]
# pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements
"""Read individual image files and perform augmentations."""
from __future__ import absolute_import, print_function
import os
import random
import logging
import numpy as np
try:
import cv2
except ImportError:
cv2 = None
from .base import numeric_types
from . import ndarray as nd
from . import _ndarray_internal as _internal
from ._ndarray_internal import _cvimresize as imresize
from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
from . import io
from . import recordio
def imdecode(buf, **kwargs):
"""Decode an image to an NDArray.
Note: `imdecode` uses OpenCV (not the CV2 Python library).
MXNet must have been built with OpenCV for `imdecode` to work.
Parameters
----------
buf : str/bytes or numpy.ndarray
Binary image data as string or numpy ndarray.
flag : int, optional, default=1
1 for three channel color output. 0 for grayscale output.
to_rgb : int, optional, default=1
1 for RGB formatted output (MXNet default). 0 for BGR formatted output (OpenCV default).
out : NDArray, optional
Output buffer. Use `None` for automatic allocation.
Returns
-------
NDArray
An `NDArray` containing the image.
Example
-------
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image
<NDArray 224x224x3 @cpu(0)>
Set `flag` parameter to 0 to get grayscale output
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, flag=0)
>>> image
<NDArray 224x224x1 @cpu(0)>
Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR)
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, to_rgb=0)
>>> image
<NDArray 224x224x3 @cpu(0)>
"""
if not isinstance(buf, nd.NDArray):
buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
return _internal._cvimdecode(buf, **kwargs)
def scale_down(src_size, size):
"""Scales down crop size if it's larger than image size.
If width/height of the crop is larger than the width/height of the image,
sets the width/height to the width/height of the image.
Parameters
----------
src_size : tuple of int
Size of the image in (width, height) format.
size : tuple of int
Size of the crop in (width, height) format.
Returns
-------
tuple of int
A tuple containing the scaled crop size in (width, height) format.
Example
--------
>>> src_size = (640,480)
>>> size = (720,120)
>>> new_size = mx.img.scale_down(src_size, size)
>>> new_size
(640,106)
"""
w, h = size
sw, sh = src_size
if sh < h:
w, h = float(w * sh) / h, sh
if sw < w:
w, h = sw, float(h * sw) / w
return int(w), int(h)
def resize_short(src, size, interp=2):
"""Resizes shorter edge to size.
Note: `resize_short` uses OpenCV (not the CV2 Python library).
MXNet must have been built with OpenCV for `resize_short` to work.
Resizes the original image by setting the shorter edge to size
and setting the longer edge accordingly.
Resizing function is called from OpenCV.
Parameters
----------
src : NDArray
The original image.
size : int
The length to be set for the shorter edge.
interp : int, optional, default=2
Interpolation method used for resizing the image.
Default method is bicubic interpolation.
More details can be found in the documentation of OpenCV, please refer to
http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
Returns
-------
NDArray
An 'NDArray' containing the resized image.
Example
-------
>>> with open("flower.jpeg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image
<NDArray 2321x3482x3 @cpu(0)>
>>> size = 640
>>> new_image = mx.img.resize_short(image, size)
>>> new_image
<NDArray 2321x3482x3 @cpu(0)>
"""
h, w, _ = src.shape
if h > w:
new_h, new_w = size * h / w, size
else:
new_h, new_w = size, size * w / h
return imresize(src, new_w, new_h, interp=interp)
def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
"""Crop src at fixed location, and (optionally) resize it to size."""
out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
if size is not None and (w, h) != size:
out = imresize(out, *size, interp=interp)
return out
def random_crop(src, size, interp=2):
"""Randomly crop `src` with `size` (width, height).
Upsample result if `src` is smaller than `size`.
Parameters
----------
src: Source image `NDArray`
size: Size of the crop formatted as (width, height). If the `size` is larger
than the image, then the source image is upsampled to `size` and returned.
interp: Interpolation method to be used in case the size is larger (default: bicubic).
Uses OpenCV convention for the parameters. Nearest - 0, Bilinear - 1, Bicubic - 2,
Area - 3. See OpenCV imresize function for more details.
Returns
-------
NDArray
An `NDArray` containing the cropped image.
Tuple
A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the
original image and (width, height) are the dimensions of the cropped image.
Example
-------
>>> im = mx.nd.array(cv2.imread("flower.jpg"))
>>> cropped_im, rect = mx.image.random_crop(im, (100, 100))
>>> print cropped_im
<NDArray 100x100x1 @cpu(0)>
>>> print rect
(20, 21, 100, 100)
"""
h, w, _ = src.shape
new_w, new_h = scale_down((w, h), size)
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def center_crop(src, size, interp=2):
"""Crops the image `src` to the given `size` by trimming on all four
sides and preserving the center of the image. Upsamples if `src` is smaller
than `size`.
.. note:: This requires MXNet to be compiled with USE_OPENCV.
Parameters
----------
src : NDArray
Binary source image data.
size : list or tuple of int
The desired output image size.
interp : interpolation, optional, default=Area-based
The type of interpolation that is done to the image.
Possible values:
0: Nearest Neighbors Interpolation.
1: Bilinear interpolation.
2: Area-based (resampling using pixel area relation). It may be a
preferred method for image decimation, as it gives moire-free
results. But when the image is zoomed, it is similar to the Nearest
Neighbors method. (used by default).
3: Bicubic interpolation over 4x4 pixel neighborhood.
4: Lanczos interpolation over 8x8 pixel neighborhood.
When shrinking an image, it will generally look best with AREA-based
interpolation, whereas, when enlarging an image, it will generally look best
with Bicubic (slow) or Bilinear (faster but still looks OK).
Returns
-------
NDArray
The cropped image.
Tuple
(x, y, width, height) where x, y are the positions of the crop in the
original image and width, height the dimensions of the crop.
Example
-------
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.image.imdecode(str_image)
>>> image
<NDArray 2321x3482x3 @cpu(0)>
>>> cropped_image, (x, y, width, height) = mx.image.center_crop(image, (1000, 500))
>>> cropped_image
<NDArray 500x1000x3 @cpu(0)>
>>> x, y, width, height
(1241, 910, 1000, 500)
"""
h, w, _ = src.shape
new_w, new_h = scale_down((w, h), size)
x0 = int((w - new_w) / 2)
y0 = int((h - new_h) / 2)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def color_normalize(src, mean, std=None):
"""Normalize src with mean and std."""
src -= mean
if std is not None:
src /= std
return src
def random_size_crop(src, size, min_area, ratio, interp=2):
"""Randomly crop src with size. Randomize area and aspect ratio."""
h, w, _ = src.shape
new_ratio = random.uniform(*ratio)
if new_ratio * h > w:
max_area = w * int(w / new_ratio)
else:
max_area = h * int(h * new_ratio)
min_area *= h * w
if max_area < min_area:
return random_crop(src, size, interp)
new_area = random.uniform(min_area, max_area)
new_w = int(np.sqrt(new_area * new_ratio))
new_h = int(np.sqrt(new_area / new_ratio))
assert new_w <= w and new_h <= h
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def ResizeAug(size, interp=2):
"""Make resize shorter edge to size augmenter."""
def aug(src):
"""Augmenter body"""
return [resize_short(src, size, interp)]
return aug
def RandomCropAug(size, interp=2):
"""Make random crop augmenter"""
def aug(src):
"""Augmenter body"""
return [random_crop(src, size, interp)[0]]
return aug
def RandomSizedCropAug(size, min_area, ratio, interp=2):
"""Make random crop with random resizing and random aspect ratio jitter augmenter."""
def aug(src):
"""Augmenter body"""
return [random_size_crop(src, size, min_area, ratio, interp)[0]]
return aug
def CenterCropAug(size, interp=2):
"""Make center crop augmenter."""
def aug(src):
"""Augmenter body"""
return [center_crop(src, size, interp)[0]]
return aug
def RandomOrderAug(ts):
"""Apply list of augmenters in random order"""
def aug(src):
"""Augmenter body"""
src = [src]
random.shuffle(ts)
for t in ts:
src = [j for i in src for j in t(i)]
return src
return aug
def ColorJitterAug(brightness, contrast, saturation):
"""Apply random brightness, contrast and saturation jitter in random order."""
ts = []
coef = nd.array([[[0.299, 0.587, 0.114]]])
if brightness > 0:
def baug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-brightness, brightness)
src *= alpha
return [src]
ts.append(baug)
if contrast > 0:
def caug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-contrast, contrast)
gray = src * coef
gray = (3.0 * (1.0 - alpha) / gray.size) * nd.sum(gray)
src *= alpha
src += gray
return [src]
ts.append(caug)
if saturation > 0:
def saug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-saturation, saturation)
gray = src * coef
gray = nd.sum(gray, axis=2, keepdims=True)
gray *= (1.0 - alpha)
src *= alpha
src += gray
return [src]
ts.append(saug)
return RandomOrderAug(ts)
def LightingAug(alphastd, eigval, eigvec):
"""Add PCA based noise."""
def aug(src):
"""Augmenter body"""
alpha = np.random.normal(0, alphastd, size=(3,))
rgb = np.dot(eigvec * alpha, eigval)
src += nd.array(rgb)
return [src]
return aug
def ColorNormalizeAug(mean, std):
"""Mean and std normalization."""
mean = nd.array(mean)
std = nd.array(std)
def aug(src):
"""Augmenter body"""
return [color_normalize(src, mean, std)]
return aug
def HorizontalFlipAug(p):
"""Random horizontal flipping."""
def aug(src):
"""Augmenter body"""
if random.random() < p:
src = nd.flip(src, axis=1)
return [src]
return aug
def CastAug():
"""Cast to float32"""
def aug(src):
"""Augmenter body"""
src = src.astype(np.float32)
return [src]
return aug
def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False,
mean=None, std=None, brightness=0, contrast=0, saturation=0,
pca_noise=0, inter_method=2):
"""Creates an augmenter list."""
auglist = []
if resize > 0:
auglist.append(ResizeAug(resize, inter_method))
crop_size = (data_shape[2], data_shape[1])
if rand_resize:
assert rand_crop
auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0 / 4.0, 4.0 / 3.0), inter_method))
elif rand_crop:
auglist.append(RandomCropAug(crop_size, inter_method))
else:
auglist.append(CenterCropAug(crop_size, inter_method))
if rand_mirror:
auglist.append(HorizontalFlipAug(0.5))
auglist.append(CastAug())
if brightness or contrast or saturation:
auglist.append(ColorJitterAug(brightness, contrast, saturation))
if pca_noise > 0:
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]])
auglist.append(LightingAug(pca_noise, eigval, eigvec))
if mean is True:
mean = np.array([123.68, 116.28, 103.53])
elif mean is not None:
assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3]
if std is True:
std = np.array([58.395, 57.12, 57.375])
elif std is not None:
assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3]
if mean is not None and std is not None:
auglist.append(ColorNormalizeAug(mean, std))
return auglist
class ImageIter(io.DataIter):
"""Image data iterator with a large number of augmentation choices.
This iterator supports reading from both .rec files and raw image files.
To load input images from .rec files, use `path_imgrec` parameter and to load from raw image
files, use `path_imglist` and `path_root` parameters.
To use data partition (for distributed training) or shuffling, specify `path_imgidx` parameter.
Parameters
----------
batch_size : int
Number of examples per batch.
data_shape : tuple
Data shape in (channels, height, width) format.
For now, only RGB image with 3 channels is supported.
label_width : int, optional
Number of labels per example. The default label width is 1.
path_imgrec : str
Path to image record file (.rec).
Created with tools/im2rec.py or bin/im2rec.
path_imglist : str
Path to image list (.lst).
Created with tools/im2rec.py or with custom script.
Format: Tab separated record of index, one or more labels and relative_path_from_root.
imglist: list
A list of images with the label(s).
Each item is a list [imagelabel: float or list of float, imgpath].
path_root : str
Root folder of image files.
path_imgidx : str
Path to image index file. Needed for partition and shuffling when using .rec source.
shuffle : bool
Whether to shuffle all images at the start of each iteration or not.
Can be slow for HDD.
part_index : int
Partition index.
num_parts : int
Total number of partitions.
data_name : str
Data name for provided symbols.
label_name : str
Label name for provided symbols.
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""
def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='softmax_label', **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
if path_imgrec:
print('loading recordio...')
if path_imgidx:
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.imgidx = list(self.imgrec.keys)
else:
self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.imgidx = None
else:
self.imgrec = None
if path_imglist:
print('loading image list...')
with open(path_imglist) as fin:
imglist = {}
imgkeys = []
for line in iter(fin.readline, ''):
line = line.strip().split('\t')
label = nd.array([float(i) for i in line[1:-1]])
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
self.imglist = imglist
elif isinstance(imglist, list):
print('loading image list...')
result = {}
imgkeys = []
index = 1
for img in imglist:
key = str(index) # pylint: disable=redefined-variable-type
index += 1
if isinstance(img[0], numeric_types):
label = nd.array([img[0]])
else:
label = nd.array(img[0])
result[key] = (label, img[1])
imgkeys.append(str(key))
self.imglist = result
else:
self.imglist = None
self.path_root = path_root
self.check_data_shape(data_shape)
self.provide_data = [(data_name, (batch_size,) + data_shape)]
if label_width > 1:
self.provide_label = [(label_name, (batch_size, label_width))]
else:
self.provide_label = [(label_name, (batch_size,))]
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
elif shuffle or num_parts > 1:
assert self.imgidx is not None
self.seq = self.imgidx
else:
self.seq = None
if num_parts > 1:
assert part_index < num_parts
N = len(self.seq)
C = N / num_parts
self.seq = self.seq[part_index * C:(part_index + 1) * C]
if aug_list is None:
self.auglist = CreateAugmenter(data_shape, **kwargs)
else:
self.auglist = aug_list
self.cur = 0
self.reset()
def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
def next_sample(self):
"""Helper function for reading in next sample."""
if self.seq is not None:
if self.cur >= len(self.seq):
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
if self.imgrec is not None:
s = self.imgrec.read_idx(idx)
header, img = recordio.unpack(s)
if self.imglist is None:
return header.label, img
else:
return self.imglist[idx][0], img
else:
label, fname = self.imglist[idx]
return label, self.read_image(fname)
else:
s = self.imgrec.read()
if s is None:
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img
def next(self):
"""Returns the next batch of data."""
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
data = [self.imdecode(s)]
try:
self.check_valid_image(data)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
data = self.augmentation_transform(data)
for datum in data:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
batch_data[i][:] = self.postprocess_data(datum)
batch_label[i][:] = label
i += 1
except StopIteration:
if not i:
raise StopIteration
return io.DataBatch([batch_data], [batch_label], batch_size - i)
def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
if not len(data_shape) == 3:
raise ValueError('data_shape should have length 3, with dimensions CxHxW')
if not data_shape[0] == 3:
raise ValueError('This iterator expects inputs to have 3 channels.')
def check_valid_image(self, data):
"""Checks if the input data is valid"""
if len(data[0].shape) == 0:
raise RuntimeError('Data shape is wrong')
def imdecode(self, s):
"""Decodes a string or byte string to an NDArray.
See mx.img.imdecode for more details."""
return imdecode(s)
def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.
Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
"""
with open(os.path.join(self.path_root, fname), 'rb') as fin:
img = fin.read()
return img
def augmentation_transform(self, data):
"""Transforms input data with specified augmentation."""
for aug in self.auglist:
data = [ret for src in data for ret in aug(src)]
return data
def postprocess_data(self, datum):
"""Final postprocessing step before image is loaded into the batch."""
return nd.transpose(datum, axes=(2, 0, 1))