| from __future__ import print_function | 
 | import mxnet as mx | 
 | import numpy as np | 
 | from timeit import default_timer as timer | 
 | from dataset.testdb import TestDB | 
 | from dataset.iterator import DetIter | 
 |  | 
 | class Detector(object): | 
 |     """ | 
 |     SSD detector which hold a detection network and wraps detection API | 
 |  | 
 |     Parameters: | 
 |     ---------- | 
 |     symbol : mx.Symbol | 
 |         detection network Symbol | 
 |     model_prefix : str | 
 |         name prefix of trained model | 
 |     epoch : int | 
 |         load epoch of trained model | 
 |     data_shape : int | 
 |         input data resize shape | 
 |     mean_pixels : tuple of float | 
 |         (mean_r, mean_g, mean_b) | 
 |     batch_size : int | 
 |         run detection with batch size | 
 |     ctx : mx.ctx | 
 |         device to use, if None, use mx.cpu() as default context | 
 |     """ | 
 |     def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \ | 
 |                  batch_size=1, ctx=None): | 
 |         self.ctx = ctx | 
 |         if self.ctx is None: | 
 |             self.ctx = mx.cpu() | 
 |         load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) | 
 |         if symbol is None: | 
 |             symbol = load_symbol | 
 |         self.mod = mx.mod.Module(symbol, label_names=None, context=ctx) | 
 |         self.data_shape = data_shape | 
 |         self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))]) | 
 |         self.mod.set_params(args, auxs) | 
 |         self.data_shape = data_shape | 
 |         self.mean_pixels = mean_pixels | 
 |  | 
 |     def detect(self, det_iter, show_timer=False): | 
 |         """ | 
 |         detect all images in iterator | 
 |  | 
 |         Parameters: | 
 |         ---------- | 
 |         det_iter : DetIter | 
 |             iterator for all testing images | 
 |         show_timer : Boolean | 
 |             whether to print out detection exec time | 
 |  | 
 |         Returns: | 
 |         ---------- | 
 |         list of detection results | 
 |         """ | 
 |         num_images = det_iter._size | 
 |         if not isinstance(det_iter, mx.io.PrefetchingIter): | 
 |             det_iter = mx.io.PrefetchingIter(det_iter) | 
 |         start = timer() | 
 |         detections = self.mod.predict(det_iter).asnumpy() | 
 |         time_elapsed = timer() - start | 
 |         if show_timer: | 
 |             print("Detection time for {} images: {:.4f} sec".format( | 
 |                 num_images, time_elapsed)) | 
 |         result = [] | 
 |         for i in range(detections.shape[0]): | 
 |             det = detections[i, :, :] | 
 |             res = det[np.where(det[:, 0] >= 0)[0]] | 
 |             result.append(res) | 
 |         return result | 
 |  | 
 |     def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False): | 
 |         """ | 
 |         wrapper for detecting multiple images | 
 |  | 
 |         Parameters: | 
 |         ---------- | 
 |         im_list : list of str | 
 |             image path or list of image paths | 
 |         root_dir : str | 
 |             directory of input images, optional if image path already | 
 |             has full directory information | 
 |         extension : str | 
 |             image extension, eg. ".jpg", optional | 
 |  | 
 |         Returns: | 
 |         ---------- | 
 |         list of detection results in format [det0, det1...], det is in | 
 |         format np.array([id, score, xmin, ymin, xmax, ymax]...) | 
 |         """ | 
 |         test_db = TestDB(im_list, root_dir=root_dir, extension=extension) | 
 |         test_iter = DetIter(test_db, 1, self.data_shape, self.mean_pixels, | 
 |                             is_train=False) | 
 |         return self.detect(test_iter, show_timer) | 
 |  | 
 |     def visualize_detection(self, img, dets, classes=[], thresh=0.6): | 
 |         """ | 
 |         visualize detections in one image | 
 |  | 
 |         Parameters: | 
 |         ---------- | 
 |         img : numpy.array | 
 |             image, in bgr format | 
 |         dets : numpy.array | 
 |             ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) | 
 |             each row is one object | 
 |         classes : tuple or list of str | 
 |             class names | 
 |         thresh : float | 
 |             score threshold | 
 |         """ | 
 |         import matplotlib.pyplot as plt | 
 |         import random | 
 |         plt.imshow(img) | 
 |         height = img.shape[0] | 
 |         width = img.shape[1] | 
 |         colors = dict() | 
 |         for i in range(dets.shape[0]): | 
 |             cls_id = int(dets[i, 0]) | 
 |             if cls_id >= 0: | 
 |                 score = dets[i, 1] | 
 |                 if score > thresh: | 
 |                     if cls_id not in colors: | 
 |                         colors[cls_id] = (random.random(), random.random(), random.random()) | 
 |                     xmin = int(dets[i, 2] * width) | 
 |                     ymin = int(dets[i, 3] * height) | 
 |                     xmax = int(dets[i, 4] * width) | 
 |                     ymax = int(dets[i, 5] * height) | 
 |                     rect = plt.Rectangle((xmin, ymin), xmax - xmin, | 
 |                                          ymax - ymin, fill=False, | 
 |                                          edgecolor=colors[cls_id], | 
 |                                          linewidth=3.5) | 
 |                     plt.gca().add_patch(rect) | 
 |                     class_name = str(cls_id) | 
 |                     if classes and len(classes) > cls_id: | 
 |                         class_name = classes[cls_id] | 
 |                     plt.gca().text(xmin, ymin - 2, | 
 |                                     '{:s} {:.3f}'.format(class_name, score), | 
 |                                     bbox=dict(facecolor=colors[cls_id], alpha=0.5), | 
 |                                     fontsize=12, color='white') | 
 |         plt.show() | 
 |  | 
 |     def detect_and_visualize(self, im_list, root_dir=None, extension=None, | 
 |                              classes=[], thresh=0.6, show_timer=False): | 
 |         """ | 
 |         wrapper for im_detect and visualize_detection | 
 |  | 
 |         Parameters: | 
 |         ---------- | 
 |         im_list : list of str or str | 
 |             image path or list of image paths | 
 |         root_dir : str or None | 
 |             directory of input images, optional if image path already | 
 |             has full directory information | 
 |         extension : str or None | 
 |             image extension, eg. ".jpg", optional | 
 |  | 
 |         Returns: | 
 |         ---------- | 
 |  | 
 |         """ | 
 |         import cv2 | 
 |         dets = self.im_detect(im_list, root_dir, extension, show_timer=show_timer) | 
 |         if not isinstance(im_list, list): | 
 |             im_list = [im_list] | 
 |         assert len(dets) == len(im_list) | 
 |         for k, det in enumerate(dets): | 
 |             img = cv2.imread(im_list[k]) | 
 |             img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)] | 
 |             self.visualize_detection(img, det, classes, thresh) |