blob: 1b5e8cb76ee236801d0dbb872a7c85dd1815626e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
from dataset.testdb import TestDB
from dataset.iterator import DetIter
import logging
import cv2
from mxnet.io import DataBatch, DataDesc
class Detector(object):
"""
SSD detector which hold a detection network and wraps detection API
Parameters:
----------
symbol : mx.Symbol
detection network Symbol
model_prefix : str
name prefix of trained model
epoch : int
load epoch of trained model
data_shape : int
input data resize shape
mean_pixels : tuple of float
(mean_r, mean_g, mean_b)
batch_size : int
run detection with batch size
ctx : mx.ctx
device to use, if None, use mx.cpu() as default context
"""
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=self.ctx)
if not isinstance(data_shape, tuple):
data_shape = (data_shape, data_shape)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape[0], data_shape[1]))])
self.mod.set_params(args, auxs)
self.mean_pixels = mean_pixels
self.mean_pixels_nd = mx.nd.array(mean_pixels).reshape((3,1,1))
def create_batch(self, frame):
"""
:param frame: an (w,h,channels) numpy array (image)
:return: DataBatch of (1,channels,data_shape,data_shape)
"""
frame_resize = mx.nd.array(cv2.resize(frame, (self.data_shape[0], self.data_shape[1])))
#frame_resize = mx.img.imresize(frame, self.data_shape[0], self.data_shape[1], cv2.INTER_LINEAR)
# Change dimensions from (w,h,channels) to (channels, w, h)
frame_t = mx.nd.transpose(frame_resize, axes=(2,0,1))
frame_norm = frame_t - self.mean_pixels_nd
# Add dimension for batch, results in (1,channels,w,h)
batch_frame = [mx.nd.expand_dims(frame_norm, axis=0)]
batch_shape = [DataDesc('data', batch_frame[0].shape)]
batch = DataBatch(data=batch_frame, provide_data=batch_shape)
return batch
def detect_iter(self, det_iter, show_timer=False):
"""
detect all images in iterator
Parameters:
----------
det_iter : DetIter
iterator for all testing images
show_timer : Boolean
whether to print out detection exec time
Returns:
----------
list of detection results
"""
num_images = det_iter._size
if not isinstance(det_iter, mx.io.PrefetchingIter):
det_iter = mx.io.PrefetchingIter(det_iter)
start = timer()
detections = self.mod.predict(det_iter).asnumpy()
time_elapsed = timer() - start
if show_timer:
logging.info("Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed))
result = Detector.filter_positive_detections(detections)
return result
def detect_batch(self, batch):
"""
Return detections for batch
:param batch:
:return:
"""
self.mod.forward(batch, is_train=False)
detections = self.mod.get_outputs()[0]
positive_detections = Detector.filter_positive_detections(detections)
return positive_detections
def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False):
"""
wrapper for detecting multiple images
Parameters:
----------
im_list : list of str
image path or list of image paths
root_dir : str
directory of input images, optional if image path already
has full directory information
extension : str
image extension, eg. ".jpg", optional
Returns:
----------
list of detection results in format [det0, det1...], det is in
format np.array([id, score, xmin, ymin, xmax, ymax]...)
"""
test_db = TestDB(im_list, root_dir=root_dir, extension=extension)
test_iter = DetIter(test_db, 1, self.data_shape, self.mean_pixels,
is_train=False)
return self.detect_iter(test_iter, show_timer)
def visualize_detection(self, img, dets, classes=[], thresh=0.6):
"""
visualize detections in one image
Parameters:
----------
img : numpy.array
image, in bgr format
dets : numpy.array
ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
each row is one object
classes : tuple or list of str
class names
thresh : float
score threshold
"""
import matplotlib.pyplot as plt
import random
plt.imshow(img)
height = img.shape[0]
width = img.shape[1]
colors = dict()
for det in dets:
(klass, score, x0, y0, x1, y1) = det
if score < thresh:
continue
cls_id = int(klass)
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(x0 * width)
ymin = int(y0 * height)
xmax = int(x1 * width)
ymax = int(y1 * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='white')
plt.show()
@staticmethod
def filter_positive_detections(detections):
"""
First column (class id) is -1 for negative detections
:param detections:
:return:
"""
class_idx = 0
assert(isinstance(detections, mx.nd.NDArray) or isinstance(detections, np.ndarray))
detections_per_image = []
# for each image
for i in range(detections.shape[0]):
result = []
det = detections[i, :, :]
for obj in det:
if obj[class_idx] >= 0:
result.append(obj)
detections_per_image.append(result)
logging.info("%d positive detections", len(result))
return detections_per_image
def detect_and_visualize(self, im_list, root_dir=None, extension=None,
classes=[], thresh=0.6, show_timer=False):
"""
wrapper for im_detect and visualize_detection
Parameters:
----------
im_list : list of str or str
image path or list of image paths
root_dir : str or None
directory of input images, optional if image path already
has full directory information
extension : str or None
image extension, eg. ".jpg", optional
Returns:
----------
"""
dets = self.im_detect(im_list, root_dir, extension, show_timer=show_timer)
if not isinstance(im_list, list):
im_list = [im_list]
assert len(dets) == len(im_list)
for k, det in enumerate(dets):
img = cv2.imread(im_list[k])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.visualize_detection(img, det, classes, thresh)