blob: 56c6ceefdc43c54ce5dbfd38cb8915f924929e13 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file multibox_target.cc
* \brief MultiBoxTarget op
* \author Joshua Zhang
*/
#include <algorithm>
#include "./multibox_target-inl.h"
#include "../mshadow_op.h"
namespace mshadow {
template<typename DType>
inline void AssignLocTargets(const DType *anchor, const DType *l, DType *dst,
const float vx, const float vy,
const float vw, const float vh) {
float al = *(anchor);
float at = *(anchor+1);
float ar = *(anchor+2);
float ab = *(anchor+3);
float aw = ar - al;
float ah = ab - at;
float ax = (al + ar) * 0.5;
float ay = (at + ab) * 0.5;
float gl = *(l);
float gt = *(l+1);
float gr = *(l+2);
float gb = *(l+3);
float gw = gr - gl;
float gh = gb - gt;
float gx = (gl + gr) * 0.5;
float gy = (gt + gb) * 0.5;
*(dst) = DType((gx - ax) / aw / vx);
*(dst+1) = DType((gy - ay) / ah / vy);
*(dst+2) = DType(std::log(gw / aw) / vw);
*(dst+3) = DType(std::log(gh / ah) / vh);
}
struct SortElemDescend {
float value;
int index;
SortElemDescend(float v, int i) {
value = v;
index = i;
}
bool operator<(const SortElemDescend &other) const {
return value > other.value;
}
};
template<typename DType>
inline void MultiBoxTargetForward(const Tensor<cpu, 2, DType> &loc_target,
const Tensor<cpu, 2, DType> &loc_mask,
const Tensor<cpu, 2, DType> &cls_target,
const Tensor<cpu, 2, DType> &anchors,
const Tensor<cpu, 3, DType> &labels,
const Tensor<cpu, 3, DType> &cls_preds,
const Tensor<cpu, 4, DType> &temp_space,
const float overlap_threshold,
const float background_label,
const float negative_mining_ratio,
const float negative_mining_thresh,
const int minimum_negative_samples,
const nnvm::Tuple<float> &variances) {
const DType *p_anchor = anchors.dptr_;
const int num_batches = labels.size(0);
const int num_labels = labels.size(1);
const int label_width = labels.size(2);
const int num_anchors = anchors.size(0);
CHECK_EQ(variances.ndim(), 4);
for (int nbatch = 0; nbatch < num_batches; ++nbatch) {
const DType *p_label = labels.dptr_ + nbatch * num_labels * label_width;
const DType *p_overlaps = temp_space.dptr_ + nbatch * num_anchors * num_labels;
int num_valid_gt = 0;
for (int i = 0; i < num_labels; ++i) {
if (static_cast<float>(*(p_label + i * label_width)) == -1.0f) {
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 1)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 2)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 3)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 4)), -1.0f);
break;
}
++num_valid_gt;
} // end iterate labels
if (num_valid_gt > 0) {
std::vector<bool> gt_flags(num_valid_gt, false);
std::vector<std::pair<float, int>> max_matches(num_anchors,
std::pair<float, int>(-1.0f, -1));
std::vector<char> anchor_flags(num_anchors, -1); // -1 means don't care
int num_positive = 0;
while (std::find(gt_flags.begin(), gt_flags.end(), false) != gt_flags.end()) {
// ground-truths not fully matched
int best_anchor = -1;
int best_gt = -1;
float max_overlap = 1e-6; // start with a very small positive overlap
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
const DType *pp_overlaps = p_overlaps + j * num_labels;
for (int k = 0; k < num_valid_gt; ++k) {
if (gt_flags[k]) {
continue; // already matched this gt
}
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_overlap) {
best_anchor = j;
best_gt = k;
max_overlap = iou;
}
}
}
if (best_anchor == -1) {
CHECK_EQ(best_gt, -1);
break; // no more good match
} else {
CHECK_EQ(max_matches[best_anchor].first, -1.0f);
CHECK_EQ(max_matches[best_anchor].second, -1);
max_matches[best_anchor].first = max_overlap;
max_matches[best_anchor].second = best_gt;
num_positive += 1;
// mark as visited
gt_flags[best_gt] = true;
anchor_flags[best_anchor] = 1;
}
} // end while
if (overlap_threshold > 0) {
// find positive matches based on overlaps
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
const DType *pp_overlaps = p_overlaps + j * num_labels;
int best_gt = -1;
float max_iou = -1.0f;
for (int k = 0; k < num_valid_gt; ++k) {
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_iou) {
best_gt = k;
max_iou = iou;
}
}
if (best_gt != -1) {
CHECK_EQ(max_matches[j].first, -1.0f);
CHECK_EQ(max_matches[j].second, -1);
max_matches[j].first = max_iou;
max_matches[j].second = best_gt;
if (max_iou > overlap_threshold) {
num_positive += 1;
// mark as visited
gt_flags[best_gt] = true;
anchor_flags[j] = 1;
}
}
} // end iterate anchors
}
if (negative_mining_ratio > 0) {
const int num_classes = cls_preds.size(1);
DType *p_cls_preds = cls_preds.dptr_ + nbatch * num_classes * num_anchors;
CHECK_GT(negative_mining_thresh, 0);
int num_negative = num_positive * negative_mining_ratio;
if (num_negative > (num_anchors - num_positive)) {
num_negative = num_anchors - num_positive;
}
if (num_negative > 0) {
// use negative mining, pick up "best" negative samples
std::vector<SortElemDescend> temp;
temp.reserve(num_anchors - num_positive);
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
if (max_matches[j].first < 0) {
// not yet calculated
const DType *pp_overlaps = p_overlaps + j * num_labels;
int best_gt = -1;
float max_iou = -1.0f;
for (int k = 0; k < num_valid_gt; ++k) {
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_iou) {
best_gt = k;
max_iou = iou;
}
}
if (best_gt != -1) {
CHECK_EQ(max_matches[j].first, -1.0f);
CHECK_EQ(max_matches[j].second, -1);
max_matches[j].first = max_iou;
max_matches[j].second = best_gt;
}
}
if (max_matches[j].first < negative_mining_thresh &&
anchor_flags[j] == -1) {
// calcuate class predictions
DType max_val = p_cls_preds[j];
for (int k = 1; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
if (tmp > max_val) max_val = tmp;
}
DType sum = 0.f;
for (int k = 0; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
sum += std::exp(tmp - max_val);
}
DType prob = std::exp(p_cls_preds[j] - max_val) / sum;
// loss should be -log(x), but value does not matter, skip log
temp.push_back(SortElemDescend(-prob, j));
}
} // end iterate anchors
CHECK_GE(temp.size(), num_negative);
std::stable_sort(temp.begin(), temp.end());
for (int i = 0; i < num_negative; ++i) {
anchor_flags[temp[i].index] = 0; // mark as negative sample
}
}
} else {
// use all negative samples
for (int i = 0; i < num_anchors; ++i) {
if (anchor_flags[i] != 1) {
anchor_flags[i] = 0;
}
}
}
// assign training targets
DType *p_loc_target = loc_target.dptr_ + nbatch * num_anchors * 4;
DType *p_loc_mask = loc_mask.dptr_ + nbatch * num_anchors * 4;
DType *p_cls_target = cls_target.dptr_ + nbatch * num_anchors;
for (int i = 0; i < num_anchors; ++i) {
if (anchor_flags[i] == 1) {
// positive sample
CHECK_GE(max_matches[i].second, 0);
// 0 reserved for background
*(p_cls_target + i) = *(p_label + label_width * max_matches[i].second) + 1;
int offset = i * 4;
*(p_loc_mask + offset) = 1;
*(p_loc_mask + offset + 1) = 1;
*(p_loc_mask + offset + 2) = 1;
*(p_loc_mask + offset + 3) = 1;
AssignLocTargets(p_anchor + i * 4,
p_label + label_width * max_matches[i].second + 1, p_loc_target + offset,
variances[0], variances[1], variances[2], variances[3]);
} else if (anchor_flags[i] == 0) {
// negative sample
*(p_cls_target + i) = 0;
int offset = i * 4;
*(p_loc_mask + offset) = 0;
*(p_loc_mask + offset + 1) = 0;
*(p_loc_mask + offset + 2) = 0;
*(p_loc_mask + offset + 3) = 0;
}
} // end iterate anchors
}
} // end iterate batches
}
} // namespace mshadow
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(MultiBoxTargetParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new MultiBoxTargetOp<cpu, DType>(param);
});
return op;
}
Operator* MultiBoxTargetProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
CHECK(InferType(in_type, &out_type, &aux_type));
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}
DMLC_REGISTER_PARAMETER(MultiBoxTargetParam);
MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxTarget, MultiBoxTargetProp)
.describe("Compute Multibox training targets")
.add_argument("anchor", "NDArray-or-Symbol", "Generated anchor boxes.")
.add_argument("label", "NDArray-or-Symbol", "Object detection labels.")
.add_argument("cls_pred", "NDArray-or-Symbol", "Class predictions.")
.add_arguments(MultiBoxTargetParam::__FIELDS__());
} // namespace op
} // namespace mxnet