blob: d61100281160ada5db8aa4eb433308b3db495df8 [file] [log] [blame]
import mxnet as mx
import numpy as np
import cv2
from tools.rand_sampler import RandSampler
class DetIter(mx.io.DataIter):
"""
Detection Iterator, which will feed data and label to network
Optional data augmentation is performed when providing batch
Parameters:
----------
imdb : Imdb
image database
batch_size : int
batch size
data_shape : int or (int, int)
image shape to be resized
mean_pixels : float or float list
[R, G, B], mean pixel values
rand_samplers : list
random cropping sampler list, if not specified, will
use original image only
rand_mirror : bool
whether to randomly mirror input images, default False
shuffle : bool
whether to shuffle initial image list, default False
rand_seed : int or None
whether to use fixed random seed, default None
max_crop_trial : bool
if random crop is enabled, defines the maximum trial time
if trial exceed this number, will give up cropping
is_train : bool
whether in training phase, default True, if False, labels might
be ignored
"""
def __init__(self, imdb, batch_size, data_shape, \
mean_pixels=[128, 128, 128], rand_samplers=[], \
rand_mirror=False, shuffle=False, rand_seed=None, \
is_train=True, max_crop_trial=50):
super(DetIter, self).__init__()
self._imdb = imdb
self.batch_size = batch_size
if isinstance(data_shape, int):
data_shape = (data_shape, data_shape)
self._data_shape = data_shape
if isinstance(mean_pixels, list):
self._mean_pixels = mx.nd.Reshape(
mx.nd.array(mean_pixels),
shape = (3,1,1))
else:
self._mean_pixels = mean_pixels
if not rand_samplers:
self._rand_samplers = []
else:
if not isinstance(rand_samplers, list):
rand_samplers = [rand_samplers]
assert isinstance(rand_samplers[0], RandSampler), "Invalid rand sampler"
self._rand_samplers = rand_samplers
self.is_train = is_train
self._rand_mirror = rand_mirror
self._shuffle = shuffle
if rand_seed:
np.random.seed(rand_seed) # fix random seed
self._max_crop_trial = max_crop_trial
self._current = 0
self._size = imdb.num_images
self._index = np.arange(self._size)
self._data = None
self._label = None
self._get_batch()
@property
def provide_data(self):
return [(k, v.shape) for k, v in self._data.items()]
@property
def provide_label(self):
if self.is_train:
return [(k, v.shape) for k, v in self._label.items()]
else:
return []
def reset(self):
self._current = 0
if self._shuffle:
np.random.shuffle(self._index)
def iter_next(self):
return self._current < self._size
def next(self):
if self.iter_next():
self._get_batch()
data_batch = mx.io.DataBatch(data=self._data.values(),
label=self._label.values(),
pad=self.getpad(), index=self.getindex())
self._current += self.batch_size
return data_batch
else:
raise StopIteration
def getindex(self):
return self._current // self.batch_size
def getpad(self):
pad = self._current + self.batch_size - self._size
return 0 if pad < 0 else pad
def _get_batch(self):
"""
Load data/label from dataset
"""
batch_data = mx.nd.zeros((self.batch_size, 3, self._data_shape[0], self._data_shape[1]))
batch_label = []
for i in range(self.batch_size):
if (self._current + i) >= self._size:
if not self.is_train:
continue
# use padding from middle in each epoch
idx = (self._current + i + self._size // 2) % self._size
index = self._index[idx]
else:
index = self._index[self._current + i]
# index = self.debug_index
im_path = self._imdb.image_path_from_index(index)
with open(im_path, 'rb') as fp:
img_content = fp.read()
img = mx.img.imdecode(img_content)
gt = self._imdb.label_from_index(index).copy() if self.is_train else None
data, label = self._data_augmentation(img, gt)
batch_data[i] = data
if self.is_train:
batch_label.append(label)
self._data = {'data': batch_data}
if self.is_train:
self._label = {'label': mx.nd.array(np.array(batch_label))}
else:
self._label = {'label': None}
def _data_augmentation(self, data, label):
"""
perform data augmentations: crop, mirror, resize, sub mean, swap channels...
"""
if self.is_train and self._rand_samplers:
rand_crops = []
for rs in self._rand_samplers:
rand_crops += rs.sample(label)
num_rand_crops = len(rand_crops)
# randomly pick up one as input data
if num_rand_crops > 0:
index = int(np.random.uniform(0, 1) * num_rand_crops)
width = data.shape[1]
height = data.shape[0]
crop = rand_crops[index][0]
xmin = int(crop[0] * width)
ymin = int(crop[1] * height)
xmax = int(crop[2] * width)
ymax = int(crop[3] * height)
if xmin >= 0 and ymin >= 0 and xmax <= width and ymax <= height:
data = mx.img.fixed_crop(data, xmin, ymin, xmax-xmin, ymax-ymin)
else:
# padding mode
new_width = xmax - xmin
new_height = ymax - ymin
offset_x = 0 - xmin
offset_y = 0 - ymin
data_bak = data
data = mx.nd.full((new_height, new_width, 3), 128, dtype='uint8')
data[offset_y:offset_y+height, offset_x:offset_x + width, :] = data_bak
label = rand_crops[index][1]
if self.is_train:
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, \
cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
else:
interp_methods = [cv2.INTER_LINEAR]
interp_method = interp_methods[int(np.random.uniform(0, 1) * len(interp_methods))]
data = mx.img.imresize(data, self._data_shape[0], self._data_shape[1], interp_method)
if self.is_train and self._rand_mirror:
if np.random.uniform(0, 1) > 0.5:
data = mx.nd.flip(data, axis=1)
valid_mask = np.where(label[:, 0] > -1)[0]
tmp = 1.0 - label[valid_mask, 1]
label[valid_mask, 1] = 1.0 - label[valid_mask, 3]
label[valid_mask, 3] = tmp
data = mx.nd.transpose(data, (2,0,1))
data = data.astype('float32')
data = data - self._mean_pixels
return data, label