blob: d33edb65beda2735a37372acf1620245c0abb7d8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import numpy as np
from rcnn.config import config
def get_rpn_names():
pred = ['rpn_cls_prob', 'rpn_bbox_loss']
label = ['rpn_label', 'rpn_bbox_target', 'rpn_bbox_weight']
return pred, label
def get_rcnn_names():
pred = ['rcnn_cls_prob', 'rcnn_bbox_loss']
label = ['rcnn_label', 'rcnn_bbox_target', 'rcnn_bbox_weight']
if config.TRAIN.END2END:
pred.append('rcnn_label')
rpn_pred, rpn_label = get_rpn_names()
pred = rpn_pred + pred
label = rpn_label
return pred, label
class RPNAccMetric(mx.metric.EvalMetric):
def __init__(self):
super(RPNAccMetric, self).__init__('RPNAcc')
self.pred, self.label = get_rpn_names()
def update(self, labels, preds):
pred = preds[self.pred.index('rpn_cls_prob')]
label = labels[self.label.index('rpn_label')]
# pred (b, c, p) or (b, c, h, w)
pred_label = mx.ndarray.argmax_channel(pred).asnumpy().astype('int32')
pred_label = pred_label.reshape((pred_label.shape[0], -1))
# label (b, p)
label = label.asnumpy().astype('int32')
# filter with keep_inds
keep_inds = np.where(label != -1)
pred_label = pred_label[keep_inds]
label = label[keep_inds]
self.sum_metric += np.sum(pred_label.flat == label.flat)
self.num_inst += len(pred_label.flat)
class RCNNAccMetric(mx.metric.EvalMetric):
def __init__(self):
super(RCNNAccMetric, self).__init__('RCNNAcc')
self.e2e = config.TRAIN.END2END
self.pred, self.label = get_rcnn_names()
def update(self, labels, preds):
pred = preds[self.pred.index('rcnn_cls_prob')]
if self.e2e:
label = preds[self.pred.index('rcnn_label')]
else:
label = labels[self.label.index('rcnn_label')]
last_dim = pred.shape[-1]
pred_label = pred.asnumpy().reshape(-1, last_dim).argmax(axis=1).astype('int32')
label = label.asnumpy().reshape(-1,).astype('int32')
self.sum_metric += np.sum(pred_label.flat == label.flat)
self.num_inst += len(pred_label.flat)
class RPNLogLossMetric(mx.metric.EvalMetric):
def __init__(self):
super(RPNLogLossMetric, self).__init__('RPNLogLoss')
self.pred, self.label = get_rpn_names()
def update(self, labels, preds):
pred = preds[self.pred.index('rpn_cls_prob')]
label = labels[self.label.index('rpn_label')]
# label (b, p)
label = label.asnumpy().astype('int32').reshape((-1))
# pred (b, c, p) or (b, c, h, w) --> (b, p, c) --> (b*p, c)
pred = pred.asnumpy().reshape((pred.shape[0], pred.shape[1], -1)).transpose((0, 2, 1))
pred = pred.reshape((label.shape[0], -1))
# filter with keep_inds
keep_inds = np.where(label != -1)[0]
label = label[keep_inds]
cls = pred[keep_inds, label]
cls += 1e-14
cls_loss = -1 * np.log(cls)
cls_loss = np.sum(cls_loss)
self.sum_metric += cls_loss
self.num_inst += label.shape[0]
class RCNNLogLossMetric(mx.metric.EvalMetric):
def __init__(self):
super(RCNNLogLossMetric, self).__init__('RCNNLogLoss')
self.e2e = config.TRAIN.END2END
self.pred, self.label = get_rcnn_names()
def update(self, labels, preds):
pred = preds[self.pred.index('rcnn_cls_prob')]
if self.e2e:
label = preds[self.pred.index('rcnn_label')]
else:
label = labels[self.label.index('rcnn_label')]
last_dim = pred.shape[-1]
pred = pred.asnumpy().reshape(-1, last_dim)
label = label.asnumpy().reshape(-1,).astype('int32')
cls = pred[np.arange(label.shape[0]), label]
cls += 1e-14
cls_loss = -1 * np.log(cls)
cls_loss = np.sum(cls_loss)
self.sum_metric += cls_loss
self.num_inst += label.shape[0]
class RPNL1LossMetric(mx.metric.EvalMetric):
def __init__(self):
super(RPNL1LossMetric, self).__init__('RPNL1Loss')
self.pred, self.label = get_rpn_names()
def update(self, labels, preds):
bbox_loss = preds[self.pred.index('rpn_bbox_loss')].asnumpy()
bbox_weight = labels[self.label.index('rpn_bbox_weight')].asnumpy()
# calculate num_inst (average on those fg anchors)
num_inst = np.sum(bbox_weight > 0) / 4
self.sum_metric += np.sum(bbox_loss)
self.num_inst += num_inst
class RCNNL1LossMetric(mx.metric.EvalMetric):
def __init__(self):
super(RCNNL1LossMetric, self).__init__('RCNNL1Loss')
self.e2e = config.TRAIN.END2END
self.pred, self.label = get_rcnn_names()
def update(self, labels, preds):
bbox_loss = preds[self.pred.index('rcnn_bbox_loss')].asnumpy()
if self.e2e:
label = preds[self.pred.index('rcnn_label')].asnumpy()
else:
label = labels[self.label.index('rcnn_label')].asnumpy()
# calculate num_inst
keep_inds = np.where(label != 0)[0]
num_inst = len(keep_inds)
self.sum_metric += np.sum(bbox_loss)
self.num_inst += num_inst