| from __future__ import print_function |
| import os |
| import sys |
| import importlib |
| import mxnet as mx |
| from dataset.iterator import DetRecordIter |
| from config.config import cfg |
| from evaluate.eval_metric import MApMetric, VOC07MApMetric |
| import logging |
| |
| def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, |
| model_prefix, epoch, ctx=mx.cpu(), batch_size=1, |
| path_imglist="", nms_thresh=0.45, force_nms=False, |
| ovp_thresh=0.5, use_difficult=False, class_names=None, |
| voc07_metric=False): |
| """ |
| evalute network given validation record file |
| |
| Parameters: |
| ---------- |
| net : str or None |
| Network name or use None to load from json without modifying |
| path_imgrec : str |
| path to the record validation file |
| path_imglist : str |
| path to the list file to replace labels in record file, optional |
| num_classes : int |
| number of classes, not including background |
| mean_pixels : tuple |
| (mean_r, mean_g, mean_b) |
| data_shape : tuple or int |
| (3, height, width) or height/width |
| model_prefix : str |
| model prefix of saved checkpoint |
| epoch : int |
| load model epoch |
| ctx : mx.ctx |
| mx.gpu() or mx.cpu() |
| batch_size : int |
| validation batch size |
| nms_thresh : float |
| non-maximum suppression threshold |
| force_nms : boolean |
| whether suppress different class objects |
| ovp_thresh : float |
| AP overlap threshold for true/false postives |
| use_difficult : boolean |
| whether to use difficult objects in evaluation if applicable |
| class_names : comma separated str |
| class names in string, must correspond to num_classes if set |
| voc07_metric : boolean |
| whether to use 11-point evluation as in VOC07 competition |
| """ |
| # set up logger |
| logging.basicConfig() |
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
| |
| # args |
| if isinstance(data_shape, int): |
| data_shape = (3, data_shape, data_shape) |
| assert len(data_shape) == 3 and data_shape[0] == 3 |
| model_prefix += '_' + str(data_shape[1]) |
| |
| # iterator |
| eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, |
| path_imglist=path_imglist, **cfg.valid) |
| # model params |
| load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) |
| # network |
| if net is None: |
| net = load_net |
| else: |
| sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) |
| net = importlib.import_module("symbol_" + net) \ |
| .get_symbol(num_classes, nms_thresh, force_nms) |
| if not 'label' in net.list_arguments(): |
| label = mx.sym.Variable(name='label') |
| net = mx.sym.Group([net, label]) |
| |
| # init module |
| mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, |
| fixed_param_names=net.list_arguments()) |
| mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) |
| mod.set_params(args, auxs, allow_missing=False, force_init=True) |
| |
| # run evaluation |
| if voc07_metric: |
| metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) |
| else: |
| metric = MApMetric(ovp_thresh, use_difficult, class_names) |
| results = mod.score(eval_iter, metric, num_batch=None) |
| for k, v in results: |
| print("{}: {}".format(k, v)) |