blob: 43b73b615cc80ca85da892c32b97a8551095c85e [file] [log] [blame]
# coding: utf-8
# pylint: disable=too-many-arguments,no-member,invalid-name
"""Opencv plugin for mxnet"""
import random
import ctypes
import cv2
import mxnet as mx
from mxnet.base import _LIB
from mxnet.base import mx_uint, NDArrayHandle, check_call
def imdecode(str_img, flag=1):
"""Decode image from str buffer.
Wrapper for cv2.imdecode that uses mx.nd.NDArray
Parameters
----------
str_img : str
str buffer read from image file
flag : int
same as flag for cv2.imdecode
Returns
-------
img : NDArray
decoded image in (width, height, channels)
with BGR color channel order
"""
hdl = NDArrayHandle()
check_call(_LIB.MXCVImdecode(ctypes.c_char_p(str_img),
mx_uint(len(str_img)),
flag, ctypes.byref(hdl)))
return mx.nd.NDArray(hdl)
def resize(src, size, interpolation=cv2.INTER_LINEAR):
"""Decode image from str buffer.
Wrapper for cv2.imresize that uses mx.nd.NDArray
Parameters
----------
src : NDArray
image in (width, height, channels)
size : tuple
target size in (width, height)
interpolation : int
same as interpolation for cv2.imresize
Returns
-------
img : NDArray
resized image
"""
hdl = NDArrayHandle()
check_call(_LIB.MXCVResize(src.handle, mx_uint(size[0]), mx_uint(size[1]),
interpolation, ctypes.byref(hdl)))
return mx.nd.NDArray(hdl)
def copyMakeBorder(src, top, bot, left, right, border_type=cv2.BORDER_CONSTANT, value=0):
"""Pad image border
Wrapper for cv2.copyMakeBorder that uses mx.nd.NDArray
Parameters
----------
src : NDArray
Image in (width, height, channels).
Others are the same with cv2.copyMakeBorder
Returns
-------
img : NDArray
padded image
"""
hdl = NDArrayHandle()
check_call(_LIB.MXCVcopyMakeBorder(src.handle, ctypes.c_int(top), ctypes.c_int(bot),
ctypes.c_int(left), ctypes.c_int(right),
ctypes.c_int(border_type), ctypes.c_double(value),
ctypes.byref(hdl)))
return mx.nd.NDArray(hdl)
def scale_down(src_size, size):
"""Scale down crop size if it's bigger than image size"""
w, h = size
sw, sh = src_size
if sh < h:
w, h = float(w*sh)/h, sh
if sw < w:
w, h = sw, float(h*sw)/w
return int(w), int(h)
def fixed_crop(src, x0, y0, w, h, size=None, interpolation=cv2.INTER_CUBIC):
"""Crop src at fixed location, and (optionally) resize it to size"""
out = mx.nd.crop(src, begin=(y0, x0, 0), end=(y0+h, x0+w, int(src.shape[2])))
if size is not None and (w, h) != size:
out = resize(out, size, interpolation=interpolation)
return out
def random_crop(src, size):
"""Randomly crop src with size. Upsample result if src is smaller than size"""
h, w, _ = src.shape
new_w, new_h = scale_down((w, h), size)
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size)
return out, (x0, y0, new_w, new_h)
def color_normalize(src, mean, std):
"""Normalize src with mean and std"""
src -= mean
src /= std
return src
def random_size_crop(src, size, min_area=0.25, ratio=(3.0/4.0, 4.0/3.0)):
"""Randomly crop src with size. Randomize area and aspect ratio"""
h, w, _ = src.shape
area = w*h
for _ in range(10):
new_area = random.uniform(min_area, 1.0) * area
new_ratio = random.uniform(*ratio)
new_w = int(new_area*new_ratio)
new_h = int(new_area/new_ratio)
if random.uniform(0., 1.) < 0.5:
new_w, new_h = new_h, new_w
if new_w > w or new_h > h:
continue
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size)
return out, (x0, y0, new_w, new_h)
return random_crop(src, size)
class ImageListIter(mx.io.DataIter):
"""An example image iterator using opencv plugin"""
def __init__(self, root, flist, batch_size, size, mean=None):
mx.io.DataIter.__init__(self)
self.root = root
self.list = [line.strip() for line in open(flist).readlines()]
self.cur = 0
self.batch_size = batch_size
self.size = size
if mean is not None:
self.mean = mx.nd.array(mean)
else:
self.mean = None
def reset(self):
"""Reset iterator position to 0"""
self.cur = 0
def next(self):
"""Move iterator position forward"""
batch = mx.nd.zeros((self.batch_size, self.size[1], self.size[0], 3))
i = self.cur
for i in range(self.cur, min(len(self.list), self.cur+self.batch_size)):
str_img = open(self.root+self.list[i]+'.jpg').read()
img = imdecode(str_img, 1)
img, _ = random_crop(img, self.size)
batch[i - self.cur] = img
batch = mx.nd.transpose(batch, axes=(0, 3, 1, 2))
ret = mx.io.DataBatch(data=[batch],
label=[],
pad=self.batch_size-(i-self.cur),
index=None)
self.cur = i
return ret