blob: ca0428348a6c1e0fef661a2525aa2af3ff0a71d2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file multibox_target.cu
* \brief MultiBoxTarget op
* \author Joshua Zhang
*/
#include "./multibox_target-inl.h"
#include <mshadow/cuda/tensor_gpu-inl.cuh>
#define MULTIBOX_TARGET_CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
} while (0)
namespace mshadow {
namespace cuda {
template<typename DType>
__global__ void InitGroundTruthFlags(DType *gt_flags, const DType *labels,
const int num_batches,
const int num_labels,
const int label_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= num_batches * num_labels) return;
int b = index / num_labels;
int l = index % num_labels;
if (*(labels + b * num_labels * label_width + l * label_width) == -1.f) {
*(gt_flags + b * num_labels + l) = 0;
} else {
*(gt_flags + b * num_labels + l) = 1;
}
}
template<typename DType>
__global__ void FindBestMatches(DType *best_matches, DType *gt_flags,
DType *anchor_flags, const DType *overlaps,
const int num_anchors, const int num_labels) {
int nbatch = blockIdx.x;
gt_flags += nbatch * num_labels;
overlaps += nbatch * num_anchors * num_labels;
best_matches += nbatch * num_anchors;
anchor_flags += nbatch * num_anchors;
const int num_threads = kMaxThreadsPerBlock;
__shared__ int max_indices_y[kMaxThreadsPerBlock];
__shared__ int max_indices_x[kMaxThreadsPerBlock];
__shared__ float max_values[kMaxThreadsPerBlock];
while (1) {
// check if all done.
bool finished = true;
for (int i = 0; i < num_labels; ++i) {
if (gt_flags[i] > .5) {
finished = false;
break;
}
}
if (finished) break; // all done.
// finding max indices in different threads
int max_x = -1;
int max_y = -1;
DType max_value = 1e-6; // start with very small overlap
for (int i = threadIdx.x; i < num_anchors; i += num_threads) {
if (anchor_flags[i] > .5) continue;
for (int j = 0; j < num_labels; ++j) {
if (gt_flags[j] > .5) {
DType temp = overlaps[i * num_labels + j];
if (temp > max_value) {
max_x = j;
max_y = i;
max_value = temp;
}
}
}
}
max_indices_x[threadIdx.x] = max_x;
max_indices_y[threadIdx.x] = max_y;
max_values[threadIdx.x] = max_value;
__syncthreads();
if (threadIdx.x == 0) {
// merge results and assign best match
int max_x = -1;
int max_y = -1;
DType max_value = -1;
for (int k = 0; k < num_threads; ++k) {
if (max_indices_y[k] < 0 || max_indices_x[k] < 0) continue;
float temp = max_values[k];
if (temp > max_value) {
max_x = max_indices_x[k];
max_y = max_indices_y[k];
max_value = temp;
}
}
if (max_x >= 0 && max_y >= 0) {
best_matches[max_y] = max_x;
// mark flags as visited
gt_flags[max_x] = 0.f;
anchor_flags[max_y] = 1.f;
} else {
// no more good matches
for (int i = 0; i < num_labels; ++i) {
gt_flags[i] = 0.f;
}
}
}
__syncthreads();
}
}
template<typename DType>
__global__ void FindGoodMatches(DType *best_matches, DType *anchor_flags,
const DType *overlaps, const int num_anchors,
const int num_labels,
const float overlap_threshold) {
int nbatch = blockIdx.x;
overlaps += nbatch * num_anchors * num_labels;
best_matches += nbatch * num_anchors;
anchor_flags += nbatch * num_anchors;
const int num_threads = kMaxThreadsPerBlock;
for (int i = threadIdx.x; i < num_anchors; i += num_threads) {
if (anchor_flags[i] < 0) {
int idx = -1;
float max_value = -1.f;
for (int j = 0; j < num_labels; ++j) {
DType temp = overlaps[i * num_labels + j];
if (temp > max_value) {
max_value = temp;
idx = j;
}
}
if (max_value > overlap_threshold && (idx >= 0)) {
best_matches[i] = idx;
anchor_flags[i] = 0.9f;
}
}
}
}
template<typename DType>
__global__ void UseAllNegatives(DType *anchor_flags, const int num) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num) return;
if (anchor_flags[idx] < 0.5) {
anchor_flags[idx] = 0; // regard all non-positive as negatives
}
}
template<typename DType>
__global__ void NegativeMining(const DType *overlaps, const DType *cls_preds,
DType *anchor_flags, DType *buffer,
const float negative_mining_ratio,
const float negative_mining_thresh,
const int minimum_negative_samples,
const int num_anchors,
const int num_labels, const int num_classes) {
int nbatch = blockIdx.x;
overlaps += nbatch * num_anchors * num_labels;
cls_preds += nbatch * num_classes * num_anchors;
anchor_flags += nbatch * num_anchors;
buffer += nbatch * num_anchors * 3;
const int num_threads = kMaxThreadsPerBlock;
int num_positive;
__shared__ int num_negative;
if (threadIdx.x == 0) {
num_positive = 0;
for (int i = 0; i < num_anchors; ++i) {
if (anchor_flags[i] > .5) {
++num_positive;
}
}
num_negative = num_positive * negative_mining_ratio;
if (num_negative < minimum_negative_samples) {
num_negative = minimum_negative_samples;
}
if (num_negative > (num_anchors - num_positive)) {
num_negative = num_anchors - num_positive;
}
}
__syncthreads();
if (num_negative < 1) return;
for (int i = threadIdx.x; i < num_anchors; i += num_threads) {
buffer[i] = -1.f;
if (anchor_flags[i] < 0) {
// compute max class prediction score
DType max_val = cls_preds[i];
for (int j = 1; j < num_classes; ++j) {
DType temp = cls_preds[i + num_anchors * j];
if (temp > max_val) max_val = temp;
}
DType sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
DType temp = cls_preds[i + num_anchors * j];
sum += exp(temp - max_val);
}
DType prob = exp(cls_preds[i] - max_val) / sum;
DType max_iou = -1.f;
for (int j = 0; j < num_labels; ++j) {
DType temp = overlaps[i * num_labels + j];
if (temp > max_iou) max_iou = temp;
}
if (max_iou < negative_mining_thresh) {
// only do it for anchors with iou < thresh
buffer[i] = -prob; // -log(x) actually, but value does not matter
}
}
}
__syncthreads();
// descend merge sorting for negative mining
DType *index_src = buffer + num_anchors;
DType *index_dst = buffer + num_anchors * 2;
DType *src = index_src;
DType *dst = index_dst;
for (int i = threadIdx.x; i < num_anchors; i += num_threads) {
index_src[i] = i;
}
__syncthreads();
for (int width = 2; width < (num_anchors << 1); width <<= 1) {
int slices = (num_anchors - 1) / (num_threads * width) + 1;
int start = width * threadIdx.x * slices;
for (int slice = 0; slice < slices; ++slice) {
if (start >= num_anchors) break;
int middle = start + (width >> 1);
if (num_anchors < middle) middle = num_anchors;
int end = start + width;
if (num_anchors < end) end = num_anchors;
int i = start;
int j = middle;
for (int k = start; k < end; ++k) {
int idx_i = static_cast<int>(src[i]);
int idx_j = static_cast<int>(src[j]);
if (i < middle && (j >= end || buffer[idx_i] > buffer[idx_j])) {
dst[k] = src[i];
++i;
} else {
dst[k] = src[j];
++j;
}
}
start += width;
}
__syncthreads();
// swap src/dst
src = src == index_src? index_dst : index_src;
dst = dst == index_src? index_dst : index_src;
}
__syncthreads();
for (int i = threadIdx.x; i < num_negative; i += num_threads) {
int idx = static_cast<int>(src[i]);
if (anchor_flags[idx] < 0) {
anchor_flags[idx] = 0;
}
}
}
template<typename DType>
__global__ void AssignTrainigTargets(DType *loc_target, DType *loc_mask,
DType *cls_target, DType *anchor_flags,
DType *best_matches, DType *labels,
DType *anchors, const int num_anchors,
const int num_labels, const int label_width,
const float vx, const float vy,
const float vw, const float vh) {
const int nbatch = blockIdx.x;
loc_target += nbatch * num_anchors * 4;
loc_mask += nbatch * num_anchors * 4;
cls_target += nbatch * num_anchors;
anchor_flags += nbatch * num_anchors;
best_matches += nbatch * num_anchors;
labels += nbatch * num_labels * label_width;
const int num_threads = kMaxThreadsPerBlock;
for (int i = threadIdx.x; i < num_anchors; i += num_threads) {
if (anchor_flags[i] > 0.5) {
// positive sample
int offset_l = static_cast<int>(best_matches[i]) * label_width;
cls_target[i] = labels[offset_l] + 1; // 0 reserved for background
int offset = i * 4;
loc_mask[offset] = 1;
loc_mask[offset + 1] = 1;
loc_mask[offset + 2] = 1;
loc_mask[offset + 3] = 1;
// regression targets
float al = anchors[offset];
float at = anchors[offset + 1];
float ar = anchors[offset + 2];
float ab = anchors[offset + 3];
float aw = ar - al;
float ah = ab - at;
float ax = (al + ar) * 0.5;
float ay = (at + ab) * 0.5;
float gl = labels[offset_l + 1];
float gt = labels[offset_l + 2];
float gr = labels[offset_l + 3];
float gb = labels[offset_l + 4];
float gw = gr - gl;
float gh = gb - gt;
float gx = (gl + gr) * 0.5;
float gy = (gt + gb) * 0.5;
loc_target[offset] = DType((gx - ax) / aw / vx); // xmin
loc_target[offset + 1] = DType((gy - ay) / ah / vy); // ymin
loc_target[offset + 2] = DType(log(gw / aw) / vw); // xmax
loc_target[offset + 3] = DType(log(gh / ah) / vh); // ymax
} else if (anchor_flags[i] < 0.5 && anchor_flags[i] > -0.5) {
// background
cls_target[i] = 0;
}
}
}
} // namespace cuda
template<typename DType>
inline void MultiBoxTargetForward(const Tensor<gpu, 2, DType> &loc_target,
const Tensor<gpu, 2, DType> &loc_mask,
const Tensor<gpu, 2, DType> &cls_target,
const Tensor<gpu, 2, DType> &anchors,
const Tensor<gpu, 3, DType> &labels,
const Tensor<gpu, 3, DType> &cls_preds,
const Tensor<gpu, 4, DType> &temp_space,
const float overlap_threshold,
const float background_label,
const float negative_mining_ratio,
const float negative_mining_thresh,
const int minimum_negative_samples,
const nnvm::Tuple<float> &variances) {
const int num_batches = labels.size(0);
const int num_labels = labels.size(1);
const int label_width = labels.size(2);
const int num_anchors = anchors.size(0);
const int num_classes = cls_preds.size(1);
CHECK_GE(num_batches, 1);
CHECK_GE(num_anchors, 1);
CHECK_EQ(variances.ndim(), 4);
// init ground-truth flags, by checking valid labels
temp_space[1] = 0.f;
DType *gt_flags = temp_space[1].dptr_;
const int num_threads = cuda::kMaxThreadsPerBlock;
dim3 init_thread_dim(num_threads);
dim3 init_block_dim((num_batches * num_labels - 1) / num_threads + 1);
cuda::CheckLaunchParam(init_block_dim, init_thread_dim, "MultiBoxTarget Init");
cuda::InitGroundTruthFlags<DType><<<init_block_dim, init_thread_dim>>>(
gt_flags, labels.dptr_, num_batches, num_labels, label_width);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
// compute best matches
temp_space[2] = -1.f;
temp_space[3] = -1.f;
DType *anchor_flags = temp_space[2].dptr_;
DType *best_matches = temp_space[3].dptr_;
const DType *overlaps = temp_space[0].dptr_;
cuda::CheckLaunchParam(num_batches, num_threads, "MultiBoxTarget Matching");
cuda::FindBestMatches<DType><<<num_batches, num_threads>>>(best_matches,
gt_flags, anchor_flags, overlaps, num_anchors, num_labels);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
// find good matches with overlap > threshold
if (overlap_threshold > 0) {
cuda::FindGoodMatches<DType><<<num_batches, num_threads>>>(best_matches,
anchor_flags, overlaps, num_anchors, num_labels,
overlap_threshold);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
}
// do negative mining or not
if (negative_mining_ratio > 0) {
CHECK_GT(negative_mining_thresh, 0);
temp_space[4] = 0;
DType *buffer = temp_space[4].dptr_;
cuda::NegativeMining<DType><<<num_batches, num_threads>>>(overlaps,
cls_preds.dptr_, anchor_flags, buffer, negative_mining_ratio,
negative_mining_thresh, minimum_negative_samples,
num_anchors, num_labels, num_classes);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
} else {
int num_blocks = (num_batches * num_anchors - 1) / num_threads + 1;
cuda::CheckLaunchParam(num_blocks, num_threads, "MultiBoxTarget Negative");
cuda::UseAllNegatives<DType><<<num_blocks, num_threads>>>(anchor_flags,
num_batches * num_anchors);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
}
cuda::AssignTrainigTargets<DType><<<num_batches, num_threads>>>(
loc_target.dptr_, loc_mask.dptr_, cls_target.dptr_, anchor_flags,
best_matches, labels.dptr_, anchors.dptr_, num_anchors, num_labels,
label_width, variances[0], variances[1], variances[2], variances[3]);
MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError());
}
} // namespace mshadow
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<gpu>(MultiBoxTargetParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new MultiBoxTargetOp<gpu, DType>(param);
});
return op;
}
} // namespace op
} // namespace mxnet