blob: 2bf49f34d1ff53e14ff1cf1363ba1dcb11f17539 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file multibox_detection.cc
* \brief MultiBoxDetection op
* \author Joshua Zhang
*/
#include "./multibox_detection-inl.h"
#include <algorithm>
namespace mshadow {
template<typename DType>
struct SortElemDescend {
DType value;
int index;
SortElemDescend(DType v, int i) {
value = v;
index = i;
}
bool operator<(const SortElemDescend &other) const {
return value > other.value;
}
};
template<typename DType>
inline void TransformLocations(DType *out, const DType *anchors,
const DType *loc_pred, const bool clip,
const float vx, const float vy,
const float vw, const float vh) {
// transform predictions to detection results
DType al = anchors[0];
DType at = anchors[1];
DType ar = anchors[2];
DType ab = anchors[3];
DType aw = ar - al;
DType ah = ab - at;
DType ax = (al + ar) / 2.f;
DType ay = (at + ab) / 2.f;
DType px = loc_pred[0];
DType py = loc_pred[1];
DType pw = loc_pred[2];
DType ph = loc_pred[3];
DType ox = px * vx * aw + ax;
DType oy = py * vy * ah + ay;
DType ow = exp(pw * vw) * aw / 2;
DType oh = exp(ph * vh) * ah / 2;
out[0] = clip ? std::max(DType(0), std::min(DType(1), ox - ow)) : (ox - ow);
out[1] = clip ? std::max(DType(0), std::min(DType(1), oy - oh)) : (oy - oh);
out[2] = clip ? std::max(DType(0), std::min(DType(1), ox + ow)) : (ox + ow);
out[3] = clip ? std::max(DType(0), std::min(DType(1), oy + oh)) : (oy + oh);
}
template<typename DType>
inline DType CalculateOverlap(const DType *a, const DType *b) {
DType w = std::max(DType(0), std::min(a[2], b[2]) - std::max(a[0], b[0]));
DType h = std::max(DType(0), std::min(a[3], b[3]) - std::max(a[1], b[1]));
DType i = w * h;
DType u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i;
return u <= 0.f ? static_cast<DType>(0) : static_cast<DType>(i / u);
}
template<typename DType>
inline void MultiBoxDetectionForward(const Tensor<cpu, 3, DType> &out,
const Tensor<cpu, 3, DType> &cls_prob,
const Tensor<cpu, 2, DType> &loc_pred,
const Tensor<cpu, 2, DType> &anchors,
const Tensor<cpu, 3, DType> &temp_space,
const float threshold,
const bool clip,
const nnvm::Tuple<float> &variances,
const float nms_threshold,
const bool force_suppress,
const int nms_topk) {
CHECK_EQ(variances.ndim(), 4) << "Variance size must be 4";
const int num_classes = cls_prob.size(1);
const int num_anchors = cls_prob.size(2);
const int num_batches = cls_prob.size(0);
const DType *p_anchor = anchors.dptr_;
for (int nbatch = 0; nbatch < num_batches; ++nbatch) {
const DType *p_cls_prob = cls_prob.dptr_ + nbatch * num_classes * num_anchors;
const DType *p_loc_pred = loc_pred.dptr_ + nbatch * num_anchors * 4;
DType *p_out = out.dptr_ + nbatch * num_anchors * 6;
int valid_count = 0;
for (int i = 0; i < num_anchors; ++i) {
// find the predicted class id and probability
DType score = -1;
int id = 0;
for (int j = 1; j < num_classes; ++j) {
DType temp = p_cls_prob[j * num_anchors + i];
if (temp > score) {
score = temp;
id = j;
}
}
if (id > 0 && score < threshold) {
id = 0;
}
if (id > 0) {
// [id, prob, xmin, ymin, xmax, ymax]
p_out[valid_count * 6] = id - 1; // remove background, restore original id
p_out[valid_count * 6 + 1] = (id == 0 ? DType(-1) : score);
int offset = i * 4;
TransformLocations(p_out + valid_count * 6 + 2, p_anchor + offset,
p_loc_pred + offset, clip, variances[0], variances[1],
variances[2], variances[3]);
++valid_count;
}
} // end iter num_anchors
if (valid_count < 1 || nms_threshold <= 0 || nms_threshold > 1) continue;
// sort and apply NMS
Copy(temp_space[nbatch], out[nbatch], out.stream_);
// sort confidence in descend order
std::vector<SortElemDescend<DType>> sorter;
sorter.reserve(valid_count);
for (int i = 0; i < valid_count; ++i) {
sorter.push_back(SortElemDescend<DType>(p_out[i * 6 + 1], i));
}
std::stable_sort(sorter.begin(), sorter.end());
// re-order output
DType *ptemp = temp_space.dptr_ + nbatch * num_anchors * 6;
int nkeep = static_cast<int>(sorter.size());
if (nms_topk > 0 && nms_topk < nkeep) {
nkeep = nms_topk;
}
for (int i = 0; i < nkeep; ++i) {
for (int j = 0; j < 6; ++j) {
p_out[i * 6 + j] = ptemp[sorter[i].index * 6 + j];
}
}
// apply nms
for (int i = 0; i < valid_count; ++i) {
int offset_i = i * 6;
if (p_out[offset_i] < 0) continue; // skip eliminated
for (int j = i + 1; j < valid_count; ++j) {
int offset_j = j * 6;
if (p_out[offset_j] < 0) continue; // skip eliminated
if (force_suppress || (p_out[offset_i] == p_out[offset_j])) {
// when foce_suppress == true or class_id equals
DType iou = CalculateOverlap(p_out + offset_i + 2, p_out + offset_j + 2);
if (iou >= nms_threshold) {
p_out[offset_j] = -1;
}
}
}
}
} // end iter batch
}
} // namespace mshadow
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(MultiBoxDetectionParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new MultiBoxDetectionOp<cpu, DType>(param);
});
return op;
}
Operator* MultiBoxDetectionProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
CHECK(InferType(in_type, &out_type, &aux_type));
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}
DMLC_REGISTER_PARAMETER(MultiBoxDetectionParam);
MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxDetection, MultiBoxDetectionProp)
.describe("Convert multibox detection predictions.")
.add_argument("cls_prob", "NDArray-or-Symbol", "Class probabilities.")
.add_argument("loc_pred", "NDArray-or-Symbol", "Location regression predictions.")
.add_argument("anchor", "NDArray-or-Symbol", "Multibox prior anchor boxes")
.add_arguments(MultiBoxDetectionParam::__FIELDS__());
} // namespace op
} // namespace mxnet