blob: fa631a5263fcc8ec02e9415028bfafd2669d2bb2 [file] [log] [blame]
import mxnet as mx
import numpy as np
class MultiBoxMetric(mx.metric.EvalMetric):
"""Calculate metrics for Multibox training """
def __init__(self, eps=1e-8):
super(MultiBoxMetric, self).__init__(['CrossEntropy', 'SmoothL1'], 2)
self.eps = eps
def update(self, labels, preds):
"""
Implementation of updating metrics
"""
# get generated multi label from network
cls_prob = preds[0].asnumpy()
loc_loss = preds[1].asnumpy()
cls_label = preds[2].asnumpy()
valid_count = np.sum(cls_label >= 0)
# overall accuracy & object accuracy
label = cls_label.flatten()
mask = np.where(label >= 0)[0]
indices = np.int64(label[mask])
prob = cls_prob.transpose((0, 2, 1)).reshape((-1, cls_prob.shape[1]))
prob = prob[mask, indices]
self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
self.num_inst[0] += valid_count
# smoothl1loss
self.sum_metric[1] += np.sum(loc_loss)
self.num_inst[1] += valid_count
def get(self):
"""Get the current evaluation result.
Override the default behavior
Returns
-------
name : str
Name of the metric.
value : float
Value of the evaluation.
"""
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
names = ['%s'%(self.name[i]) for i in range(self.num)]
values = [x / y if y != 0 else float('nan') \
for x, y in zip(self.sum_metric, self.num_inst)]
return (names, values)