| /*! |
| * Copyright (c) 2017 by Contributors |
| * Copyright (c) 2017 Microsoft |
| * Licensed under The Apache-2.0 License [see LICENSE for details] |
| * \file psroi_pooling.cu |
| * \brief psroi pooling operator |
| * \author Yi Li, Tairui Chen, Guodong Zhang, Haozhi Qi, Jifeng Dai |
| */ |
| #include "./psroi_pooling-inl.h" |
| #include <mshadow/tensor.h> |
| #include <mshadow/cuda/reduce.cuh> |
| #include <algorithm> |
| #include <vector> |
| #include "../../common/cuda_utils.h" |
| #include "../mxnet_op.h" |
| |
| #define PSROIPOOLING_CUDA_CHECK(condition) \ |
| /* Code block avoids redefinition of cudaError_t error */ \ |
| do { \ |
| cudaError_t error = condition; \ |
| CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ |
| } while (0) |
| #define CUDA_KERNEL_LOOP(i, n) \ |
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ |
| i < (n); \ |
| i += blockDim.x * gridDim.x) |
| |
| namespace mshadow { |
| namespace cuda { |
| |
| template <typename DType> |
| __global__ void PSROIPoolForwardKernel( |
| const int count, |
| const DType* bottom_data, |
| const DType spatial_scale, |
| const int channels, |
| const int height, const int width, |
| const int pooled_height, const int pooled_width, |
| const DType* bottom_rois, |
| const int output_dim, |
| const int group_size, |
| DType* top_data) { |
| CUDA_KERNEL_LOOP(index, count) { |
| // The output is in order (n, ctop, ph, pw) |
| int pw = index % pooled_width; |
| int ph = (index / pooled_width) % pooled_height; |
| int ctop = (index / pooled_width / pooled_height) % output_dim; |
| int n = index / pooled_width / pooled_height / output_dim; |
| |
| // [start, end) interval for spatial sampling |
| const DType* offset_bottom_rois = bottom_rois + n * 5; |
| int roi_batch_ind = offset_bottom_rois[0]; |
| DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale; |
| DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale; |
| DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale; |
| DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale; |
| |
| // Force too small ROIs to be 1x1 |
| DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 |
| DType roi_height = max(roi_end_h - roi_start_h, 0.1); |
| |
| // Compute w and h at bottom |
| DType bin_size_h = roi_height / static_cast<DType>(pooled_height); |
| DType bin_size_w = roi_width / static_cast<DType>(pooled_width); |
| |
| int hstart = floor(static_cast<DType>(ph) * bin_size_h |
| + roi_start_h); |
| int wstart = floor(static_cast<DType>(pw)* bin_size_w |
| + roi_start_w); |
| int hend = ceil(static_cast<DType>(ph + 1) * bin_size_h |
| + roi_start_h); |
| int wend = ceil(static_cast<DType>(pw + 1) * bin_size_w |
| + roi_start_w); |
| // Add roi offsets and clip to input boundaries |
| hstart = min(max(hstart, 0), height); |
| hend = min(max(hend, 0), height); |
| wstart = min(max(wstart, 0), width); |
| wend = min(max(wend, 0), width); |
| bool is_empty = (hend <= hstart) || (wend <= wstart); |
| |
| int gw = floor(static_cast<DType>(pw)* group_size / pooled_width); |
| int gh = floor(static_cast<DType>(ph)* group_size / pooled_height); |
| gw = min(max(gw, 0), group_size - 1); |
| gh = min(max(gh, 0), group_size - 1); |
| int c = (ctop*group_size + gh)*group_size + gw; |
| |
| const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; |
| DType out_sum = 0; |
| for (int h = hstart; h < hend; ++h) { |
| for (int w = wstart; w < wend; ++w) { |
| int bottom_index = h*width + w; |
| out_sum += offset_bottom_data[bottom_index]; |
| } |
| } |
| |
| DType bin_area = (hend - hstart)*(wend - wstart); |
| top_data[index] = is_empty? (DType)0. : out_sum/bin_area; |
| } |
| } |
| |
| template<typename DType> |
| inline void PSROIPoolForward(const Tensor<gpu, 4, DType> &out, |
| const Tensor<gpu, 4, DType> &data, |
| const Tensor<gpu, 2, DType> &bbox, |
| const float spatial_scale, |
| const int output_dim_, |
| const int group_size_) { |
| const DType *bottom_data = data.dptr_; |
| const DType *bottom_rois = bbox.dptr_; |
| DType *top_data = out.dptr_; |
| const int count = out.shape_.Size(); |
| const int channels = data.size(1); |
| const int height = data.size(2); |
| const int width = data.size(3); |
| const int pooled_height = out.size(2); |
| const int pooled_width = out.size(3); |
| cudaStream_t stream = Stream<gpu>::GetStream(out.stream_); |
| PSROIPoolForwardKernel<DType> << <mxnet::op::mxnet_op::cuda_get_num_blocks(count), |
| kBaseThreadNum, 0, stream >> >( |
| count, bottom_data, spatial_scale, channels, height, width, |
| pooled_height, pooled_width, bottom_rois, output_dim_, group_size_, top_data); |
| PSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); |
| } |
| |
| |
| template <typename DType> |
| __global__ void PSROIPoolBackwardAccKernel( |
| const int count, |
| const DType* top_diff, |
| const int num_rois, |
| const DType spatial_scale, |
| const int channels, |
| const int height, const int width, |
| const int pooled_height, const int pooled_width, |
| const int group_size, |
| const int output_dim, |
| DType* bottom_diff, |
| const DType* bottom_rois) { |
| CUDA_KERNEL_LOOP(index, count) { |
| // The output is in order (n, ctop, ph, pw) |
| int pw = index % pooled_width; |
| int ph = (index / pooled_width) % pooled_height; |
| int ctop = (index / pooled_width / pooled_height) % output_dim; |
| int n = index / pooled_width / pooled_height / output_dim; |
| |
| // [start, end) interval for spatial sampling |
| const DType* offset_bottom_rois = bottom_rois + n * 5; |
| int roi_batch_ind = offset_bottom_rois[0]; |
| DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale; |
| DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale; |
| DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale; |
| DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale; |
| |
| // Force too small ROIs to be 1x1 |
| DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 |
| DType roi_height = max(roi_end_h - roi_start_h, 0.1); |
| |
| // Compute w and h at bottom |
| DType bin_size_h = roi_height / static_cast<DType>(pooled_height); |
| DType bin_size_w = roi_width / static_cast<DType>(pooled_width); |
| |
| int hstart = floor(static_cast<DType>(ph)* bin_size_h |
| + roi_start_h); |
| int wstart = floor(static_cast<DType>(pw)* bin_size_w |
| + roi_start_w); |
| int hend = ceil(static_cast<DType>(ph + 1) * bin_size_h |
| + roi_start_h); |
| int wend = ceil(static_cast<DType>(pw + 1) * bin_size_w |
| + roi_start_w); |
| // Add roi offsets and clip to input boundaries |
| hstart = min(max(hstart, 0), height); |
| hend = min(max(hend, 0), height); |
| wstart = min(max(wstart, 0), width); |
| wend = min(max(wend, 0), width); |
| bool is_empty = (hend <= hstart) || (wend <= wstart); |
| |
| // Compute c at bottom |
| int gw = floor(static_cast<DType>(pw)* group_size / pooled_width); |
| int gh = floor(static_cast<DType>(ph)* group_size / pooled_height); |
| gw = min(max(gw, 0), group_size - 1); |
| gh = min(max(gh, 0), group_size - 1); |
| int c = (ctop*group_size + gh)*group_size + gw; |
| DType* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; |
| DType bin_area = (hend - hstart)*(wend - wstart); |
| DType diff_val = is_empty ? (DType)0. : top_diff[index] / bin_area; |
| for (int h = hstart; h < hend; ++h) { |
| for (int w = wstart; w < wend; ++w) { |
| int bottom_index = h*width + w; |
| atomicAdd(offset_bottom_diff + bottom_index, diff_val); |
| } |
| } |
| } |
| } |
| |
| |
| template<typename DType> |
| inline void PSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad, |
| const Tensor<gpu, 4, DType> &out_grad, |
| const Tensor<gpu, 2, DType> &bbox, |
| const float spatial_scale, |
| const int output_dim_, |
| const int group_size_) { |
| // LOG(INFO) << "PSROIPoolBackward"; |
| const DType *top_diff = out_grad.dptr_; |
| const DType *bottom_rois = bbox.dptr_; |
| DType *bottom_diff = in_grad.dptr_; |
| const int count = out_grad.shape_.Size(); |
| const int num_rois = bbox.size(0); |
| const int channels = in_grad.size(1); |
| const int height = in_grad.size(2); |
| const int width = in_grad.size(3); |
| const int pooled_height = out_grad.size(2); |
| const int pooled_width = out_grad.size(3); |
| cudaStream_t stream = Stream<gpu>::GetStream(in_grad.stream_); |
| PSROIPoolBackwardAccKernel<DType> << <mxnet::op::mxnet_op::cuda_get_num_blocks(count), |
| kBaseThreadNum, 0, stream >> >( |
| count, top_diff, num_rois, spatial_scale, channels, height, width, |
| pooled_height, pooled_width, group_size_, output_dim_, bottom_diff, bottom_rois); |
| PSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); |
| } |
| |
| } // namespace cuda |
| |
| template<typename DType> |
| inline void PSROIPoolForward(const Tensor<gpu, 4, DType> &out, |
| const Tensor<gpu, 4, DType> &data, |
| const Tensor<gpu, 2, DType> &bbox, |
| const float spatial_scale, |
| const int output_dim_, |
| const int group_size_) { |
| cuda::PSROIPoolForward(out, data, bbox, spatial_scale, output_dim_, group_size_); |
| } |
| |
| template<typename DType> |
| inline void PSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad, |
| const Tensor<gpu, 4, DType> &out_grad, |
| const Tensor<gpu, 2, DType> &bbox, |
| const float spatial_scale, |
| const int output_dim_, |
| const int group_size_) { |
| cuda::PSROIPoolBackwardAcc(in_grad, out_grad, bbox, spatial_scale, output_dim_, group_size_); |
| } |
| |
| } // namespace mshadow |
| |
| |
| namespace mxnet { |
| namespace op { |
| |
| template<> |
| Operator* CreateOp<gpu>(PSROIPoolingParam param, int dtype) { |
| Operator* op = NULL; |
| MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { |
| op = new PSROIPoolingOp<gpu, DType>(param); |
| }); |
| return op; |
| } |
| |
| } // namespace op |
| } // namespace mxnet |