| from __future__ import print_function |
| import mxnet as mx |
| import numpy as np |
| from timeit import default_timer as timer |
| from dataset.testdb import TestDB |
| from dataset.iterator import DetIter |
| |
| class Detector(object): |
| """ |
| SSD detector which hold a detection network and wraps detection API |
| |
| Parameters: |
| ---------- |
| symbol : mx.Symbol |
| detection network Symbol |
| model_prefix : str |
| name prefix of trained model |
| epoch : int |
| load epoch of trained model |
| data_shape : int |
| input data resize shape |
| mean_pixels : tuple of float |
| (mean_r, mean_g, mean_b) |
| batch_size : int |
| run detection with batch size |
| ctx : mx.ctx |
| device to use, if None, use mx.cpu() as default context |
| """ |
| def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \ |
| batch_size=1, ctx=None): |
| self.ctx = ctx |
| if self.ctx is None: |
| self.ctx = mx.cpu() |
| _, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) |
| self.mod = mx.mod.Module(symbol, context=ctx) |
| self.data_shape = data_shape |
| self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))]) |
| self.mod.set_params(args, auxs) |
| self.data_shape = data_shape |
| self.mean_pixels = mean_pixels |
| |
| def detect(self, det_iter, show_timer=False): |
| """ |
| detect all images in iterator |
| |
| Parameters: |
| ---------- |
| det_iter : DetIter |
| iterator for all testing images |
| show_timer : Boolean |
| whether to print out detection exec time |
| |
| Returns: |
| ---------- |
| list of detection results |
| """ |
| num_images = det_iter._size |
| if not isinstance(det_iter, mx.io.PrefetchingIter): |
| det_iter = mx.io.PrefetchingIter(det_iter) |
| start = timer() |
| detections = self.mod.predict(det_iter).asnumpy() |
| time_elapsed = timer() - start |
| if show_timer: |
| print("Detection time for {} images: {:.4f} sec".format( |
| num_images, time_elapsed)) |
| result = [] |
| for i in range(detections.shape[0]): |
| det = detections[i, :, :] |
| res = det[np.where(det[:, 0] >= 0)[0]] |
| result.append(res) |
| return result |
| |
| def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False): |
| """ |
| wrapper for detecting multiple images |
| |
| Parameters: |
| ---------- |
| im_list : list of str |
| image path or list of image paths |
| root_dir : str |
| directory of input images, optional if image path already |
| has full directory information |
| extension : str |
| image extension, eg. ".jpg", optional |
| |
| Returns: |
| ---------- |
| list of detection results in format [det0, det1...], det is in |
| format np.array([id, score, xmin, ymin, xmax, ymax]...) |
| """ |
| test_db = TestDB(im_list, root_dir=root_dir, extension=extension) |
| test_iter = DetIter(test_db, 1, self.data_shape, self.mean_pixels, |
| is_train=False) |
| return self.detect(test_iter, show_timer) |
| |
| def visualize_detection(self, img, dets, classes=[], thresh=0.6): |
| """ |
| visualize detections in one image |
| |
| Parameters: |
| ---------- |
| img : numpy.array |
| image, in bgr format |
| dets : numpy.array |
| ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) |
| each row is one object |
| classes : tuple or list of str |
| class names |
| thresh : float |
| score threshold |
| """ |
| import matplotlib.pyplot as plt |
| import random |
| plt.imshow(img) |
| height = img.shape[0] |
| width = img.shape[1] |
| colors = dict() |
| for i in range(dets.shape[0]): |
| cls_id = int(dets[i, 0]) |
| if cls_id >= 0: |
| score = dets[i, 1] |
| if score > thresh: |
| if cls_id not in colors: |
| colors[cls_id] = (random.random(), random.random(), random.random()) |
| xmin = int(dets[i, 2] * width) |
| ymin = int(dets[i, 3] * height) |
| xmax = int(dets[i, 4] * width) |
| ymax = int(dets[i, 5] * height) |
| rect = plt.Rectangle((xmin, ymin), xmax - xmin, |
| ymax - ymin, fill=False, |
| edgecolor=colors[cls_id], |
| linewidth=3.5) |
| plt.gca().add_patch(rect) |
| class_name = str(cls_id) |
| if classes and len(classes) > cls_id: |
| class_name = classes[cls_id] |
| plt.gca().text(xmin, ymin - 2, |
| '{:s} {:.3f}'.format(class_name, score), |
| bbox=dict(facecolor=colors[cls_id], alpha=0.5), |
| fontsize=12, color='white') |
| plt.show() |
| |
| def detect_and_visualize(self, im_list, root_dir=None, extension=None, |
| classes=[], thresh=0.6, show_timer=False): |
| """ |
| wrapper for im_detect and visualize_detection |
| |
| Parameters: |
| ---------- |
| im_list : list of str or str |
| image path or list of image paths |
| root_dir : str or None |
| directory of input images, optional if image path already |
| has full directory information |
| extension : str or None |
| image extension, eg. ".jpg", optional |
| |
| Returns: |
| ---------- |
| |
| """ |
| import cv2 |
| dets = self.im_detect(im_list, root_dir, extension, show_timer=show_timer) |
| if not isinstance(im_list, list): |
| im_list = [im_list] |
| assert len(dets) == len(im_list) |
| for k, det in enumerate(dets): |
| img = cv2.imread(im_list[k]) |
| img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)] |
| self.visualize_detection(img, det, classes, thresh) |